Source code for stdpipe.astrometry_quad

"""
Quad-hash astrometry solver (Python-only) with practical upgrades:

1) Multi-scale quad pools:
   - quads are partitioned into baseline-length *rank bins* (e.g. 6 bins)
   - consistent binning between detection and reference sets via shared edges
   - matching is done only within the same bin, reducing ambiguity

2) Two-stage hypothesis scoring with accuracy enhancements:
   - Stage 1: weighted scoring on a small subset with multi-probe hashing
   - Stage 2: mutual nearest-neighbor matching with weighted scoring
   - Iterative affine re-matching to grow the match set
   - Progressive sigma-clipping in final WCS refinement

Inputs
------
obj : astropy.table.Table with 'x','y' and either 'flux' or 'mag'
cat : astropy.table.Table with 'ra','dec','mag'
wcs_init : astropy.wcs.WCS (rough)

Outputs
-------
refined_wcs : astropy.wcs.WCS
match : astropy.table.Table
diagnostics : dict

Deps: numpy, scipy, astropy
"""

from __future__ import annotations

import dataclasses
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
from scipy.spatial import cKDTree

from astropy.coordinates import SkyCoord
import astropy.units as u
from astropy.table import Table
from astropy.wcs import WCS
from stdpipe.astrometry_wcs import fit_wcs_from_points


# -----------------------------
# Utilities
# -----------------------------


def _as_float_array(x) -> np.ndarray:
    a = np.asarray(x)
    return a.astype(np.float64, copy=False)


def _finite_mask(*cols) -> np.ndarray:
    m = np.ones(len(cols[0]), dtype=bool)
    for c in cols:
        c = np.asarray(c)
        m &= np.isfinite(c)
    return m


def _pick_brightest_obj(obj: Table, n: int) -> np.ndarray:
    if "flux" in obj.colnames:
        v = _as_float_array(obj["flux"])
        key = -v  # higher flux = brighter
    elif "mag" in obj.colnames:
        v = _as_float_array(obj["mag"])
        key = v  # lower mag = brighter
    else:
        raise ValueError("obj must contain 'flux' or 'mag'")
    idx = np.argsort(key)
    return idx[: min(n, len(idx))]


def _pick_brightest_cat(cat: Table, n: int) -> np.ndarray:
    v = _as_float_array(cat["mag"])
    idx = np.argsort(v)
    return idx[: min(n, len(idx))]


def _robust_sigma(x: np.ndarray) -> float:
    x = np.asarray(x, dtype=np.float64)
    med = np.nanmedian(x)
    mad = np.nanmedian(np.abs(x - med))
    return 1.4826 * mad if mad > 0 else np.nanstd(x)


