Module pyminflux.correct

Correction/restoration functions.

Functions

def align_datasets_using_beads(reference_mbm_dict: Dict,
moving_mbm_dict: Dict,
bead_correspondence: Dict[str, str] | None = None,
transform_type: str = 'euclidean',
n_points: int | None = 3) ‑> object | None
Expand source code
def align_datasets_using_beads(
    reference_mbm_dict: Dict,
    moving_mbm_dict: Dict,
    bead_correspondence: Optional[Dict[str, str]] = None,
    transform_type: str = 'euclidean',
    n_points: Optional[int] = 3,
) -> Optional[object]:
    """
    Align two datasets using bead localizations.
    
    Args:
        reference_mbm_dict: MBM dictionary from reference dataset
        moving_mbm_dict: MBM dictionary from dataset to be aligned
        bead_correspondence: Optional dict mapping moving bead names to reference bead names.
                           If None, assumes beads with same names correspond.
        transform_type: Type of transformation ('euclidean' or 'affine')
        n_points: Number of earliest time points to average per bead
        
    Returns:
        model: The transformation model, or None if alignment fails
    """
    # Create dataframes from mbm dicts
    df_ref = mbm_dict_to_dataframe(reference_mbm_dict, additional_metadata={'type': 'reference'})
    df_mov = mbm_dict_to_dataframe(moving_mbm_dict, additional_metadata={'type': 'moving'})
    
    if df_ref.empty or df_mov.empty:
        print("Error: One or both datasets have no bead data.")
        return None
    
    # Filter to only include beads marked as "used"
    df_ref = df_ref[df_ref['used'] == True]
    df_mov = df_mov[df_mov['used'] == True]
    
    if df_ref.empty or df_mov.empty:
        print("Error: No beads marked as 'used' in one or both datasets.")
        return None
    
    # Identify beads to use for alignment
    if bead_correspondence is None:
        # Automatic: use beads with matching names
        common_beads_ref = set(df_ref['bead_name'])
        common_beads_mov = set(df_mov['bead_name'])
        common_beads = common_beads_ref.intersection(common_beads_mov)
        
        if not common_beads:
            print("Error: No common beads found between datasets.")
            return None
        
        # Create identity mapping
        bead_correspondence = {bn: bn for bn in common_beads}
    else:
        # Manual correspondence provided
        # Validate that all beads exist
        missing_ref = set(bead_correspondence.values()) - set(df_ref['bead_name'])
        missing_mov = set(bead_correspondence.keys()) - set(df_mov['bead_name'])
        
        if missing_ref:
            print(f"Warning: Reference beads not found: {missing_ref}")
        if missing_mov:
            print(f"Warning: Moving beads not found: {missing_mov}")
        
        # Filter to only valid correspondences
        bead_correspondence = {
            mov: ref for mov, ref in bead_correspondence.items()
            if mov in set(df_mov['bead_name']) and ref in set(df_ref['bead_name'])
        }
        
        if not bead_correspondence:
            print("Error: No valid bead correspondences.")
            return None
    
    print(f"Using {len(bead_correspondence)} bead pairs for alignment")
    
    # Calculate average positions for each bead (using earliest n_points)
    def get_bead_position(df, bead_name, n_points):
        bead_data = df[df['bead_name'] == bead_name]
        if len(bead_data) == 0:
            return None
        # Get n_points with smallest tim values
        earliest = bead_data.nsmallest(n_points, 'tim')
        # Return mean position as [z, y, x] for consistency with original code
        return earliest[['z', 'y', 'x']].mean(axis=0).to_numpy()
    
    # Build point clouds
    pts_ref = []
    pts_mov = []
    used_beads = []
    
    for bead_mov, bead_ref in bead_correspondence.items():
        pos_ref = get_bead_position(df_ref, bead_ref, n_points)
        pos_mov = get_bead_position(df_mov, bead_mov, n_points)
        
        if pos_ref is not None and pos_mov is not None:
            pts_ref.append(pos_ref)
            pts_mov.append(pos_mov)
            used_beads.append((bead_ref, bead_mov))
    
    if len(pts_ref) < 1:
        print(f"Error: Not enough valid bead pairs ({len(pts_ref)}).")
        return None
    
    pts_ref = np.array(pts_ref)
    pts_mov = np.array(pts_mov)
    
    print(f"Aligning using {len(pts_ref)} bead positions")
    
    # Warn if using translation-only mode
    if len(pts_ref) < 3:
        print(f"Warning: Only {len(pts_ref)} bead pair(s) available. Using translation-only alignment.")
    
    # Execute alignment
    try:
        model = point_registration(
            pts_ref,
            pts_mov,
            transform_type=transform_type,
        )
    except Exception as e:
        print(f"Error during registration: {e}")
        return None
    
    if model is None:
        print("Registration failed.")
        return None
    
    # Calculate residuals
    residuals = model.residuals(pts_mov, pts_ref)
    residual_mean = np.mean(residuals)
    residual_before = np.mean(np.linalg.norm(pts_mov - pts_ref, axis=1))
    
    alignment_mode = "translation-only" if len(pts_ref) < 3 else "rigid (rotation + translation)"
    print(f"Alignment completed using {alignment_mode} mode.")
    print(f"Mean residual: {residual_mean:.2f} nm.")
    print(f"Mean residual before alignment: {residual_before:.2f} nm.")
    
    return model

