Module pyminflux.correct
Correction/restoration functions.
Functions
def align_datasets_using_beads(reference_mbm_dict: Dict,
moving_mbm_dict: Dict,
bead_correspondence: Dict[str, str] | None = None,
transform_type: str = 'euclidean',
n_points: int | None = 3) ‑> object | None-
Expand source code
def align_datasets_using_beads( reference_mbm_dict: Dict, moving_mbm_dict: Dict, bead_correspondence: Optional[Dict[str, str]] = None, transform_type: str = 'euclidean', n_points: Optional[int] = 3, ) -> Optional[object]: """ Align two datasets using bead localizations. Args: reference_mbm_dict: MBM dictionary from reference dataset moving_mbm_dict: MBM dictionary from dataset to be aligned bead_correspondence: Optional dict mapping moving bead names to reference bead names. If None, assumes beads with same names correspond. transform_type: Type of transformation ('euclidean' or 'affine') n_points: Number of earliest time points to average per bead Returns: model: The transformation model, or None if alignment fails """ # Create dataframes from mbm dicts df_ref = mbm_dict_to_dataframe(reference_mbm_dict, additional_metadata={'type': 'reference'}) df_mov = mbm_dict_to_dataframe(moving_mbm_dict, additional_metadata={'type': 'moving'}) if df_ref.empty or df_mov.empty: print("Error: One or both datasets have no bead data.") return None # Filter to only include beads marked as "used" df_ref = df_ref[df_ref['used'] == True] df_mov = df_mov[df_mov['used'] == True] if df_ref.empty or df_mov.empty: print("Error: No beads marked as 'used' in one or both datasets.") return None # Identify beads to use for alignment if bead_correspondence is None: # Automatic: use beads with matching names common_beads_ref = set(df_ref['bead_name']) common_beads_mov = set(df_mov['bead_name']) common_beads = common_beads_ref.intersection(common_beads_mov) if not common_beads: print("Error: No common beads found between datasets.") return None # Create identity mapping bead_correspondence = {bn: bn for bn in common_beads} else: # Manual correspondence provided # Validate that all beads exist missing_ref = set(bead_correspondence.values()) - set(df_ref['bead_name']) missing_mov = set(bead_correspondence.keys()) - set(df_mov['bead_name']) if missing_ref: print(f"Warning: Reference beads not found: {missing_ref}") if missing_mov: print(f"Warning: Moving beads not found: {missing_mov}") # Filter to only valid correspondences bead_correspondence = { mov: ref for mov, ref in bead_correspondence.items() if mov in set(df_mov['bead_name']) and ref in set(df_ref['bead_name']) } if not bead_correspondence: print("Error: No valid bead correspondences.") return None print(f"Using {len(bead_correspondence)} bead pairs for alignment") # Calculate average positions for each bead (using earliest n_points) def get_bead_position(df, bead_name, n_points): bead_data = df[df['bead_name'] == bead_name] if len(bead_data) == 0: return None # Get n_points with smallest tim values earliest = bead_data.nsmallest(n_points, 'tim') # Return mean position as [z, y, x] for consistency with original code return earliest[['z', 'y', 'x']].mean(axis=0).to_numpy() # Build point clouds pts_ref = [] pts_mov = [] used_beads = [] for bead_mov, bead_ref in bead_correspondence.items(): pos_ref = get_bead_position(df_ref, bead_ref, n_points) pos_mov = get_bead_position(df_mov, bead_mov, n_points) if pos_ref is not None and pos_mov is not None: pts_ref.append(pos_ref) pts_mov.append(pos_mov) used_beads.append((bead_ref, bead_mov)) if len(pts_ref) < 1: print(f"Error: Not enough valid bead pairs ({len(pts_ref)}).") return None pts_ref = np.array(pts_ref) pts_mov = np.array(pts_mov) print(f"Aligning using {len(pts_ref)} bead positions") # Warn if using translation-only mode if len(pts_ref) < 3: print(f"Warning: Only {len(pts_ref)} bead pair(s) available. Using translation-only alignment.") # Execute alignment try: model = point_registration( pts_ref, pts_mov, transform_type=transform_type, ) except Exception as e: print(f"Error during registration: {e}") return None if model is None: print("Registration failed.") return None # Calculate residuals residuals = model.residuals(pts_mov, pts_ref) residual_mean = np.mean(residuals) residual_before = np.mean(np.linalg.norm(pts_mov - pts_ref, axis=1)) alignment_mode = "translation-only" if len(pts_ref) < 3 else "rigid (rotation + translation)" print(f"Alignment completed using {alignment_mode} mode.") print(f"Mean residual: {residual_mean:.2f} nm.") print(f"Mean residual before alignment: {residual_before:.2f} nm.") return modelAlign two datasets using bead localizations.
Args
reference_mbm_dict- MBM dictionary from reference dataset
moving_mbm_dict- MBM dictionary from dataset to be aligned
bead_correspondence- Optional dict mapping moving bead names to reference bead names. If None, assumes beads with same names correspond.
transform_type- Type of transformation ('euclidean' or 'affine')
n_points- Number of earliest time points to average per bead
Returns
model- The transformation model, or None if alignment fails
def drift_correction_time_windows_2d(x: numpy.ndarray,
y: numpy.ndarray,
t: numpy.ndarray,
sxy: float,
rx: tuple | None = None,
ry: tuple | None = None,
T: float | None = None,
tid: numpy.ndarray | None = None)-
Expand source code
def drift_correction_time_windows_2d( x: np.ndarray, y: np.ndarray, t: np.ndarray, sxy: float, rx: Optional[tuple] = None, ry: Optional[tuple] = None, T: Optional[float] = None, tid: Optional[np.ndarray] = None, ): """Estimate 2D drift correction based on auto-correlation. Reimplemented (with modifications) from: * [paper] Ostersehlt, L.M., Jans, D.C., Wittek, A. et al. DNA-PAINT MINFLUX nanoscopy. Nat Methods 19, 1072-1075 (2022). https://doi.org/10.1038/s41592-022-01577-1 * [code] https://zenodo.org/record/6563100 Parameters ---------- x: np.ndarray Array of localization x coordinates. y: np.ndarray Array of localization y coordinates. t: np.ndarray Array of localization time points. sxy: float (Default = 1.0) Resolution in nm in both the x and y direction. rx: tuple (Optional) (min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()). ry: float (Optional) (min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()). T: float (Optional) Time window for analysis. tid: np.ndarray (Optional) Only used if T is None. The unique trace IDs are used to calculate the time window for analysis using some heuristics. """ if T is None and tid is None: raise ValueError("If T is not defined, the array of TIDs must be provided.") # Make sure we are working with NumPy arrays x = np.array(x) y = np.array(y) t = np.array(t) if tid is not None: tid = np.array(tid) # Make sure we have valid ranges if rx is None: rx = (x.min(), x.max()) if ry is None: ry = (y.min(), y.max()) # Heuristics to define the length of the time window if T is None: rt = (t[0], t[-1]) T = len(np.unique(tid)) * np.diff(rx)[0] * np.diff(ry)[0] / 3e6 T = np.min([T, np.diff(rt)[0] / 2, 3600]) # At least two time windows T = max([T, 600]) # And at least 10 minutes long # Number of time windows to use Rt = [t[0], t[-1]] Ns = int(np.floor((Rt[1] - Rt[0]) / T)) # total number of time windows assert Ns > 1, "At least two time windows are needed, please reduce T" # Center of mass CR = 10 # Maximum number of frames in the cross-correlation D = Ns # Weight with distance w = np.linspace(1, 0.5, D) # Regularization term (roughness) l = 0.1 # get dimensions Nx = int(np.ceil((rx[1] - rx[0]) / sxy)) Ny = int(np.ceil((ry[1] - ry[0]) / sxy)) c = np.round(np.array([Ny, Nx]) / 2) # Create all the histograms h = [None] * Ns ti = np.zeros(Ns) # Average times of the snapshots for j in range(Ns): t0 = Rt[0] + j * T idx = (t >= t0) & (t < t0 + T) ti[j] = np.mean(t[idx]) hj, _, _, _ = render_xy( x[idx], y[idx], sx=sxy, sy=sxy, rx=rx, ry=ry, render_type="fixed_gaussian", fwhm=3 * sxy, ) h[j] = hj # Compute fourier transform of every histogram for j in range(Ns): h[j] = fftn(h[j]) # Compute cross-correlations dx = np.zeros((Ns, Ns)) dy = np.zeros((Ns, Ns)) dm = np.zeros((Ns, Ns), dtype=bool) dx0 = np.zeros(Ns - 1) dy0 = np.zeros(Ns - 1) for i in range(Ns - 1): hi = np.conj(h[i]) for j in range(i + 1, min(Ns, i + D)): # Either to Ns or maximally D more hj = ifftshift(np.real(ifftn(hi * h[j]))) yc = c[0] xc = c[1] # Centroid estimation gy = np.arange(yc - 2 * CR, yc + 2 * CR + 1).astype(int) gx = np.arange(xc - 2 * CR, xc + 2 * CR + 1).astype(int) gy, gx = np.meshgrid(gy, gx, indexing="ij") d = hj[gy, gx] gy = gy.flatten() gx = gx.flatten() d = d.flatten() - np.min(d) d = d / np.sum(d) for k in range(20): wc = np.exp(-4 * np.log(2) * ((xc - gx) ** 2 + (yc - gy) ** 2) / CR**2) n = np.sum(wc * d) xc = np.sum(gx * d * wc) / n yc = np.sum(gy * d * wc) / n sh = np.array([-1.0, 1.0]) * (np.array([yc, xc]) - c) dy[i, j] = sh[0] dx[i, j] = sh[1] dm[i, j] = True if j == i + 1: dy0[i] = sh[0] dx0[i] = sh[1] a, b = np.nonzero(dm) dx = dx[dm] dy = dy[dm] sx0 = np.cumsum(dx0) sy0 = np.cumsum(dy0) # Minimize cost function with some kind of regularization options = {"disp": False, "maxiter": 1e5} minimizer = lambda x: lse_distance(x, a, b, dx, w, l) res = minimize(minimizer, sx0, options=options, method="BFGS") sx = np.concatenate(([0], res.x)) minimizer = lambda x: lse_distance(x, a, b, dy, w, l) res = minimize(minimizer, sy0, options=options, method="BFGS") sy = np.concatenate(([0], res.x)) # Reduce by mean (so shift is minimal) sx = sx - np.mean(sx) sy = sy - np.mean(sy) # Multiply by pixel size sx = sx * sxy sy = sy * sxy # Create interpolants fx = interp1d(ti, sx, kind="slinear", fill_value="extrapolate") fy = interp1d(ti, sy, kind="slinear", fill_value="extrapolate") # Correct drift dx = fx(t) dy = fy(t) # Apply to the time points used to estimate the correction dxt = fx(ti) dyt = fy(ti) return dx, dy, dxt, dyt, ti, TEstimate 2D drift correction based on auto-correlation.
Reimplemented (with modifications) from:
- [paper] Ostersehlt, L.M., Jans, D.C., Wittek, A. et al. DNA-PAINT MINFLUX nanoscopy. Nat Methods 19, 1072-1075 (2022). https://doi.org/10.1038/s41592-022-01577-1
- [code] https://zenodo.org/record/6563100
Parameters
x:np.ndarray- Array of localization x coordinates.
y:np.ndarray- Array of localization y coordinates.
t:np.ndarray- Array of localization time points.
sxy:float (Default = 1.0)- Resolution in nm in both the x and y direction.
rx:tuple (Optional)- (min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()).
ry:float (Optional)- (min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()).
T:float (Optional)- Time window for analysis.
tid:np.ndarray (Optional)- Only used if T is None. The unique trace IDs are used to calculate the time window for analysis using some heuristics.
def drift_correction_time_windows_3d(x: numpy.ndarray,
y: numpy.ndarray,
z: numpy.ndarray,
t: numpy.ndarray,
sxyz: float,
rx: tuple | None = None,
ry: tuple | None = None,
rz: tuple | None = None,
T: float | None = None,
tid: numpy.ndarray | None = None)-
Expand source code
def drift_correction_time_windows_3d( x: np.ndarray, y: np.ndarray, z: np.ndarray, t: np.ndarray, sxyz: float, rx: Optional[tuple] = None, ry: Optional[tuple] = None, rz: Optional[tuple] = None, T: Optional[float] = None, tid: Optional[np.ndarray] = None, ): """Estimate 3D drift correction based on auto-correlation. Reimplemented (with modifications) from: * [paper] Ostersehlt, L.M., Jans, D.C., Wittek, A. et al. DNA-PAINT MINFLUX nanoscopy. Nat Methods 19, 1072-1075 (2022). https://doi.org/10.1038/s41592-022-01577-1 * [code] https://zenodo.org/record/6563100 Parameters ---------- x: np.ndarray Array of localization x coordinates. y: np.ndarray Array of localization y coordinates. z: np.ndarray Array of localization z coordinates. t: np.ndarray Array of localization time points. sxyz: float (Default = 1.0) Resolution in nm in x, y and z direction. rx: tuple (Optional) (min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()). ry: float (Optional) (min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()). rz: float (Optional) (min, max) boundaries for the z coordinates. If omitted, it will default to (z.min(), z.max()). T: float (Optional) Time window for analysis. tid: np.ndarray (Optional) Only used if T is None. The unique trace IDs are used to calculate the time window for analysis using some heuristics. """ if T is None and tid is None: raise ValueError("If T is not defined, the array of TIDs must be provided.") # Make sure we are working with NumPy arrays x = np.array(x) y = np.array(y) z = np.array(z) t = np.array(t) if tid is not None: tid = np.array(tid) # Make sure we have valid ranges if rx is None: rx = (x.min(), x.max()) if ry is None: ry = (y.min(), y.max()) if rz is None: rz = (z.min(), z.max()) # Heuristics to define the length of the time window if T is None: rt = (t[0], t[-1]) T = len(np.unique(tid)) * np.diff(rx)[0] * np.diff(ry)[0] / 3e6 T = np.min([T, np.diff(rt)[0] / 2, 3600]) # At least two time windows T = max([T, 600]) # And at least 10 minutes long # Number of time windows to use Rt = [t[0], t[-1]] Ns = int(np.floor((Rt[1] - Rt[0]) / T)) # total number of time windows assert Ns > 1, "At least two time windows are needed, please reduce T" # Center of mass CR = 8 # Maximum number of frames in the cross-correlation D = Ns # Weight with distance w = np.linspace(1, 0.2, D) # Regularization term (roughness) l = 0.01 # get dimensions Nx = int(np.ceil((rx[1] - rx[0]) / sxyz)) Ny = int(np.ceil((ry[1] - ry[0]) / sxyz)) Nz = int(np.ceil((rz[1] - rz[0]) / sxyz)) c = np.round(np.array([Nz, Ny, Nx]) / 2) # Create all the histograms h = [None] * Ns ti = np.zeros(Ns) # Average times of the snapshots for j in range(Ns): t0 = Rt[0] + j * T idx = (t >= t0) & (t < t0 + T) ti[j] = np.mean(t[idx]) hj, _, _, _, _ = render_xyz( x[idx], y[idx], z[idx], sx=sxyz, sy=sxyz, sz=sxyz, rx=rx, ry=ry, rz=rz, render_type="fixed_gaussian", fwhm=3 * sxyz, ) h[j] = hj # Compute fourier transform of every histogram for j in range(Ns): h[j] = fftn(h[j]) # Compute cross-correlations dx = np.zeros((Ns, Ns)) dy = np.zeros((Ns, Ns)) dz = np.zeros((Ns, Ns)) dm = np.zeros((Ns, Ns), dtype=bool) dx0 = np.zeros(Ns - 1) dy0 = np.zeros(Ns - 1) dz0 = np.zeros(Ns - 1) for i in range(Ns - 1): hi = np.conj(h[i]) for j in range(i + 1, min(Ns, i + D)): # either to Ns or maximally D more hj = ifftshift(np.real(ifftn(hi * h[j]))) zc = c[0] yc = c[1] xc = c[2] # Centroid estimation gz = np.arange(zc - 2 * CR, zc + 2 * CR + 1).astype(int) gy = np.arange(yc - 2 * CR, yc + 2 * CR + 1).astype(int) gx = np.arange(xc - 2 * CR, xc + 2 * CR + 1).astype(int) gz, gy, gx = np.meshgrid(gz, gy, gx, indexing="ij") d = hj[gz, gy, gx] gz = gz.flatten() gy = gy.flatten() gx = gx.flatten() d = d.flatten() - np.min(d) d = d / np.sum(d) for k in range(20): wc = np.exp( -4 * np.log(2) * ((xc - gx) ** 2 + (yc - gy) ** 2 + (zc - gz) ** 2) / CR**2 ) n = np.sum(wc * d) xc = np.sum(gx * d * wc) / n yc = np.sum(gy * d * wc) / n zc = np.sum(gz * d * wc) / n sh = np.array([1.0, -1.0, 1.0]) * (np.array([zc, yc, xc]) - c) dz[i, j] = sh[0] dy[i, j] = sh[1] dx[i, j] = sh[2] dm[i, j] = True if j == i + 1: dz0[i] = sh[0] dy0[i] = sh[1] dx0[i] = sh[2] a, b = np.nonzero(dm) dx = dx[dm] dy = dy[dm] dz = dz[dm] sx0 = np.cumsum(dx0) sy0 = np.cumsum(dy0) sz0 = np.cumsum(dz0) # Minimize cost function with some kind of regularization options = {"disp": False, "maxiter": 1e5} minimizer = lambda x: lse_distance(x, a, b, dx, w, l) res = minimize(minimizer, sx0, options=options, method="BFGS") sx = np.concatenate(([0], res.x)) minimizer = lambda x: lse_distance(x, a, b, dy, w, l) res = minimize(minimizer, sy0, options=options, method="BFGS") sy = np.concatenate(([0], res.x)) minimizer = lambda x: lse_distance(x, a, b, dz, w, l) res = minimize(minimizer, sz0, options=options, method="BFGS") sz = np.concatenate(([0], res.x)) # Reduce by mean (so shift is minimal) sx = sx - np.mean(sx) sy = sy - np.mean(sy) sz = sz - np.mean(sz) # Multiply by voxel size sx = sx * sxyz sy = sy * sxyz sz = sz * sxyz # Create interpolants fx = interp1d(ti, sx, kind="slinear", fill_value="extrapolate") fy = interp1d(ti, sy, kind="slinear", fill_value="extrapolate") fz = interp1d(ti, sz, kind="slinear", fill_value="extrapolate") # Correct drift dx = fx(t) dy = fy(t) dz = fz(t) # Apply to the time points used to estimate the correction dxt = fx(ti) dyt = fy(ti) dzt = fz(ti) return dx, dy, dz, dxt, dyt, dzt, ti, TEstimate 3D drift correction based on auto-correlation.
Reimplemented (with modifications) from:
- [paper] Ostersehlt, L.M., Jans, D.C., Wittek, A. et al. DNA-PAINT MINFLUX nanoscopy. Nat Methods 19, 1072-1075 (2022). https://doi.org/10.1038/s41592-022-01577-1
- [code] https://zenodo.org/record/6563100
Parameters
x:np.ndarray- Array of localization x coordinates.
y:np.ndarray- Array of localization y coordinates.
z:np.ndarray- Array of localization z coordinates.
t:np.ndarray- Array of localization time points.
sxyz:float (Default = 1.0)- Resolution in nm in x, y and z direction.
rx:tuple (Optional)- (min, max) boundaries for the x coordinates. If omitted, it will default to (x.min(), x.max()).
ry:float (Optional)- (min, max) boundaries for the y coordinates. If omitted, it will default to (y.min(), y.max()).
rz:float (Optional)- (min, max) boundaries for the z coordinates. If omitted, it will default to (z.min(), z.max()).
T:float (Optional)- Time window for analysis.
tid:np.ndarray (Optional)- Only used if T is None. The unique trace IDs are used to calculate the time window for analysis using some heuristics.
def mbm_dict_to_dataframe(mbm_dict: Dict, additional_metadata: Dict | None = None) ‑> pandas.core.frame.DataFrame-
Expand source code
def mbm_dict_to_dataframe(mbm_dict: Dict, additional_metadata: Optional[Dict] = None) -> pd.DataFrame: """ Convert mbm_dict to a single pandas DataFrame. Args: mbm_dict: Dictionary containing bead measurement data additional_metadata: Optional dictionary of metadata to add as columns Returns: pd.DataFrame: Combined DataFrame with all bead data """ dfs = [] for bead_id, bead_data in mbm_dict.items(): points = bead_data['points'] # Convert structured array to DataFrame (except xyz first) df_points = pd.DataFrame({ 'gri': points['gri'], 'tim': points['tim'], 'str': points['str'] }) # Expand xyz xyz_df = pd.DataFrame(points['xyz'], columns=['x', 'y', 'z']) xyz_df *= 1e9 # Convert to nm df_points = pd.concat([df_points, xyz_df], axis=1) # Add metadata columns df_points['bead_name'] = bead_data['bead_name'] df_points['bead_gri'] = bead_data['gri'] df_points['used'] = bead_data['used'] if additional_metadata: for key, value in additional_metadata.items(): df_points[key] = value dfs.append(df_points) if not dfs: return pd.DataFrame() return pd.concat(dfs, ignore_index=True)Convert mbm_dict to a single pandas DataFrame.
Args
mbm_dict- Dictionary containing bead measurement data
additional_metadata- Optional dictionary of metadata to add as columns
Returns
pd.DataFrame- Combined DataFrame with all bead data
def point_registration(pts_fixed: numpy.ndarray,
pts_moving: numpy.ndarray,
transform_type: str = 'euclidean') ‑> pyminflux.correct._bead_alignment.RigidTransform | pyminflux.correct._bead_alignment.TranslationTransform | None-
Expand source code
def point_registration( pts_fixed: np.ndarray, pts_moving: np.ndarray, transform_type: str = 'euclidean', ) -> Optional[Union[RigidTransform, TranslationTransform]]: """ Estimate transformation between two sets of corresponding points. For 3+ correspondences, the Kabsch algorithm is used for rigid alignment. When fewer than 3 correspondences are available, a translation-only transform is used. Args: pts_fixed (np.ndarray): The (N, D) array of points in the fixed coordinate system. pts_moving (np.ndarray): The (N, D) array of points in the moving coordinate system. transform_type (str): The type of transformation to estimate. Currently only 'euclidean' (rigid: rotation+translation) is supported. Returns: Optional[Union[RigidTransform, TranslationTransform]]: The estimated transformation model. """ if transform_type.lower() != 'euclidean': raise ValueError("Only 'euclidean' transform type is currently supported.") n_points, d = pts_moving.shape if n_points < 1: raise ValueError( f"Not enough data points ({n_points}) for alignment. " f"At least 1 point is required." ) if pts_fixed.shape != pts_moving.shape: raise ValueError( f"Point arrays must have the same shape. " f"Got pts_fixed: {pts_fixed.shape}, pts_moving: {pts_moving.shape}" ) # Use translation-only transform for 1-2 correspondences if n_points < 3: # Compute mean translation t = pts_fixed.mean(axis=0) - pts_moving.mean(axis=0) # Create translation-only transformation model return TranslationTransform(t) # Compute optimal rigid transformation using Kabsch algorithm for 3+ correspondences R, t = _kabsch(pts_moving, pts_fixed, allow_reflection=False) # Create transformation model return RigidTransform(R, t)Estimate transformation between two sets of corresponding points.
For 3+ correspondences, the Kabsch algorithm is used for rigid alignment. When fewer than 3 correspondences are available, a translation-only transform is used.
Args
pts_fixed:np.ndarray- The (N, D) array of points in the fixed coordinate system.
pts_moving:np.ndarray- The (N, D) array of points in the moving coordinate system.
transform_type:str- The type of transformation to estimate. Currently only 'euclidean' (rigid: rotation+translation) is supported.
Returns
Optional[Union[RigidTransform, TranslationTransform]]- The estimated transformation model.
Classes
class RigidTransform (rotation: numpy.ndarray, translation: numpy.ndarray)-
Expand source code
class RigidTransform: """Simple rigid transformation model compatible with the existing interface.""" def __init__(self, rotation: np.ndarray, translation: np.ndarray): """ Initialize rigid transform. Args: rotation: (d, d) rotation matrix translation: (d,) translation vector """ self.rotation = rotation self.translation = translation self.dimensionality = rotation.shape[0] # Build homogeneous transformation matrix for compatibility d = self.dimensionality self.params = np.eye(d + 1) self.params[:d, :d] = rotation self.params[:d, d] = translation def __call__(self, coords: np.ndarray) -> np.ndarray: """ Apply transformation to coordinates. Args: coords: (n, d) array of coordinates Returns: Transformed coordinates (n, d) """ return (self.rotation @ coords.T).T + self.translation def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray: """ Calculate residuals between transformed source and destination. Args: src: (n, d) source points dst: (n, d) destination points Returns: (n,) array of residuals (Euclidean distances) """ transformed = self(src) return np.linalg.norm(transformed - dst, axis=1)Simple rigid transformation model compatible with the existing interface.
Initialize rigid transform.
Args
rotation- (d, d) rotation matrix
translation- (d,) translation vector
Methods
def residuals(self, src: numpy.ndarray, dst: numpy.ndarray) ‑> numpy.ndarray-
Expand source code
def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray: """ Calculate residuals between transformed source and destination. Args: src: (n, d) source points dst: (n, d) destination points Returns: (n,) array of residuals (Euclidean distances) """ transformed = self(src) return np.linalg.norm(transformed - dst, axis=1)Calculate residuals between transformed source and destination.
Args
src- (n, d) source points
dst- (n, d) destination points
Returns
(n,) array of residuals (Euclidean distances)
class TranslationTransform (translation: numpy.ndarray)-
Expand source code
class TranslationTransform: """Translation-only transformation model for cases with insufficient correspondences.""" def __init__(self, translation: np.ndarray): """ Initialize translation-only transform. Args: translation: (d,) translation vector """ self.translation = translation self.dimensionality = translation.shape[0] # Build homogeneous transformation matrix for compatibility d = self.dimensionality self.params = np.eye(d + 1) self.params[:d, d] = translation # Store identity rotation for compatibility self.rotation = np.eye(d) def __call__(self, coords: np.ndarray) -> np.ndarray: """ Apply transformation to coordinates. Args: coords: (n, d) array of coordinates Returns: Transformed coordinates (n, d) """ return coords + self.translation def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray: """ Calculate residuals between transformed source and destination. Args: src: (n, d) source points dst: (n, d) destination points Returns: (n,) array of residuals (Euclidean distances) """ transformed = self(src) return np.linalg.norm(transformed - dst, axis=1)Translation-only transformation model for cases with insufficient correspondences.
Initialize translation-only transform.
Args
translation- (d,) translation vector
Methods
def residuals(self, src: numpy.ndarray, dst: numpy.ndarray) ‑> numpy.ndarray-
Expand source code
def residuals(self, src: np.ndarray, dst: np.ndarray) -> np.ndarray: """ Calculate residuals between transformed source and destination. Args: src: (n, d) source points dst: (n, d) destination points Returns: (n,) array of residuals (Euclidean distances) """ transformed = self(src) return np.linalg.norm(transformed - dst, axis=1)Calculate residuals between transformed source and destination.
Args
src- (n, d) source points
dst- (n, d) destination points
Returns
(n,) array of residuals (Euclidean distances)