[docs] def mag_signature(mags4: np.ndarray) -> Tuple[int, int, int, int]: mags4 = np.asarray(mags4, dtype=np.float64) if len(mags4) != 4: raise ValueError(f"Expected 4 magnitudes, got {len(mags4)}") return tuple(np.argsort(mags4).astype(int).tolist())
# ----------------------------- # TAN projection # -----------------------------
[docs] def tan_project_deg( ra_deg: np.ndarray, dec_deg: np.ndarray, ra0_deg: float, dec0_deg: float ) -> Tuple[np.ndarray, np.ndarray]: ra = np.deg2rad(_as_float_array(ra_deg)) dec = np.deg2rad(_as_float_array(dec_deg)) ra0 = np.deg2rad(float(ra0_deg)) dec0 = np.deg2rad(float(dec0_deg)) dra = (ra - ra0 + np.pi) % (2 * np.pi) - np.pi sin_dec, cos_dec = np.sin(dec), np.cos(dec) sin_dec0, cos_dec0 = np.sin(dec0), np.cos(dec0) cos_dra = np.cos(dra) sin_dra = np.sin(dra) denom = sin_dec0 * sin_dec + cos_dec0 * cos_dec * cos_dra eps = 1e-12 denom = np.where(np.abs(denom) < eps, np.sign(denom) * eps, denom) u = (cos_dec * sin_dra) / denom v = (cos_dec0 * sin_dec - sin_dec0 * cos_dec * cos_dra) / denom return u, v
# ----------------------------- # Similarity transform (Umeyama) # -----------------------------
[docs] def estimate_similarity_2d( A: np.ndarray, B: np.ndarray, allow_reflection: bool = True ) -> Tuple[np.ndarray, np.ndarray, float]: A = np.asarray(A, dtype=np.float64) B = np.asarray(B, dtype=np.float64) if A.shape != B.shape or A.shape[1] != 2: raise ValueError("A and B must be (N,2) and same shape") mu_A = A.mean(axis=0) mu_B = B.mean(axis=0) X = A - mu_A Y = B - mu_B var_A = np.sum(X**2) / A.shape[0] if var_A <= 0: raise ValueError("Degenerate A variance") Sigma = (Y.T @ X) / A.shape[0] U, D, Vt = np.linalg.svd(Sigma) R = U @ Vt if not allow_reflection and np.linalg.det(R) < 0: U[:, -1] *= -1 R = U @ Vt # If allow_reflection=True, accept the reflection as-is s = np.sum(D) / var_A t = mu_B - s * (R @ mu_A) return R, t, float(s)
[docs] def apply_similarity(xy: np.ndarray, R: np.ndarray, t: np.ndarray, s: float) -> np.ndarray: xy = np.asarray(xy, dtype=np.float64) return (s * (xy @ R.T)) + t
# ----------------------------- # Affine transform # -----------------------------
[docs] def estimate_affine_2d(A: np.ndarray, B: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Fit affine transform B = A @ M.T + t via least squares (6 DOF). More flexible than similarity (4 DOF) - handles shear and non-square pixels. Returns (M, t) where M is (2,2) linear part and t is (2,) translation. """ A = np.asarray(A, dtype=np.float64) B = np.asarray(B, dtype=np.float64) n = A.shape[0] if n < 3: raise ValueError("Need at least 3 points for affine fit") A_aug = np.hstack([A, np.ones((n, 1))]) # (n, 3) cond = np.linalg.cond(A_aug) if cond > 1e12: raise ValueError(f"Degenerate affine fit (condition number {cond:.1e})") X, _, _, _ = np.linalg.lstsq(A_aug, B, rcond=None) M = X[:2].T # (2, 2) t = X[2] # (2,) if not (np.all(np.isfinite(M)) and np.all(np.isfinite(t))): raise ValueError("Degenerate affine fit (non-finite values)") return M, t
[docs] def apply_affine(xy: np.ndarray, M: np.ndarray, t: np.ndarray) -> np.ndarray: """Apply affine transform: result = xy @ M.T + t Raises ValueError if the result contains non-finite values. """ xy = np.asarray(xy, dtype=np.float64) with np.errstate(over="ignore", invalid="ignore", divide="ignore"): result = xy @ M.T + t if not np.all(np.isfinite(result)): raise ValueError("Affine transform produced non-finite values") return result
# ----------------------------- # Mutual nearest-neighbor matching # ----------------------------- def _mutual_nearest_neighbor( xy_a: np.ndarray, xy_b: np.ndarray, radius: float, tree_b: Optional[cKDTree] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Find mutual nearest-neighbor pairs between two point sets within radius. Returns ``(idx_a, idx_b)`` arrays of matched indices where each pair is the mutual closest match within the given radius. Parameters ---------- xy_a : (N, 2) ndarray First point set. xy_b : (M, 2) ndarray Second point set. radius : float Maximum matching distance. tree_b : cKDTree, optional Pre-built cKDTree for ``xy_b`` (avoids rebuilding). Returns ------- idx_a : ndarray of int Indices into ``xy_a`` for matched pairs. idx_b : ndarray of int Indices into ``xy_b`` for matched pairs. """ if len(xy_a) == 0 or len(xy_b) == 0: return np.array([], dtype=np.int32), np.array([], dtype=np.int32) tree_a = cKDTree(xy_a) if tree_b is None: tree_b = cKDTree(xy_b) # Forward: for each point in A, find nearest in B dist_fwd, nn_fwd = tree_b.query(xy_a, k=1, distance_upper_bound=radius) # Reverse: for each point in B, find nearest in A dist_rev, nn_rev = tree_a.query(xy_b, k=1, distance_upper_bound=radius) ok_fwd = np.isfinite(dist_fwd) & (dist_fwd < radius) & (nn_fwd < len(xy_b)) a_ids = np.nonzero(ok_fwd)[0] b_ids = nn_fwd[a_ids].astype(int) # Mutual check: B's nearest neighbor back to A must be the same point mutual = nn_rev[b_ids] == a_ids return a_ids[mutual], b_ids[mutual] # ----------------------------- # Quad generation & descriptors # -----------------------------
[docs] def make_local_quads(points: np.ndarray, k: int = 8) -> List[Tuple[int, int, int, int]]: points = np.asarray(points, dtype=np.float64) n = points.shape[0] if n < 4: return [] k = int(min(k, max(3, n - 1))) tree = cKDTree(points) _, idxs = tree.query(points, k=k + 1) quads: List[Tuple[int, int, int, int]] = [] for i in range(n): neigh = idxs[i, 1:] for a in range(len(neigh) - 2): for b in range(a + 1, len(neigh) - 1): for c in range(b + 1, len(neigh)): quads.append((i, int(neigh[a]), int(neigh[b]), int(neigh[c]))) return quads
[docs] def quad_descriptor(P: np.ndarray) -> Tuple[np.ndarray, float]: """ Descriptor invariant to translation/rotation/scale. Returns (desc[4], baseline_len). """ P = np.asarray(P, dtype=np.float64) d2 = np.sum((P[None, :, :] - P[:, None, :]) ** 2, axis=2) i, j = np.unravel_index(np.argmax(d2), d2.shape) if i == j: return np.full(4, np.nan), np.nan A = P[i] B = P[j] base = B - A L = float(np.hypot(base[0], base[1])) if L <= 0: return np.full(4, np.nan), np.nan e1 = base / L e2 = np.array([-e1[1], e1[0]], dtype=np.float64) idx = [k for k in range(4) if k not in (i, j)] C = P[idx[0]] D = P[idx[1]] c = np.array([np.dot(C - A, e1), np.dot(C - A, e2)], dtype=np.float64) / L d = np.array([np.dot(D - A, e1), np.dot(D - A, e2)], dtype=np.float64) / L pair = np.stack([c, d], axis=0) order = np.lexsort((pair[:, 1], pair[:, 0])) pair = pair[order] desc = np.array([pair[0, 0], pair[0, 1], pair[1, 0], pair[1, 1]], dtype=np.float64) return desc, L
# Phase 2 Optimization: Vectorized quad descriptor calculation
[docs] def quad_descriptor_batch( points: np.ndarray, quad_indices: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ Vectorized quad descriptor calculation for multiple quads. Args: points: (N, 2) array of point coordinates quad_indices: (M, 4) array of quad indices into points Returns: descs: (M, 4) array of descriptors lens: (M,) array of baseline lengths """ M = quad_indices.shape[0] descs = np.full((M, 4), np.nan, dtype=np.float64) lens = np.full(M, np.nan, dtype=np.float64) if M == 0: return descs, lens # Gather all quad points: (M, 4, 2) P = points[quad_indices] # Compute pairwise distances for all quads: (M, 4, 4) diff = P[:, None, :, :] - P[:, :, None, :] # (M, 4, 4, 2) d2 = np.sum(diff**2, axis=3) # (M, 4, 4) # Find longest baseline for each quad flat_idx = np.argmax(d2.reshape(M, -1), axis=1) i_idx = flat_idx // 4 j_idx = flat_idx % 4 # Check for degenerate quads (i == j) valid = i_idx != j_idx # Process only valid quads if not np.any(valid): return descs, lens valid_mask = np.where(valid)[0] M_valid = len(valid_mask) # Extract baseline points for valid quads A = P[valid_mask, i_idx[valid_mask]] # (M_valid, 2) B = P[valid_mask, j_idx[valid_mask]] # (M_valid, 2) # Compute baseline vectors and lengths base = B - A # (M_valid, 2) L = np.hypot(base[:, 0], base[:, 1]) # (M_valid,) # Filter out zero-length baselines valid_L = L > 0 if not np.any(valid_L): return descs, lens final_mask = valid_mask[valid_L] M_final = len(final_mask) # Update for final valid quads A = A[valid_L] B = B[valid_L] base = base[valid_L] L = L[valid_L] i_idx_final = i_idx[final_mask] j_idx_final = j_idx[final_mask] P_final = P[final_mask] # Compute orthonormal basis e1 = base / L[:, None] # (M_final, 2) e2 = np.stack([-e1[:, 1], e1[:, 0]], axis=1) # (M_final, 2) # Find the other two points (C, D) for each quad (vectorized) # Create a mask for each quad indicating which indices are NOT i or j all_indices = np.arange(4) mask = np.ones((M_final, 4), dtype=bool) mask[np.arange(M_final), i_idx_final] = False mask[np.arange(M_final), j_idx_final] = False # Find the two other indices for each quad other_indices = np.where(mask) # Reshape to (M_final, 2) - two "other" indices per quad other_indices_reshaped = other_indices[1].reshape(M_final, 2) C_indices = other_indices_reshaped[:, 0] D_indices = other_indices_reshaped[:, 1] C = P_final[np.arange(M_final), C_indices] # (M_final, 2) D = P_final[np.arange(M_final), D_indices] # (M_final, 2) # Project C and D onto basis and normalize by L C_vec = C - A D_vec = D - A c_coords = ( np.stack([np.sum(C_vec * e1, axis=1), np.sum(C_vec * e2, axis=1)], axis=1) / L[:, None] ) # (M_final, 2) d_coords = ( np.stack([np.sum(D_vec * e1, axis=1), np.sum(D_vec * e2, axis=1)], axis=1) / L[:, None] ) # (M_final, 2) # Sort coordinates lexicographically for each quad (vectorized) # Stack c and d coordinates: (M_final, 2, 2) - 2 points × 2 coords pairs = np.stack([c_coords, d_coords], axis=1) # (M_final, 2, 2) # Determine which point should come first (lexicographic order) # First compare x-coordinates ([:,0]), then y-coordinates ([:,1]) swap = (pairs[:, 0, 0] > pairs[:, 1, 0]) | ( (pairs[:, 0, 0] == pairs[:, 1, 0]) & (pairs[:, 0, 1] > pairs[:, 1, 1]) ) # Swap where needed pairs[swap] = pairs[swap, ::-1, :] # Flatten to descriptor format: [x1, y1, x2, y2] descs[final_mask] = pairs.reshape(M_final, 4) lens[final_mask] = L return descs, lens
[docs] def quantize_desc(desc: np.ndarray, eps: float) -> Tuple[int, int, int, int]: q = np.floor(desc / eps + 0.5).astype(np.int64) return int(q[0]), int(q[1]), int(q[2]), int(q[3])
[docs] def reflect_desc(desc: np.ndarray) -> np.ndarray: """Reflect quad descriptor across its baseline (flip y) with lexicographic re-ordering.""" x1, y1, x2, y2 = (float(desc[0]), float(desc[1]), float(desc[2]), float(desc[3])) y1 = -y1 y2 = -y2 if (x1 > x2) or (x1 == x2 and y1 > y2): x1, y1, x2, y2 = x2, y2, x1, y1 return np.array([x1, y1, x2, y2], dtype=np.float64)
# Phase 1 Optimization: Pre-computed neighbor offsets cache _NEIGHBOR_OFFSETS_CACHE: Dict[int, np.ndarray] = {} def _get_neighbor_offsets(r: int) -> np.ndarray: """Get pre-computed neighbor offsets for given radius.""" if r not in _NEIGHBOR_OFFSETS_CACHE: x = np.arange(-r, r + 1, dtype=np.int32) # Create meshgrid and reshape to (n_neighbors, 4) grid = np.meshgrid(x, x, x, x, indexing='ij') offsets = np.stack([g.ravel() for g in grid], axis=1) _NEIGHBOR_OFFSETS_CACHE[r] = offsets return _NEIGHBOR_OFFSETS_CACHE[r]
[docs] def neighbors_bins( key: Tuple[int, int, int, int], r: int = 1 ) -> Iterable[Tuple[int, int, int, int]]: """Generate neighbor bins using optimized pre-computed offsets.""" offsets = _get_neighbor_offsets(r) key_arr = np.array(key, dtype=np.int32) neighbors = key_arr + offsets # Return as tuples of Python ints for compatibility (important for dict keys) for neighbor in neighbors: yield tuple(int(x) for x in neighbor)
[docs] def multiprobe_desc_keys( desc: np.ndarray, eps: float, threshold: float = 0.3 ) -> List[Tuple[int, int, int, int]]: """Multi-probe hashing: return the primary quantized key plus keys for nearby bins where the descriptor is close to a bin boundary. Instead of searching all (2r+1)^4 = 81 neighboring bins, this probes only the bins that the descriptor might fall into due to quantization noise. Typically returns 1-8 keys instead of 81. Args: desc: 4D descriptor array eps: quantization step size threshold: fraction of eps within which a boundary is considered "near" """ q = desc / eps + 0.5 base = np.floor(q).astype(np.int64) frac = q - base # fractional part in [0, 1) primary = tuple(int(x) for x in base) keys = [primary] # Find dimensions near bin boundaries near_dims = [] offsets_per_dim = {} for d in range(4): if frac[d] < threshold: near_dims.append(d) offsets_per_dim[d] = -1 elif frac[d] > (1.0 - threshold): near_dims.append(d) offsets_per_dim[d] = 1 # Single-dimension probes for d in near_dims: alt = list(base) alt[d] += offsets_per_dim[d] keys.append(tuple(int(x) for x in alt)) # Two-dimension combination probes for i in range(len(near_dims)): for j in range(i + 1, len(near_dims)): d1, d2 = near_dims[i], near_dims[j] alt = list(base) alt[d1] += offsets_per_dim[d1] alt[d2] += offsets_per_dim[d2] keys.append(tuple(int(x) for x in alt)) return keys
[docs] def compute_bin_edges(lengths: np.ndarray, n_bins: int) -> Optional[np.ndarray]: """Compute quantile-based bin edges from a set of baseline lengths. Returns array of (n_bins+1) edge values, or None if binning not possible. """ lengths = np.asarray(lengths, dtype=np.float64) if len(lengths) == 0 or n_bins <= 1: return None qs = np.quantile(lengths, q=np.linspace(0, 1, n_bins + 1)) for i in range(1, len(qs)): if qs[i] <= qs[i - 1]: qs[i] = np.nextafter(qs[i - 1], np.inf) return qs
[docs] def baseline_rank_bins( lengths: np.ndarray, n_bins: int, edges: Optional[np.ndarray] = None ) -> np.ndarray: """ Assign each length to a *rank* bin based on quantiles. If edges are provided, use those instead of computing from the data. This allows consistent binning between detection and reference sets. """ lengths = np.asarray(lengths, dtype=np.float64) n = len(lengths) if n == 0: return np.zeros(0, dtype=np.int16) if n_bins <= 1: return np.zeros(n, dtype=np.int16) if edges is not None: b = np.searchsorted(edges[1:-1], lengths, side="right") return np.clip(b, 0, n_bins - 1).astype(np.int16) qs = np.quantile(lengths, q=np.linspace(0, 1, n_bins + 1)) # make strictly increasing to avoid pathological equal-quantile edges for i in range(1, len(qs)): if qs[i] <= qs[i - 1]: qs[i] = np.nextafter(qs[i - 1], np.inf) b = np.searchsorted(qs[1:-1], lengths, side="right") return b.astype(np.int16)
[docs] @dataclass class QuadEntry: idxs: Tuple[int, int, int, int] desc: np.ndarray bin_id: int mags: Optional[Tuple[float, float, float, float]] = None
[docs] def build_quad_hash_multiscale( points: np.ndarray, mags: Optional[np.ndarray], k: int, eps: float, n_scale_bins: int, allow_reflection: bool = True, bin_edges: Optional[np.ndarray] = None, quad_array: Optional[np.ndarray] = None, descs: Optional[np.ndarray] = None, lens: Optional[np.ndarray] = None, ) -> Dict[Tuple[int, Tuple[int, int, int, int]], List[QuadEntry]]: """ Hash key: (scale_bin_id, quantized_descriptor) If bin_edges is provided, use those for baseline binning instead of computing quantiles from the data. This enables consistent binning between detection and reference sets. Phase 2 Optimization: Uses vectorized quad descriptor calculation. """ if quad_array is None or descs is None or lens is None: quads = make_local_quads(points, k=k) if not quads: return {} quad_array = np.array(quads, dtype=np.int32) descs, lens = quad_descriptor_batch(points, quad_array) else: quad_array = np.asarray(quad_array, dtype=np.int32) descs = np.asarray(descs, dtype=np.float64) lens = np.asarray(lens, dtype=np.float64) if quad_array.ndim != 2 or quad_array.shape[1] != 4: raise ValueError("quad_array must be (N, 4)") if descs.shape != (len(quad_array), 4): raise ValueError("descs must have shape (N, 4)") if lens.shape != (len(quad_array),): raise ValueError("lens must have shape (N,)") # Phase 1: Use boolean indexing instead of list comprehension ok = np.isfinite(descs).all(axis=1) & np.isfinite(lens) & (lens > 0) quads_ok = quad_array[ok] # Keep as numpy array descs_ok = descs[ok] lens_ok = lens[ok] if len(quads_ok) == 0: return {} bins = baseline_rank_bins(lens_ok, n_bins=n_scale_bins, edges=bin_edges) # Phase 1: Pre-allocate and use numpy operations H: Dict[Tuple[int, Tuple[int, int, int, int]], List[QuadEntry]] = {} for i in range(len(quads_ok)): q = tuple(quads_ok[i]) # Convert to tuple for QuadEntry desc = descs_ok[i] b = int(bins[i]) me = None if mags is not None: me = tuple(float(mags[idx]) for idx in q) entry = QuadEntry(idxs=q, desc=desc, bin_id=b, mags=me) # Insert descriptor and its reflected counterpart to handle parity flips key = (b, quantize_desc(desc, eps)) H.setdefault(key, []).append(entry) if allow_reflection: desc_ref = reflect_desc(desc) key_ref = (b, quantize_desc(desc_ref, eps)) if key_ref != key: H.setdefault(key_ref, []).append(entry) return H
class _PatternMatchFailed(RuntimeError): """Internal sentinel for pattern matching failure (used by adaptive retry).""" pass def _auto_match_resolution(det_xy: np.ndarray, ref_uv: np.ndarray) -> Optional[float]: """Compute matching resolution from source density (SCAMP-inspired). Uses the mean inter-source spacing as confusion limit: matchresol = sqrt(field_area / n_sources) Returns the confusion radius in pixels, or None if it can't be computed. """ # Use the smaller of the two sets (like SCAMP's cross-section approach) n = min(len(det_xy), len(ref_uv)) if n < 4: return None # Estimate field area from convex hull or bounding box of detections x_range = det_xy[:, 0].max() - det_xy[:, 0].min() y_range = det_xy[:, 1].max() - det_xy[:, 1].min() field_area = x_range * y_range if field_area <= 0: return None # Mean area per source → confusion radius mean_area = field_area / n confusion_radius = np.sqrt(mean_area) return confusion_radius # ----------------------------- # Solver # -----------------------------
[docs] @dataclass class AstrometryConfig: # Selection sizes n_det: int = 180 n_ref: int = 380 # Quads neighbor_k: int = 8 eps_desc: float = 0.015 bin_neighbor_radius: int = 1 n_scale_bins: int = 6 allow_reflection: bool = True use_mag_signature: bool = True # WCS prior constraints (use initial WCS to filter hypotheses) use_wcs_prior: bool = True scale_tolerance: float = 0.30 # fractional, e.g. 0.30 = +/-30% enforce_parity: bool = True # Two-stage scoring stage1_n_det: int = 50 # cheap scoring uses only brightest subset of selected detections stage1_radius_arcsec: float = 20.0 # loose stage2_radius_arcsec: float = 8.0 # tighter top_k_hypotheses: int = 60 # keep best hypotheses from stage 1 max_quads_tested: int = 5000 # max detection quads to try max_candidates_per_bucket: int = 60 # cap ref candidates pulled per hash bucket # Multi-probe hashing: probe only nearby bins instead of all (2r+1)^4 use_multiprobe: bool = True # Final fit sip_degree: int = 3 refine_clip_sigma: float = 4.0 refine_clip_sigma_start: float = 5.0 # progressive clipping: start value refine_min_match_fraction: float = 0.5 # keep at least this fraction of initial matches refine_max_iter: int = 5 refine_rematch_iters: int = 2 # iterative affine re-matching rounds # Optional expanded refinement pool (use many more objects for final fit) refine_use_all: bool = False refine_n_det: Optional[int] = None refine_n_ref: Optional[int] = None refine_match_radius_arcsec: Optional[float] = None # Auto matching resolution: compute stage radii from source density auto_match_resolution: bool = True # Adaptive source count: retry with more sources if matching fails adaptive_n_retry: int = 2 # number of retry doublings (0 = disabled) adaptive_min_inliers: int = 12 # minimum inliers to accept without retry
[docs] class QuadHashAstrometry: def __init__(self, config: AstrometryConfig = AstrometryConfig()): self.cfg = config def _pattern_match( self, det_xy: np.ndarray, det_mag: np.ndarray, ref_uv: np.ndarray, ref_mag: np.ndarray, wcs_init: WCS, pixel_scale_arcsec: float, cfg: Optional['AstrometryConfig'] = None, ) -> Tuple[list, list, dict]: """Run quad-hash pattern matching (stages 1 & 2) + affine re-matching. Returns (pairs, top_hyp, best) where pairs is list of (det_idx, ref_idx). Raises _PatternMatchFailed if matching fails. """ if cfg is None: cfg = self.cfg # WCS prior: in pixel space expected scale ≈ 1.0, compute parity from WCS expected_scale = None expected_parity = None # sign of determinant if cfg.use_wcs_prior: expected_scale = 1.0 # Both sets in pixel space try: crpix = np.asarray(wcs_init.wcs.crpix, dtype=np.float64) if crpix.size == 2 and np.all(np.isfinite(crpix)): px = np.array([crpix[0], crpix[0] + 1.0, crpix[0]], dtype=np.float64) py = np.array([crpix[1], crpix[1], crpix[1] + 1.0], dtype=np.float64) ra_s, dec_s = wcs_init.all_pix2world(px, py, 1) x_s, y_s = wcs_init.all_world2pix(ra_s, dec_s, 0) J = np.array( [ [x_s[1] - x_s[0], x_s[2] - x_s[0]], [y_s[1] - y_s[0], y_s[2] - y_s[0]], ], dtype=np.float64, ) det_j = float(np.linalg.det(J)) if np.isfinite(det_j) and det_j != 0: expected_parity = 1.0 if det_j > 0 else -1.0 except Exception: expected_parity = None # --- Consistent baseline binning --- ref_qa = None ref_descs_raw = None ref_lens_raw = None ref_bin_edges = None ref_quads_raw = make_local_quads(ref_uv, k=cfg.neighbor_k) if ref_quads_raw: ref_qa = np.array(ref_quads_raw, dtype=np.int32) ref_descs_raw, ref_lens_raw = quad_descriptor_batch(ref_uv, ref_qa) ok_ref = np.isfinite(ref_lens_raw) & (ref_lens_raw > 0) if np.any(ok_ref): ref_bin_edges = compute_bin_edges(ref_lens_raw[ok_ref], cfg.n_scale_bins) # Reference multiscale hash ref_hash = build_quad_hash_multiscale( ref_uv, ref_mag, k=cfg.neighbor_k, eps=cfg.eps_desc, n_scale_bins=cfg.n_scale_bins, allow_reflection=cfg.allow_reflection, bin_edges=ref_bin_edges, quad_array=ref_qa, descs=ref_descs_raw, lens=ref_lens_raw, ) # Detection quads det_quads = make_local_quads(det_xy, k=cfg.neighbor_k) if not det_quads: raise _PatternMatchFailed("Not enough detections for quad matching.") det_quad_array = np.array(det_quads, dtype=np.int32) det_descs, det_lens = quad_descriptor_batch(det_xy, det_quad_array) okq = np.isfinite(det_descs).all(axis=1) & np.isfinite(det_lens) & (det_lens > 0) det_quads = det_quad_array[okq] det_descs = det_descs[okq] det_lens = det_lens[okq] if ref_bin_edges is not None: det_bins = baseline_rank_bins(det_lens, n_bins=cfg.n_scale_bins, edges=ref_bin_edges) else: det_bins = baseline_rank_bins(det_lens, n_bins=cfg.n_scale_bins) det_quads_quantized = np.array( [quantize_desc(desc, cfg.eps_desc) for desc in det_descs], dtype=object ) ref_tree = cKDTree(ref_uv) # Stage radii in pixels r1 = cfg.stage1_radius_arcsec / pixel_scale_arcsec r2 = cfg.stage2_radius_arcsec / pixel_scale_arcsec # Stage1 detection subset order_det = np.argsort(det_mag) det_stage1_ids = order_det[: min(cfg.stage1_n_det, len(order_det))] det_xy_stage1 = det_xy[det_stage1_ids] top_hyp: List[Tuple[float, np.ndarray, np.ndarray, float]] = [] rng = np.random.default_rng(12345) q_order = np.arange(len(det_quads)) rng.shuffle(q_order) q_order = q_order[: min(cfg.max_quads_tested, len(q_order))] def _try_insert(score: float, R: np.ndarray, t: np.ndarray, s: float): nonlocal top_hyp if len(top_hyp) < cfg.top_k_hypotheses: top_hyp.append((score, R, t, s)) return worst_i = int(np.argmin([h[0] for h in top_hyp])) if score > top_hyp[worst_i][0]: top_hyp[worst_i] = (score, R, t, s) # --- Stage 1 --- for qi in q_order: q = det_quads[qi] bin_id = int(det_bins[qi]) kdesc = det_quads_quantized[qi] candidates: List[QuadEntry] = [] if cfg.use_multiprobe: for kdesc2 in multiprobe_desc_keys(det_descs[qi], cfg.eps_desc): kk = (bin_id, kdesc2) if kk in ref_hash: candidates.extend(ref_hash[kk]) else: for kdesc2 in neighbors_bins(kdesc, r=cfg.bin_neighbor_radius): kk = (bin_id, kdesc2) if kk in ref_hash: candidates.extend(ref_hash[kk]) if not candidates: continue det_sig = mag_signature(det_mag[q]) if cfg.use_mag_signature else None rng.shuffle(candidates) for ce in candidates[: cfg.max_candidates_per_bucket]: if det_sig is not None and ce.mags is not None: if mag_signature(np.array(ce.mags)) != det_sig: continue P = det_xy[q] Q = ref_uv[list(ce.idxs)] try: R, t, s = estimate_similarity_2d(P, Q, allow_reflection=cfg.allow_reflection) except Exception: continue if expected_scale is not None and cfg.scale_tolerance is not None: tol = float(cfg.scale_tolerance) if tol >= 0: if not (expected_scale * (1 - tol) <= s <= expected_scale * (1 + tol)): continue if expected_parity is not None and cfg.enforce_parity and cfg.allow_reflection: detR = float(np.linalg.det(R)) if not np.isfinite(detR) or np.sign(detR) != expected_parity: continue det_uv1 = apply_similarity(det_xy_stage1, R, t, s) dist, nn = ref_tree.query(det_uv1, k=1, distance_upper_bound=r1) ok = np.isfinite(dist) & (dist < r1) & (nn < len(ref_uv)) if not np.any(ok): continue weights = 1.0 - (dist[ok] / r1) ** 2 score = float(np.sum(weights)) _try_insert(score, R, t, s) if not top_hyp: raise _PatternMatchFailed( "Pattern matching failed at stage 1. " "Try increasing n_det/n_ref, eps_desc, or stage1_radius_arcsec." ) # --- Stage 2 --- top_hyp.sort(key=lambda x: x[0], reverse=True) best = {"score": -1.0, "inliers": -1, "R": None, "t": None, "s": None, "pairs": None} for score1, R, t, s in top_hyp: if expected_scale is not None and cfg.scale_tolerance is not None: tol = float(cfg.scale_tolerance) if tol >= 0: if not (expected_scale * (1 - tol) <= s <= expected_scale * (1 + tol)): continue if expected_parity is not None and cfg.enforce_parity and cfg.allow_reflection: detR = float(np.linalg.det(R)) if not np.isfinite(detR) or np.sign(detR) != expected_parity: continue det_uv = apply_similarity(det_xy, R, t, s) det_ids, ref_ids = _mutual_nearest_neighbor(det_uv, ref_uv, r2, tree_b=ref_tree) if len(det_ids) == 0: continue dists = np.hypot( det_uv[det_ids, 0] - ref_uv[ref_ids, 0], det_uv[det_ids, 1] - ref_uv[ref_ids, 1] ) weights = 1.0 - (dists / r2) ** 2 score = float(np.sum(weights)) pairs = list(zip(det_ids.tolist(), ref_ids.tolist())) if score > best["score"]: best = { "score": score, "inliers": len(pairs), "R": R, "t": t, "s": s, "pairs": pairs, } if best["inliers"] < 8: raise _PatternMatchFailed( f"Pattern matching failed at stage 2 (best inliers={best['inliers']}). " f"Try loosening stage2_radius_arcsec or increasing top_k_hypotheses." ) # --- Iterative affine re-matching --- pairs = best["pairs"] r_rematch = r2 for rematch_it in range(cfg.refine_rematch_iters): if len(pairs) < 6: break pairs_arr = np.array(pairs, dtype=np.int32) A_pts = det_xy[pairs_arr[:, 0]] B_pts = ref_uv[pairs_arr[:, 1]] try: M_affine, t_affine = estimate_affine_2d(A_pts, B_pts) det_uv_affine = apply_affine(det_xy, M_affine, t_affine) except Exception: break r_iter = r_rematch * (0.8 if rematch_it > 0 else 1.0) new_det_ids, new_ref_ids = _mutual_nearest_neighbor( det_uv_affine, ref_uv, r_iter, tree_b=ref_tree ) if len(new_det_ids) < max(8, int(len(pairs) * 0.6)): break pairs = list(zip(new_det_ids.tolist(), new_ref_ids.tolist())) return pairs, top_hyp, best
[docs] def refine( self, obj: Table, cat: Table, wcs_init: WCS, fit_projection: Optional[WCS] = None, pv_deg: int = 5, ) -> Tuple[WCS, Table, dict]: if not all(k in obj.colnames for k in ("x", "y")): raise ValueError("obj must contain 'x' and 'y'") if not all(k in cat.colnames for k in ("ra", "dec", "mag")): raise ValueError("cat must contain 'ra','dec','mag'") x = _as_float_array(obj["x"]) y = _as_float_array(obj["y"]) if "mag" in obj.colnames: m_obj = _as_float_array(obj["mag"]) elif "flux" in obj.colnames: f = _as_float_array(obj["flux"]) f = np.where(f > 0, f, np.nan) m_obj = -2.5 * np.log10(f) else: raise ValueError("obj must contain 'flux' or 'mag'") ra = _as_float_array(cat["ra"]) dec = _as_float_array(cat["dec"]) m_cat = _as_float_array(cat["mag"]) m0 = _finite_mask(x, y, m_obj) x, y, m_obj = x[m0], y[m0], m_obj[m0] m1 = _finite_mask(ra, dec, m_cat) ra, dec, m_cat = ra[m1], dec[m1], m_cat[m1] # Rough sky center (used for residuals and SIP check) try: ra0, dec0 = float(wcs_init.wcs.crval[0]), float(wcs_init.wcs.crval[1]) if not (np.isfinite(ra0) and np.isfinite(dec0)): raise ValueError except Exception: ra0, dec0 = float(np.nanmedian(ra)), float(np.nanmedian(dec)) # Pixel scale for radius conversion (arcsec/pixel) try: pscales = wcs_init.proj_plane_pixel_scales() # proj_plane_pixel_scales() may return Quantity with units if hasattr(pscales[0], 'to_value'): pixel_scale_deg = float(np.mean([s.to_value(u.deg) for s in pscales])) else: pixel_scale_deg = float(np.mean(pscales)) except Exception: pixel_scale_deg = None if pixel_scale_deg is None or pixel_scale_deg <= 0 or not np.isfinite(pixel_scale_deg): raise RuntimeError("Could not determine pixel scale from WCS.") pixel_scale_arcsec = pixel_scale_deg * 3600.0 # --- Adaptive source count retry loop (SCAMP-inspired) --- # Try pattern matching with increasing source counts. On each retry, # double n_det/n_ref to bring in more sources for matching. # Pre-sort by brightness (avoids re-sorting each retry) obj_order = np.argsort(m_obj) # lower mag = brighter cat_order = np.argsort(m_cat) # Project full catalog to pixel space once (slice per retry) all_ref_xy = np.column_stack(wcs_init.all_world2pix(ra, dec, 0)) all_ref_finite = np.all(np.isfinite(all_ref_xy), axis=1) # Local config copy so auto_match_resolution doesn't mutate self.cfg cfg = dataclasses.replace(self.cfg) n_retries = max(0, cfg.adaptive_n_retry) last_error = None pairs = None for _attempt in range(n_retries + 1): scale_factor = 2**_attempt cur_n_det = min(cfg.n_det * scale_factor, len(x)) cur_n_ref = min(cfg.n_ref * scale_factor, len(ra)) # Select bright subsets (pre-sorted, just slice) obj_sub_idx = obj_order[:cur_n_det] cat_sub_idx = cat_order[:cur_n_ref] det_xy = np.column_stack([x[obj_sub_idx], y[obj_sub_idx]]) det_mag = m_obj[obj_sub_idx] # Slice pre-projected catalog and filter non-finite ref_uv = all_ref_xy[cat_sub_idx] ref_mag = m_cat[cat_sub_idx] ref_ok = all_ref_finite[cat_sub_idx] if not np.all(ref_ok): ref_uv = ref_uv[ref_ok] ref_mag = ref_mag[ref_ok] cat_sub_idx = cat_sub_idx[ref_ok] if len(det_xy) < 8 or len(ref_uv) < 8: last_error = _PatternMatchFailed("Not enough points after selection.") continue # Auto matching resolution: widen stage radii from source density if cfg.auto_match_resolution: auto_r_pix = _auto_match_resolution(det_xy, ref_uv) if auto_r_pix is not None: auto_r_arcsec = auto_r_pix * pixel_scale_arcsec cfg = dataclasses.replace( cfg, stage1_radius_arcsec=max(cfg.stage1_radius_arcsec, auto_r_arcsec * 2.5), stage2_radius_arcsec=max(cfg.stage2_radius_arcsec, auto_r_arcsec), ) try: pairs, top_hyp, best = self._pattern_match( det_xy, det_mag, ref_uv, ref_mag, wcs_init, pixel_scale_arcsec, cfg=cfg, ) except _PatternMatchFailed as e: last_error = e if cur_n_det < len(x) or cur_n_ref < len(ra): continue else: break # Check if we have enough inliers if len(pairs) >= cfg.adaptive_min_inliers: break # Not enough inliers — retry with more sources if possible last_error = _PatternMatchFailed( f"Only {len(pairs)} inliers (need {cfg.adaptive_min_inliers})" ) if cur_n_det >= len(x) and cur_n_ref >= len(ra): break # can't add more sources if pairs is None or len(pairs) < 8: raise RuntimeError(str(last_error) if last_error else "Pattern matching failed.") # Build match list for final fit pairs_arr = np.array(pairs, dtype=np.int32) det_indices = pairs_arr[:, 0] ref_indices = pairs_arr[:, 1] det_match_idx = obj_sub_idx[det_indices] ref_match_idx = cat_sub_idx[ref_indices] seed_matches = int(len(det_match_idx)) used_expanded = False refine_pool_det = int(len(det_match_idx)) refine_pool_ref = int(len(ref_match_idx)) # Optional: expand refinement pool using many more detections/catalog entries if self.cfg.refine_use_all: det_pool_idx = np.arange(len(x)) if self.cfg.refine_n_det is not None and self.cfg.refine_n_det > 0: det_pool_idx = np.argsort(m_obj)[: min(self.cfg.refine_n_det, len(m_obj))] ref_pool_idx = np.arange(len(ra)) if self.cfg.refine_n_ref is not None and self.cfg.refine_n_ref > 0: ref_pool_idx = np.argsort(m_cat)[: min(self.cfg.refine_n_ref, len(m_cat))] if len(det_pool_idx) >= 8 and len(ref_pool_idx) >= 8: det_xy_pool = np.column_stack([x[det_pool_idx], y[det_pool_idx]]) # Project expanded catalog to pixel space using initial WCS ref_x_all, ref_y_all = wcs_init.all_world2pix( ra[ref_pool_idx], dec[ref_pool_idx], 0 ) ref_uv_all = np.column_stack([ref_x_all, ref_y_all]) # Filter non-finite _ref_ok = np.all(np.isfinite(ref_uv_all), axis=1) if not np.all(_ref_ok): ref_uv_all = ref_uv_all[_ref_ok] ref_pool_idx = ref_pool_idx[_ref_ok] # Use affine from last rematch iteration for better projection try: _pa = np.array(pairs, dtype=np.int32) _A = det_xy[_pa[:, 0]] _B = ref_uv[_pa[:, 1]] M_exp, t_exp = estimate_affine_2d(_A, _B) det_uv_pool = apply_affine(det_xy_pool, M_exp, t_exp) except Exception: det_uv_pool = apply_similarity(det_xy_pool, best["R"], best["t"], best["s"]) r_match_arcsec = self.cfg.refine_match_radius_arcsec if r_match_arcsec is None: r_match_arcsec = self.cfg.stage2_radius_arcsec r_match = float(r_match_arcsec) / pixel_scale_arcsec # Mutual nearest-neighbor for expanded pool exp_det_ids, exp_ref_ids = _mutual_nearest_neighbor( det_uv_pool, ref_uv_all, r_match ) if len(exp_det_ids) >= 8: det_match_idx = det_pool_idx[exp_det_ids] ref_match_idx = ref_pool_idx[exp_ref_ids] used_expanded = True refine_pool_det = int(len(det_pool_idx)) refine_pool_ref = int(len(ref_pool_idx)) det_x = x[det_match_idx] det_y = y[det_match_idx] ref_ra = ra[ref_match_idx] ref_dec = dec[ref_match_idx] ref_m = m_cat[ref_match_idx] det_m = m_obj[det_match_idx] refine_matches_preclip = int(len(det_match_idx)) # Minimum match floor (#6): don't clip below this count min_matches = max(8, int(refine_matches_preclip * self.cfg.refine_min_match_fraction)) # --- Progressive sigma-clipping refinement (#6) --- sky = SkyCoord(ref_ra * u.deg, ref_dec * u.deg, frame="icrs") xy = np.vstack([det_x, det_y]) # (2,N) refined_wcs = None keep = np.ones(xy.shape[1], dtype=bool) # Projection template for WCS fitting (may differ from wcs_init # which is used only for the matching phase) wcs_fit_proj = fit_projection if fit_projection is not None else wcs_init # Progressive clipping: interpolate from start to end sigma clip_start = self.cfg.refine_clip_sigma_start clip_end = self.cfg.refine_clip_sigma n_iter = self.cfg.refine_max_iter for it in range(n_iter): # Linearly interpolate clip threshold from start (loose) to end (tight) if n_iter > 1: frac = it / (n_iter - 1) else: frac = 1.0 clip_sigma = clip_start + (clip_end - clip_start) * frac xy_use = xy[:, keep] sky_use = sky[keep] # Wrapper handles SIP for TAN, PV for ZPN, linear for others refined_wcs = fit_wcs_from_points( xy_use, sky_use, projection=wcs_fit_proj, sip_degree=int(self.cfg.sip_degree), pv_deg=pv_deg, ) wcs_fit_proj = refined_wcs # Residuals via spherical distance (projection-independent) ra_fit, dec_fit = refined_wcs.all_pix2world(xy_use[0], xy_use[1], 0) ref_ra_use = sky_use.ra.deg ref_dec_use = sky_use.dec.deg cos_dec = np.cos(np.deg2rad(0.5 * (dec_fit + ref_dec_use))) dra_deg = (ra_fit - ref_ra_use) * cos_dec ddec_deg = dec_fit - ref_dec_use dr = np.deg2rad(np.hypot(dra_deg, ddec_deg)) sig = _robust_sigma(dr) if not np.isfinite(sig) or sig <= 0: break thresh = clip_sigma * sig dr_full = np.full(keep.shape[0], np.nan, dtype=np.float64) dr_full[keep] = dr keep_new = keep & np.isfinite(dr_full) & (dr_full < thresh) # Enforce minimum match floor (#6) if keep_new.sum() < min_matches: # Don't clip further if we'd go below the minimum. # Keep the current inlier set instead of partially applying # a clip that violates the configured floor. break if keep_new.sum() == keep.sum(): keep = keep_new break keep = keep_new if refined_wcs is None: raise RuntimeError("WCS refinement failed in fit_wcs_from_points stage.") # Output match table final_idx = np.nonzero(keep)[0] match = Table() match["x"] = det_x[final_idx] match["y"] = det_y[final_idx] match["obj_mag"] = det_m[final_idx] match["ra"] = ref_ra[final_idx] match["dec"] = ref_dec[final_idx] match["cat_mag"] = ref_m[final_idx] # Residuals in arcsec using final WCS (spherical, projection-independent) ra_fit, dec_fit = refined_wcs.all_pix2world(match["x"], match["y"], 0) cos_dec_out = np.cos(np.deg2rad(0.5 * (dec_fit + _as_float_array(match["dec"])))) match["du_arcsec"] = (ra_fit - _as_float_array(match["ra"])) * cos_dec_out * 3600.0 match["dv_arcsec"] = (dec_fit - _as_float_array(match["dec"])) * 3600.0 match["dr_arcsec"] = np.hypot(match["du_arcsec"], match["dv_arcsec"]) diagnostics = { "rough_center_deg": (float(ra0), float(dec0)), "stage1_hypotheses_kept": int(len(top_hyp)), "pattern_inliers_stage2": int(best["inliers"]), "seed_matches": int(seed_matches), "refine_used_expanded": bool(used_expanded), "refine_pool_det": int(refine_pool_det), "refine_pool_ref": int(refine_pool_ref), "refine_matches_preclip": int(refine_matches_preclip), "final_matches": int(len(match)), "rms_dr_arcsec": float(np.sqrt(np.mean(match["dr_arcsec"] ** 2))) if len(match) else np.nan, "mad_dr_arcsec": float(_robust_sigma(_as_float_array(match["dr_arcsec"]))) if len(match) else np.nan, "config": self.cfg, } return refined_wcs, match, diagnostics
[docs] def refine_wcs_quadhash( obj: Table, cat: Table, wcs: WCS = None, header=None, sr: float = 10 / 3600, order: int = 2, projection: str = None, pv_deg: int = 5, cat_col_ra: str = 'RAJ2000', cat_col_dec: str = 'DEJ2000', cat_col_mag: str = 'rmag', obj_col_mag: str = None, obj_col_flux: str = 'flux', sn: float = None, n_det: int = 150, n_ref: int = 600, max_quads_tested: int = 8000, allow_reflection: bool = True, use_wcs_prior: bool = True, scale_tolerance: float = 0.30, enforce_parity: bool = True, refine_use_all: bool = False, refine_n_det: Optional[int] = None, refine_n_ref: Optional[int] = None, refine_match_radius_arcsec: Optional[float] = None, get_header: bool = False, update: bool = False, verbose: bool = False, ) -> WCS: """Refine WCS using quad-hash pattern matching algorithm. Pure Python implementation of astrometric refinement using quad-hash based pattern matching with no external dependencies (only numpy, scipy, astropy). Typically achieves sub-arcsecond accuracy; ~2–7× more accurate than SCAMP, especially for challenging conditions, though ~4× slower. Parameters ---------- obj : astropy.table.Table List of objects on the frame, must contain at least ``x``, ``y``, and either ``flux`` or ``mag`` columns. cat : astropy.table.Table Reference astrometric catalogue with RA, Dec, and magnitude columns. wcs : astropy.wcs.WCS, optional Initial WCS solution (rough estimate). header : astropy.io.fits.Header, optional FITS header containing the initial astrometric solution (alternative to ``wcs``). sr : float, optional Matching radius in degrees for stage 2 matching. order : int, optional SIP polynomial order for the distortion solution (0–5). For TAN projection, controls the SIP polynomial degree. For ZPN projection, controls additional SIP corrections on top of the PV radial model (0 means ZPN-only, 2 is recommended for wide-field images). Ignored for other projections. projection : str or None, optional Target projection type for the output WCS. If ``None`` (default), uses the same projection as the input WCS. Supported values: - ``'TAN'`` — gnomonic (TAN) with SIP distortion polynomials - ``'ZPN'`` — zenithal polynomial with PV radial terms (+ optional SIP for non-radial corrections). Best for wide-field images (FoV > 5°) as ZPN naturally handles radial distortion that TAN-SIP struggles with. - Any other FITS WCS projection code supported by astropy (``'STG'``, ``'ARC'``, ``'ZEA'``, ``'SIN'``, etc.) — linear WCS fit only (no distortion polynomials). The initial WCS (``wcs``) is always used for the pattern matching phase; the projection conversion only affects the final WCS fit. pv_deg : int, optional ZPN PV polynomial degree (``PV2_0`` … ``PV2_pv_deg``). Only used when ``projection='ZPN'``. Default is 5, which is sufficient for most optical systems. cat_col_ra : str, optional Catalogue column name for Right Ascension. Default is ``'RAJ2000'``. cat_col_dec : str, optional Catalogue column name for Declination. Default is ``'DEJ2000'``. cat_col_mag : str, optional Catalogue column name for magnitude. Default is ``'rmag'``. obj_col_mag : str, optional Object list column name for magnitude. Auto-detected if not provided. obj_col_flux : str, optional Object list column name for flux. Used if magnitude column is absent. Default is ``'flux'``. sn : float, optional If provided, only objects with S/N exceeding this value will be used. n_det : int, optional Number of brightest detections to use for matching. Default is 150. n_ref : int, optional Number of brightest catalogue stars to use for matching. Default is 600. max_quads_tested : int, optional Maximum number of detection quads to test. Default is 8000. allow_reflection : bool, optional Allow reflected quad matches. Default is True. use_wcs_prior : bool, optional Use the initial WCS to filter hypotheses by scale/parity. Default is True. scale_tolerance : float, optional Fractional tolerance for scale filtering. Default is 0.30 (±30%). enforce_parity : bool, optional Enforce parity consistency with the initial WCS. Default is True. refine_use_all : bool, optional If True, expand the final refinement to a larger pool of objects/catalogue entries. refine_n_det : int or None, optional Max number of detections in the expanded refinement pool. None means all. refine_n_ref : int or None, optional Max number of catalogue stars in the expanded refinement pool. None means all. refine_match_radius_arcsec : float or None, optional Matching radius in arcsec for the expanded refinement pool. Defaults to the stage 2 radius. get_header : bool, optional If True, return the FITS header object instead of WCS solution. update : bool, optional If True, update object sky coordinates in-place using the refined WCS. verbose : bool or callable, optional Whether to show verbose messages. May be boolean or a ``print``-like callable. Returns ------- astropy.wcs.WCS or astropy.io.fits.Header or None Refined WCS solution, or FITS header if ``get_header=True``, or None on failure. Examples -------- >>> from stdpipe.astrometry_quad import refine_wcs_quadhash >>> wcs_refined = refine_wcs_quadhash( ... obj, cat, wcs_init, ... sr=10/3600, # 10 arcsec matching radius ... order=2, # quadratic SIP distortion ... sn=5, # use only S/N > 5 objects ... verbose=True ... ) For wide-field images (FoV > 5 degrees), ZPN projection with SIP corrections typically gives lower systematic residuals at field edges: >>> wcs_refined = refine_wcs_quadhash( ... obj, cat, wcs_init, ... sr=30/3600, ... order=2, # SIP order on top of ZPN radial model ... projection='ZPN', # zenithal polynomial projection ... pv_deg=5, # PV radial polynomial degree ... verbose=True ... ) """ # Simple wrapper around print for logging in verbose mode only log = (verbose if callable(verbose) else print) if verbose else lambda *args, **kwargs: None # Get WCS from header if needed if wcs is None and header is not None: from astropy.io import fits wcs = WCS(header) if wcs is None or not wcs.is_celestial: log("Error: Valid initial WCS required") return None # Prepare catalog columns if cat_col_ra not in cat.colnames or cat_col_dec not in cat.colnames: log(f"Error: Catalog must contain '{cat_col_ra}' and '{cat_col_dec}' columns") return None if cat_col_mag not in cat.colnames: log( f"Warning: Catalog magnitude column '{cat_col_mag}' not found, using first magnitude column" ) # Try to find any magnitude column mag_cols = [c for c in cat.colnames if 'mag' in c.lower()] if mag_cols: cat_col_mag = mag_cols[0] log(f" Using '{cat_col_mag}' as magnitude column") else: log("Error: No magnitude column found in catalog") return None # Prepare object columns if obj_col_mag is None: # Auto-detect magnitude column if 'mag' in obj.colnames: obj_col_mag = 'mag' else: obj_col_mag = None has_mag = obj_col_mag is not None and obj_col_mag in obj.colnames has_flux = obj_col_flux in obj.colnames if not has_mag and not has_flux: log( f"Error: Object list must contain either '{obj_col_mag or 'mag'}' or '{obj_col_flux}' column" ) return None # Apply S/N filtering if requested obj_filtered = obj if sn is not None: if 'fluxerr' in obj.colnames or 'flux_err' in obj.colnames: err_col = 'fluxerr' if 'fluxerr' in obj.colnames else 'flux_err' flux_col = obj_col_flux if has_flux else None if flux_col and flux_col in obj.colnames: obj_sn = obj[flux_col] / obj[err_col] mask = obj_sn > sn obj_filtered = obj[mask] log(f"S/N filtering: {len(obj_filtered)}/{len(obj)} objects with S/N > {sn}") else: log(f"Warning: Cannot apply S/N filter - flux column '{flux_col}' not found") else: log("Warning: Cannot apply S/N filter - no flux error column found") # Create standardized tables with required columns cat_std = Table() cat_std['ra'] = cat[cat_col_ra] cat_std['dec'] = cat[cat_col_dec] cat_std['mag'] = cat[cat_col_mag] obj_std = Table() obj_std['x'] = obj_filtered['x'] obj_std['y'] = obj_filtered['y'] if has_mag: obj_std['mag'] = obj_filtered[obj_col_mag] if has_flux: obj_std['flux'] = obj_filtered[obj_col_flux] # Configure refinement # Compute pixel scale from input WCS for scale-aware defaults try: _pscales = wcs.proj_plane_pixel_scales() if hasattr(_pscales[0], 'to_value'): _pixel_scale_deg = float(np.mean([s.to_value(u.deg) for s in _pscales])) else: _pixel_scale_deg = float(np.mean(_pscales)) pixel_scale_arcsec = _pixel_scale_deg * 3600.0 except Exception: pixel_scale_arcsec = 1.0 # fallback: assume 1 arcsec/pixel _pixel_scale_deg = pixel_scale_arcsec / 3600.0 # Estimate FOV diagonal in degrees for scaling source counts nx = wcs.pixel_shape[0] if wcs.pixel_shape else 2048 ny = wcs.pixel_shape[1] if wcs.pixel_shape else 2048 fov_diag_deg = np.hypot(nx, ny) * _pixel_scale_deg # Scale-aware matching radii: ensure at least a few pixels even for coarse plates sr_arcsec = sr * 3600 stage2_arcsec = max(sr_arcsec, 2.0 * pixel_scale_arcsec) stage1_arcsec = max(sr_arcsec * 2, 3.0 * pixel_scale_arcsec) # Scale n_det/n_ref with FOV: for larger fields use more sources # Base: n_det=150, n_ref=600 for a ~0.5 deg field; scale with sqrt(area) fov_scale = max(1.0, fov_diag_deg / 0.5) scaled_n_det = min(int(n_det * np.sqrt(fov_scale)), len(obj_std)) scaled_n_ref = min(int(n_ref * np.sqrt(fov_scale)), len(cat_std)) log("Starting quad-hash WCS refinement...") log(f" Catalog: {len(cat_std)} stars") log(f" Detections: {len(obj_std)} objects") log(f" Initial WCS center: RA={wcs.wcs.crval[0]:.6f}, Dec={wcs.wcs.crval[1]:.6f}") log( f" Pixel scale: {pixel_scale_arcsec:.2f} arcsec/pixel, FOV diagonal: {fov_diag_deg:.2f} deg" ) log( f" Matching radii: stage1={stage1_arcsec:.1f} arcsec ({stage1_arcsec / pixel_scale_arcsec:.1f} pix), " f"stage2={stage2_arcsec:.1f} arcsec ({stage2_arcsec / pixel_scale_arcsec:.1f} pix)" ) log(f" Source counts: n_det={scaled_n_det}, n_ref={scaled_n_ref}") log(f" SIP order: {order}") # Build projection template for the WCS fitting phase fit_projection = None if projection is not None: from stdpipe.astrometry_wcs import convert_wcs_projection proj_code = projection.upper().strip() fit_projection = convert_wcs_projection(wcs, proj_code, pv_deg=pv_deg) log(f" Output projection: {proj_code} (fit CTYPE: {fit_projection.wcs.ctype})") if proj_code == 'ZPN': log(f" ZPN PV degree: {pv_deg}") elif proj_code != 'TAN' and order > 0: log(f" Note: SIP order={order} ignored for {proj_code} projection (linear fit only)") config = AstrometryConfig( n_det=scaled_n_det, n_ref=scaled_n_ref, neighbor_k=10, eps_desc=0.02, n_scale_bins=6, stage1_n_det=min(60, len(obj_std)), stage1_radius_arcsec=stage1_arcsec, stage2_radius_arcsec=stage2_arcsec, top_k_hypotheses=80, max_quads_tested=max_quads_tested, allow_reflection=allow_reflection, use_wcs_prior=use_wcs_prior, scale_tolerance=scale_tolerance, enforce_parity=enforce_parity, sip_degree=max(0, min(5, order)), # Clamp to 0-5 refine_clip_sigma=4.0, refine_clip_sigma_start=5.0, refine_min_match_fraction=0.5, refine_max_iter=5, refine_rematch_iters=2, refine_use_all=refine_use_all, refine_n_det=refine_n_det, refine_n_ref=refine_n_ref, refine_match_radius_arcsec=refine_match_radius_arcsec, # Disable auto_match_resolution: the wrapper already computes # pixel-scale-aware radii above, so the density-based inflation # (which can produce absurdly large radii for sparse wide-field # images) is not needed here. auto_match_resolution=False, ) try: solver = QuadHashAstrometry(config=config) wcs_refined, match, diagnostics = solver.refine( obj=obj_std, cat=cat_std, wcs_init=wcs, fit_projection=fit_projection, pv_deg=pv_deg, ) log(f"Refinement successful:") log(f" Pattern matches: {diagnostics['pattern_inliers_stage2']}") log(f" Final matches: {diagnostics['final_matches']}") log(f" RMS residual: {diagnostics['rms_dr_arcsec']:.3f} arcsec") log(f" MAD residual: {diagnostics['mad_dr_arcsec']:.3f} arcsec") log( f" Refined WCS center: RA={wcs_refined.wcs.crval[0]:.6f}, Dec={wcs_refined.wcs.crval[1]:.6f}" ) # Update object table with sky coordinates if requested if update and wcs_refined is not None: ra_obj, dec_obj = wcs_refined.all_pix2world(obj['x'], obj['y'], 0) obj['ra'] = ra_obj obj['dec'] = dec_obj log(f"Updated object table with ra/dec coordinates") # Return header if requested if get_header and wcs_refined is not None: return wcs_refined.to_header(relax=True) return wcs_refined except Exception as e: log(f"Refinement failed: {e}") if verbose: import traceback traceback.print_exc() return None