Align two datasets using bead localizations.

Args

reference_mbm_dict
MBM dictionary from reference dataset
moving_mbm_dict
MBM dictionary from dataset to be aligned
bead_correspondence
Optional dict mapping moving bead names to reference bead names. If None, assumes beads with same names correspond.
transform_type
Type of transformation ('euclidean' or 'affine')
n_points
Number of earliest time points to average per bead

Returns

model
The transformation model, or None if alignment fails
def drift_correction_time_windows_2d(x: numpy.ndarray,
y: numpy.ndarray,
t: numpy.ndarray,
sxy: float,
rx: tuple | None = None,
ry: tuple | None = None,
T: float | None = None,
tid: numpy.ndarray | None = None)
Expand source code
def drift_correction_time_windows_2d(
    x: np.ndarray,
    y: np.ndarray,
    t: np.ndarray,
    sxy: float,
    rx: Optional[tuple] = None,
    ry: Optional[tuple] = None,
    T: Optional[float] = None,
    tid: Optional[np.ndarray] = None,
):
    """Estimate 2D drift correction based on auto-correlation.

    Reimplemented (with modifications) from:

    * [paper] Ostersehlt, L.M., Jans, D.C., Wittek, A. et al. DNA-PAINT MINFLUX nanoscopy. Nat Methods 19, 1072-1075 (2022). https://doi.org/10.1038/s41592-022-01577-1
    * [code]  https://zenodo.org/record/6563100

    Parameters
    ----------

    x: np.ndarray
        Array of localization x coordinates.

    y: np.ndarray
        Array of localization y coordinates.

    t: np.ndarray
        Array of localization time points.

    sxy: float (Default = 1.0)
        Resolution in nm in both the x and y direction.

    rx: tuple (Optional)
        (min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()).

    ry: float (Optional)
        (min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()).

    T: float (Optional)
        Time window for analysis.

    tid: np.ndarray (Optional)
        Only used if T is None. The unique trace IDs are used to calculate the time window for
        analysis using some heuristics.
    """

    if T is None and tid is None:
        raise ValueError("If T is not defined, the array of TIDs must be provided.")

    # Make sure we are working with NumPy arrays
    x = np.array(x)
    y = np.array(y)
    t = np.array(t)
    if tid is not None:
        tid = np.array(tid)

    # Make sure we have valid ranges
    if rx is None:
        rx = (x.min(), x.max())

    if ry is None:
        ry = (y.min(), y.max())

    # Heuristics to define the length of the time window
    if T is None:
        rt = (t[0], t[-1])
        T = len(np.unique(tid)) * np.diff(rx)[0] * np.diff(ry)[0] / 3e6
        T = np.min([T, np.diff(rt)[0] / 2, 3600])  # At least two time windows
        T = max([T, 600])  # And at least 10 minutes long

    # Number of time windows to use
    Rt = [t[0], t[-1]]
    Ns = int(np.floor((Rt[1] - Rt[0]) / T))  # total number of time windows
    assert Ns > 1, "At least two time windows are needed, please reduce T"

    # Center of mass
    CR = 10

    # Maximum number of frames in the cross-correlation
    D = Ns

    # Weight with distance
    w = np.linspace(1, 0.5, D)

    # Regularization term (roughness)
    l = 0.1

    # get dimensions
    Nx = int(np.ceil((rx[1] - rx[0]) / sxy))
    Ny = int(np.ceil((ry[1] - ry[0]) / sxy))
    c = np.round(np.array([Ny, Nx]) / 2)

    # Create all the histograms
    h = [None] * Ns
    ti = np.zeros(Ns)  # Average times of the snapshots
    for j in range(Ns):
        t0 = Rt[0] + j * T
        idx = (t >= t0) & (t < t0 + T)
        ti[j] = np.mean(t[idx])
        hj, _, _, _ = render_xy(
            x[idx],
            y[idx],
            sx=sxy,
            sy=sxy,
            rx=rx,
            ry=ry,
            render_type="fixed_gaussian",
            fwhm=3 * sxy,
        )
        h[j] = hj

    # Compute fourier transform of every histogram
    for j in range(Ns):
        h[j] = fftn(h[j])

    # Compute cross-correlations
    dx = np.zeros((Ns, Ns))
    dy = np.zeros((Ns, Ns))
    dm = np.zeros((Ns, Ns), dtype=bool)
    dx0 = np.zeros(Ns - 1)
    dy0 = np.zeros(Ns - 1)

    for i in range(Ns - 1):
        hi = np.conj(h[i])
        for j in range(i + 1, min(Ns, i + D)):  # Either to Ns or maximally D more
            hj = ifftshift(np.real(ifftn(hi * h[j])))
            yc = c[0]
            xc = c[1]

            # Centroid estimation
            gy = np.arange(yc - 2 * CR, yc + 2 * CR + 1).astype(int)
            gx = np.arange(xc - 2 * CR, xc + 2 * CR + 1).astype(int)
            gy, gx = np.meshgrid(gy, gx, indexing="ij")
            d = hj[gy, gx]
            gy = gy.flatten()
            gx = gx.flatten()
            d = d.flatten() - np.min(d)
            d = d / np.sum(d)
            for k in range(20):
                wc = np.exp(-4 * np.log(2) * ((xc - gx) ** 2 + (yc - gy) ** 2) / CR**2)
                n = np.sum(wc * d)
                xc = np.sum(gx * d * wc) / n
                yc = np.sum(gy * d * wc) / n
            sh = np.array([-1.0, 1.0]) * (np.array([yc, xc]) - c)
            dy[i, j] = sh[0]
            dx[i, j] = sh[1]
            dm[i, j] = True
            if j == i + 1:
                dy0[i] = sh[0]
                dx0[i] = sh[1]

    a, b = np.nonzero(dm)
    dx = dx[dm]
    dy = dy[dm]
    sx0 = np.cumsum(dx0)
    sy0 = np.cumsum(dy0)

    # Minimize cost function with some kind of regularization
    options = {"disp": False, "maxiter": 1e5}

    minimizer = lambda x: lse_distance(x, a, b, dx, w, l)
    res = minimize(minimizer, sx0, options=options, method="BFGS")
    sx = np.concatenate(([0], res.x))

    minimizer = lambda x: lse_distance(x, a, b, dy, w, l)
    res = minimize(minimizer, sy0, options=options, method="BFGS")
    sy = np.concatenate(([0], res.x))

    # Reduce by mean (so shift is minimal)
    sx = sx - np.mean(sx)
    sy = sy - np.mean(sy)

    # Multiply by pixel size
    sx = sx * sxy
    sy = sy * sxy

    # Create interpolants
    fx = interp1d(ti, sx, kind="slinear", fill_value="extrapolate")
    fy = interp1d(ti, sy, kind="slinear", fill_value="extrapolate")

    # Correct drift
    dx = fx(t)
    dy = fy(t)

    # Apply to the time points used to estimate the correction
    dxt = fx(ti)
    dyt = fy(ti)

    return dx, dy, dxt, dyt, ti, T

Estimate 2D drift correction based on auto-correlation.

Reimplemented (with modifications) from:

Parameters

x : np.ndarray
Array of localization x coordinates.
y : np.ndarray
Array of localization y coordinates.
t : np.ndarray
Array of localization time points.
sxy : float (Default = 1.0)
Resolution in nm in both the x and y direction.
rx : tuple (Optional)
(min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()).
ry : float (Optional)
(min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()).
T : float (Optional)
Time window for analysis.
tid : np.ndarray (Optional)
Only used if T is None. The unique trace IDs are used to calculate the time window for analysis using some heuristics.
def drift_correction_time_windows_3d(x: numpy.ndarray,
y: numpy.ndarray,
z: numpy.ndarray,
t: numpy.ndarray,
sxyz: float,
rx: tuple | None = None,
ry: tuple | None = None,
rz: tuple | None = None,
T: float | None = None,
tid: numpy.ndarray | None = None)
Expand source code
def drift_correction_time_windows_3d(
    x: np.ndarray,
    y: np.ndarray,
    z: np.ndarray,
    t: np.ndarray,
    sxyz: float,
    rx: Optional[tuple] = None,
    ry: Optional[tuple] = None,
    rz: Optional[tuple] = None,
    T: Optional[float] = None,
    tid: Optional[np.ndarray] = None,
):
    """Estimate 3D drift correction based on auto-correlation.

    Reimplemented (with modifications) from:

    * [paper] Ostersehlt, L.M., Jans, D.C., Wittek, A. et al. DNA-PAINT MINFLUX nanoscopy. Nat Methods 19, 1072-1075 (2022). https://doi.org/10.1038/s41592-022-01577-1
    * [code]  https://zenodo.org/record/6563100

    Parameters
    ----------

    x: np.ndarray
        Array of localization x coordinates.

    y: np.ndarray
        Array of localization y coordinates.

    z: np.ndarray
        Array of localization z coordinates.

    t: np.ndarray
        Array of localization time points.

    sxyz: float (Default = 1.0)
        Resolution in nm in x, y and z direction.

    rx: tuple (Optional)
        (min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()).

    ry: float (Optional)
        (min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()).

    rz: float (Optional)
        (min, max) boundaries for the z coordinates. If omitted, it will default to (z.min(), z.max()).

    T: float (Optional)
        Time window for analysis.

    tid: np.ndarray (Optional)
        Only used if T is None. The unique trace IDs are used to calculate the time window for
        analysis using some heuristics.
    """

    if T is None and tid is None:
        raise ValueError("If T is not defined, the array of TIDs must be provided.")

    # Make sure we are working with NumPy arrays
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)
    t = np.array(t)
    if tid is not None:
        tid = np.array(tid)

    # Make sure we have valid ranges
    if rx is None:
        rx = (x.min(), x.max())

    if ry is None:
        ry = (y.min(), y.max())

    if rz is None:
        rz = (z.min(), z.max())

    # Heuristics to define the length of the time window
    if T is None:
        rt = (t[0], t[-1])
        T = len(np.unique(tid)) * np.diff(rx)[0] * np.diff(ry)[0] / 3e6
        T = np.min([T, np.diff(rt)[0] / 2, 3600])  # At least two time windows
        T = max([T, 600])  # And at least 10 minutes long

    # Number of time windows to use
    Rt = [t[0], t[-1]]
    Ns = int(np.floor((Rt[1] - Rt[0]) / T))  # total number of time windows
    assert Ns > 1, "At least two time windows are needed, please reduce T"

    # Center of mass
    CR = 8

    # Maximum number of frames in the cross-correlation
    D = Ns

    # Weight with distance
    w = np.linspace(1, 0.2, D)

    # Regularization term (roughness)
    l = 0.01

    # get dimensions
    Nx = int(np.ceil((rx[1] - rx[0]) / sxyz))
    Ny = int(np.ceil((ry[1] - ry[0]) / sxyz))
    Nz = int(np.ceil((rz[1] - rz[0]) / sxyz))
    c = np.round(np.array([Nz, Ny, Nx]) / 2)

    # Create all the histograms
    h = [None] * Ns
    ti = np.zeros(Ns)  # Average times of the snapshots
    for j in range(Ns):
        t0 = Rt[0] + j * T
        idx = (t >= t0) & (t < t0 + T)
        ti[j] = np.mean(t[idx])
        hj, _, _, _, _ = render_xyz(
            x[idx],
            y[idx],
            z[idx],
            sx=sxyz,
            sy=sxyz,
            sz=sxyz,
            rx=rx,
            ry=ry,
            rz=rz,
            render_type="fixed_gaussian",
            fwhm=3 * sxyz,
        )
        h[j] = hj

    # Compute fourier transform of every histogram
    for j in range(Ns):
        h[j] = fftn(h[j])

    # Compute cross-correlations
    dx = np.zeros((Ns, Ns))
    dy = np.zeros((Ns, Ns))
    dz = np.zeros((Ns, Ns))
    dm = np.zeros((Ns, Ns), dtype=bool)
    dx0 = np.zeros(Ns - 1)
    dy0 = np.zeros(Ns - 1)
    dz0 = np.zeros(Ns - 1)

    for i in range(Ns - 1):
        hi = np.conj(h[i])
        for j in range(i + 1, min(Ns, i + D)):  # either to Ns or maximally D more
            hj = ifftshift(np.real(ifftn(hi * h[j])))
            zc = c[0]
            yc = c[1]
            xc = c[2]

            # Centroid estimation
            gz = np.arange(zc - 2 * CR, zc + 2 * CR + 1).astype(int)
            gy = np.arange(yc - 2 * CR, yc + 2 * CR + 1).astype(int)
            gx = np.arange(xc - 2 * CR, xc + 2 * CR + 1).astype(int)
            gz, gy, gx = np.meshgrid(gz, gy, gx, indexing="ij")
            d = hj[gz, gy, gx]
            gz = gz.flatten()
            gy = gy.flatten()
            gx = gx.flatten()
            d = d.flatten() - np.min(d)
            d = d / np.sum(d)
            for k in range(20):
                wc = np.exp(
                    -4
                    * np.log(2)
                    * ((xc - gx) ** 2 + (yc - gy) ** 2 + (zc - gz) ** 2)
                    / CR**2
                )
                n = np.sum(wc * d)
                xc = np.sum(gx * d * wc) / n
                yc = np.sum(gy * d * wc) / n
                zc = np.sum(gz * d * wc) / n
            sh = np.array([1.0, -1.0, 1.0]) * (np.array([zc, yc, xc]) - c)
            dz[i, j] = sh[0]
            dy[i, j] = sh[1]
            dx[i, j] = sh[2]
            dm[i, j] = True
            if j == i + 1:
                dz0[i] = sh[0]
                dy0[i] = sh[1]
                dx0[i] = sh[2]

    a, b = np.nonzero(dm)
    dx = dx[dm]
    dy = dy[dm]
    dz = dz[dm]
    sx0 = np.cumsum(dx0)
    sy0 = np.cumsum(dy0)
    sz0 = np.cumsum(dz0)

    # Minimize cost function with some kind of regularization
    options = {"disp": False, "maxiter": 1e5}

    minimizer = lambda x: lse_distance(x, a, b, dx, w, l)
    res = minimize(minimizer, sx0, options=options, method="BFGS")
    sx = np.concatenate(([0], res.x))

    minimizer = lambda x: lse_distance(x, a, b, dy, w, l)
    res = minimize(minimizer, sy0, options=options, method="BFGS")
    sy = np.concatenate(([0], res.x))

    minimizer = lambda x: lse_distance(x, a, b, dz, w, l)
    res = minimize(minimizer, sz0, options=options, method="BFGS")
    sz = np.concatenate(([0], res.x))

    # Reduce by mean (so shift is minimal)
    sx = sx - np.mean(sx)
    sy = sy - np.mean(sy)
    sz = sz - np.mean(sz)

    # Multiply by voxel size
    sx = sx * sxyz
    sy = sy * sxyz
    sz = sz * sxyz

    # Create interpolants
    fx = interp1d(ti, sx, kind="slinear", fill_value="extrapolate")
    fy = interp1d(ti, sy, kind="slinear", fill_value="extrapolate")
    fz = interp1d(ti, sz, kind="slinear", fill_value="extrapolate")

    # Correct drift
    dx = fx(t)
    dy = fy(t)
    dz = fz(t)

    # Apply to the time points used to estimate the correction
    dxt = fx(ti)
    dyt = fy(ti)
    dzt = fz(ti)

    return dx, dy, dz, dxt, dyt, dzt, ti, T

Estimate 3D drift correction based on auto-correlation.

Reimplemented (with modifications) from:

Parameters

x : np.ndarray
Array of localization x coordinates.
y : np.ndarray
Array of localization y coordinates.
z : np.ndarray
Array of localization z coordinates.
t : np.ndarray
Array of localization time points.
sxyz : float (Default = 1.0)
Resolution in nm in x, y and z direction.
rx : tuple (Optional)
(min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()).
ry : float (Optional)
(min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()).
rz : float (Optional)
(min, max) boundaries for the z coordinates. If omitted, it will default to (z.min(), z.max()).
T : float (Optional)
Time window for analysis.
tid : np.ndarray (Optional)
Only used if T is None. The unique trace IDs are used to calculate the time window for analysis using some heuristics.
def mbm_dict_to_dataframe(mbm_dict: Dict, additional_metadata: Dict | None = None) ‑> pandas.core.frame.DataFrame
Expand source code
def mbm_dict_to_dataframe(mbm_dict: Dict, additional_metadata: Optional[Dict] = None) -> pd.DataFrame:
    """
    Convert mbm_dict to a single pandas DataFrame.
    
    Args:
        mbm_dict: Dictionary containing bead measurement data
        additional_metadata: Optional dictionary of metadata to add as columns
        
    Returns:
        pd.DataFrame: Combined DataFrame with all bead data
    """
    dfs = []
    for bead_id, bead_data in mbm_dict.items():
        points = bead_data['points']

        # Convert structured array to DataFrame (except xyz first)
        df_points = pd.DataFrame({
            'gri': points['gri'],
            'tim': points['tim'],
            'str': points['str']
        })
        
        # Expand xyz
        xyz_df = pd.DataFrame(points['xyz'], columns=['x', 'y', 'z'])
        xyz_df *= 1e9  # Convert to nm
        df_points = pd.concat([df_points, xyz_df], axis=1)
        
        # Add metadata columns
        df_points['bead_name'] = bead_data['bead_name']
        df_points['bead_gri'] = bead_data['gri']
        df_points['used'] = bead_data['used']

        if additional_metadata:
            for key, value in additional_metadata.items():
                df_points[key] = value
        
        dfs.append(df_points)
    
    if not dfs:
        return pd.DataFrame()
    
    return pd.concat(dfs, ignore_index=True)

Convert mbm_dict to a single pandas DataFrame.

Args

mbm_dict
Dictionary containing bead measurement data
additional_metadata
Optional dictionary of metadata to add as columns

Returns

pd.DataFrame
Combined DataFrame with all bead data
def point_registration(pts_fixed: numpy.ndarray,
pts_moving: numpy.ndarray,
transform_type: str = 'euclidean') ‑> pyminflux.correct._bead_alignment.RigidTransform | pyminflux.correct._bead_alignment.TranslationTransform | None
Expand source code
def point_registration(
    pts_fixed: np.ndarray,
    pts_moving: np.ndarray,
    transform_type: str = 'euclidean',
) -> Optional[Union[RigidTransform, TranslationTransform]]:
    """
    Estimate transformation between two sets of corresponding points.
    
    For 3+ correspondences, the Kabsch algorithm is used for rigid alignment.
    When fewer than 3 correspondences are available, a translation-only transform is used.

    Args:
        pts_fixed (np.ndarray): The (N, D) array of points in the fixed coordinate system.
        pts_moving (np.ndarray): The (N, D) array of points in the moving coordinate system.
        transform_type (str): The type of transformation to estimate. 
                            Currently only 'euclidean' (rigid: rotation+translation) is supported.

    Returns:
        Optional[Union[RigidTransform, TranslationTransform]]: The estimated transformation model.
    """
    if transform_type.lower() != 'euclidean':
        raise ValueError("Only 'euclidean' transform type is currently supported.")
    
    n_points, d = pts_moving.shape
    
    if n_points < 1:
        raise ValueError(
            f"Not enough data points ({n_points}) for alignment. "
            f"At least 1 point is required."
        )
    
    if pts_fixed.shape != pts_moving.shape:
        raise ValueError(
            f"Point arrays must have the same shape. "
            f"Got pts_fixed: {pts_fixed.shape}, pts_moving: {pts_moving.shape}"
        )
    
    # Use translation-only transform for 1-2 correspondences
    if n_points < 3:
        # Compute mean translation
        t = pts_fixed.mean(axis=0) - pts_moving.mean(axis=0)
        
        # Create translation-only transformation model
        return TranslationTransform(t)
    
    # Compute optimal rigid transformation using Kabsch algorithm for 3+ correspondences
    R, t = _kabsch(pts_moving, pts_fixed, allow_reflection=False)
    
    # Create transformation model
    return RigidTransform(R, t)

Estimate transformation between two sets of corresponding points.

For 3+ correspondences, the Kabsch algorithm is used for rigid alignment. When fewer than 3 correspondences are available, a translation-only transform is used.

Args

pts_fixed : np.ndarray
The (N, D) array of points in the fixed coordinate system.
pts_moving : np.ndarray
The (N, D) array of points in the moving coordinate system.
transform_type : str
The type of transformation to estimate. Currently only 'euclidean' (rigid: rotation+translation) is supported.

Returns

Optional[Union[RigidTransform, TranslationTransform]]
The estimated transformation model.

Classes

class RigidTransform (rotation: numpy.ndarray, translation: numpy.ndarray)
Expand source code
class RigidTransform:
    """Simple rigid transformation model compatible with the existing interface."""
    
    def __init__(self, rotation: np.ndarray, translation: np.ndarray):
        """
        Initialize rigid transform.
        
        Args:
            rotation: (d, d) rotation matrix
            translation: (d,) translation vector
        """
        self.rotation = rotation
        self.translation = translation
        self.dimensionality = rotation.shape[0]
        
        # Build homogeneous transformation matrix for compatibility
        d = self.dimensionality
        self.params = np.eye(d + 1)
        self.params[:d, :d] = rotation
        self.params[:d, d] = translation
    
    def __call__(self, coords: np.ndarray) -> np.ndarray:
        """
        Apply transformation to coordinates.
        
        Args:
            coords: (n, d) array of coordinates
            
        Returns:
            Transformed coordinates (n, d)
        """
        return (self.rotation @ coords.T).T + self.translation
    
    def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray:
        """
        Calculate residuals between transformed source and destination.
        
        Args:
            src: (n, d) source points
            dst: (n, d) destination points
            
        Returns:
            (n,) array of residuals (Euclidean distances)
        """
        transformed = self(src)
        return np.linalg.norm(transformed - dst, axis=1)

Simple rigid transformation model compatible with the existing interface.

Initialize rigid transform.

Args

rotation
(d, d) rotation matrix
translation
(d,) translation vector

Methods

def residuals(self, src: numpy.ndarray, dst: numpy.ndarray) ‑> numpy.ndarray
Expand source code
def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray:
    """
    Calculate residuals between transformed source and destination.
    
    Args:
        src: (n, d) source points
        dst: (n, d) destination points
        
    Returns:
        (n,) array of residuals (Euclidean distances)
    """
    transformed = self(src)
    return np.linalg.norm(transformed - dst, axis=1)

Calculate residuals between transformed source and destination.

Args

src
(n, d) source points
dst
(n, d) destination points

Returns

(n,) array of residuals (Euclidean distances)

class TranslationTransform (translation: numpy.ndarray)
Expand source code
class TranslationTransform:
    """Translation-only transformation model for cases with insufficient correspondences."""
    
    def __init__(self, translation: np.ndarray):
        """
        Initialize translation-only transform.
        
        Args:
            translation: (d,) translation vector
        """
        self.translation = translation
        self.dimensionality = translation.shape[0]
        
        # Build homogeneous transformation matrix for compatibility
        d = self.dimensionality
        self.params = np.eye(d + 1)
        self.params[:d, d] = translation
        
        # Store identity rotation for compatibility
        self.rotation = np.eye(d)
    
    def __call__(self, coords: np.ndarray) -> np.ndarray:
        """
        Apply transformation to coordinates.
        
        Args:
            coords: (n, d) array of coordinates
            
        Returns:
            Transformed coordinates (n, d)
        """
        return coords + self.translation
    
    def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray:
        """
        Calculate residuals between transformed source and destination.
        
        Args:
            src: (n, d) source points
            dst: (n, d) destination points
            
        Returns:
            (n,) array of residuals (Euclidean distances)
        """
        transformed = self(src)
        return np.linalg.norm(transformed - dst, axis=1)

Translation-only transformation model for cases with insufficient correspondences.

Initialize translation-only transform.

Args

translation
(d,) translation vector

Methods

def residuals(self, src: numpy.ndarray, dst: numpy.ndarray) ‑> numpy.ndarray
Expand source code
def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray:
    """
    Calculate residuals between transformed source and destination.
    
    Args:
        src: (n, d) source points
        dst: (n, d) destination points
        
    Returns:
        (n,) array of residuals (Euclidean distances)
    """
    transformed = self(src)
    return np.linalg.norm(transformed - dst, axis=1)

Calculate residuals between transformed source and destination.

Args

src
(n, d) source points
dst
(n, d) destination points

Returns

(n,) array of residuals (Euclidean distances)