Source code for stdpipe.astrometry_wcs

import numpy as np
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord, SkyOffsetFrame
import astropy.units as u


def _sky_residuals_arcsec(wcs: WCS, xy: np.ndarray, sky: SkyCoord, center: SkyCoord) -> np.ndarray:
    """
    Residuals in tangent-plane arcsec using a SkyOffsetFrame around `center`.
    Returns concatenated [d_lon_arcsec, d_lat_arcsec] per point.
    """
    ra_dec = wcs.all_pix2world(xy[:, 0], xy[:, 1], 0)
    model = SkyCoord(ra=ra_dec[0] * u.deg, dec=ra_dec[1] * u.deg, frame=sky.frame)

    off = SkyOffsetFrame(origin=center)
    sky_off = sky.transform_to(off)
    model_off = model.transform_to(off)

    # Use small-angle offsets on the tangent plane (arcsec)
    dlon = (model_off.lon - sky_off.lon).to_value(u.arcsec)
    dlat = (model_off.lat - sky_off.lat).to_value(u.arcsec)
    return np.concatenate([dlon, dlat])


def _pack_params(w: WCS, pv_deg: int) -> np.ndarray:
    # CRPIX, CRVAL, CD(2x2), PV2_0..PV2_pv_deg
    p = []
    p += [float(w.wcs.crpix[0]), float(w.wcs.crpix[1])]
    p += [float(w.wcs.crval[0]), float(w.wcs.crval[1])]
    cd = np.array(w.wcs.cd, dtype=float)
    p += [cd[0, 0], cd[0, 1], cd[1, 0], cd[1, 1]]

    pv = np.zeros(pv_deg + 1, dtype=float)
    # astropy stores PV in w.wcs.get_pv() / w.wcs.set_pv(); get_pv returns list of (i, m, value)
    for i, m, val in w.wcs.get_pv():
        if i == 2 and 0 <= m <= pv_deg:
            pv[m] = float(val)
    p += pv.tolist()
    return np.array(p, dtype=float)


def _unpack_params_to_wcs(base: WCS, p: np.ndarray, pv_deg: int) -> WCS:
    w = base.deepcopy()
    w.wcs.crpix = [p[0], p[1]]
    w.wcs.crval = [p[2], p[3]]
    w.wcs.cd = np.array([[p[4], p[5]], [p[6], p[7]]], dtype=float)

    pv_vals = p[8 : 8 + (pv_deg + 1)]
    # Preserve any PV keywords for other axes (e.g. PV1_*)
    pv_list = [(i, m, val) for (i, m, val) in base.wcs.get_pv() if i != 2]
    pv_list += [(2, m, float(pv_vals[m])) for m in range(pv_deg + 1)]
    w.wcs.set_pv(pv_list)
    return w


def _normalize_zpn_pv1(w: WCS) -> WCS:
    """Normalize a ZPN WCS so that PV2_1 = 1, absorbing the scale into CD.

    The ZPN projection R = Σ PV2_m·θ^m has a degeneracy: scaling all PV
    by ``c`` and CD by ``c`` gives an identical pixel→sky mapping.  This
    function normalizes PV2_1 to 1 (the FITS standard convention), moving
    any absorbed plate-scale factor into the CD matrix.

    Parameters
    ----------
    w : WCS
        ZPN WCS (modified in place).

    Returns
    -------
    WCS
        The same object, normalized.
    """
    pv_dict = {}
    for i, m, val in w.wcs.get_pv():
        if i == 2:
            pv_dict[m] = val

    pv1 = pv_dict.get(1, 1.0)
    if not np.isfinite(pv1) or abs(pv1) < 1e-15 or abs(pv1 - 1.0) < 1e-12:
        return w  # already normalized or can't normalize

    scale = 1.0 / pv1

    # CD' = CD/pv1 so that intermediate coords scale together with PV.
    cd = np.array(w.wcs.cd, dtype=float)
    w.wcs.cd = cd / pv1

    # Scale all PV2 coefficients
    pv_list = [(i, m, val) for (i, m, val) in w.wcs.get_pv() if i != 2]
    for m, val in pv_dict.items():
        pv_list.append((2, m, float(val * scale)))
    w.wcs.set_pv(pv_list)

    return w


[docs] def fit_zpn_wcs_from_points( xy: np.ndarray, sky: SkyCoord, wcs_init: WCS, pv_deg: int = 5, fit_crpix: bool = True, fit_crval: bool = True, fit_cd: bool = True, fit_pv: bool = True, robust_loss: str = "soft_l1", f_scale_arcsec: float = 2.0, max_nfev: int = 200, ): """ Fit a ZPN WCS by optimizing WCS parameters against matched (x,y) <-> (ra,dec). Parameters ---------- xy : (N,2) array Pixel coordinates (0-based as in astropy WCS, i.e. origin=0). sky : SkyCoord (N) Reference sky positions. wcs_init : astropy.wcs.WCS Initial WCS; MUST already be ZPN (RA---ZPN/DEC--ZPN) or at least usable as base. pv_deg : int Degree for PV2_m coefficients to fit (PV2_0..PV2_pv_deg). fit_* : bool Toggle which parameter blocks to optimize. robust_loss : str Passed to scipy.optimize.least_squares(loss=...). Good options: 'linear', 'soft_l1', 'huber', 'cauchy'. f_scale_arcsec : float Robust loss scale in arcsec. max_nfev : int Optimization iterations (SciPy). Returns ------- wcs_best : astropy.wcs.WCS result : scipy OptimizeResult (or None if SciPy not available) Notes ----- For stability, the solver runs in two stages when *fit_pv* is True: it first fits CRPIX/CRVAL/CD with PV fixed, then fits all free parameters (including PV) with conservative bounds to prevent invalid projections. """ xy = np.asarray(xy, dtype=float) if xy.ndim != 2 or xy.shape[1] != 2: raise ValueError("xy must be (N,2)") if len(sky) != xy.shape[0]: raise ValueError("sky and xy must have the same length") # Choose a stable center for residuals (near the middle of your matched set) center = SkyCoord( ra=np.median(sky.ra).to(u.deg), dec=np.median(sky.dec).to(u.deg), frame=sky.frame ) def _estimate_theta_max_deg(w: WCS) -> float | None: if w.pixel_shape is None: return None nx, ny = w.pixel_shape if nx is None or ny is None: return None # Prefer a pixel-scale estimate (robust to projection issues) pixel_scale_deg = None try: pscales = w.proj_plane_pixel_scales() if hasattr(pscales[0], "to_value"): pixel_scale_deg = float(np.mean([s.to_value(u.deg) for s in pscales])) else: pixel_scale_deg = float(np.mean(pscales)) except Exception: pixel_scale_deg = None if pixel_scale_deg is not None and np.isfinite(pixel_scale_deg) and pixel_scale_deg > 0: r_pix = 0.5 * np.hypot(nx, ny) return max(float(pixel_scale_deg * r_pix), 0.01) # Fallback: estimate from footprint if possible try: crpix = np.array(w.wcs.crpix, dtype=float) corners = np.array( [[0.0, 0.0], [nx - 1.0, 0.0], [0.0, ny - 1.0], [nx - 1.0, ny - 1.0]], dtype=float, ) ra_c, dec_c = w.all_pix2world(crpix[0], crpix[1], 0) ra_k, dec_k = w.all_pix2world(corners[:, 0], corners[:, 1], 0) if ( np.isfinite(ra_c) and np.isfinite(dec_c) and np.all(np.isfinite(ra_k)) and np.all(np.isfinite(dec_k)) ): center = SkyCoord(ra_c * u.deg, dec_c * u.deg, frame="icrs") sky_k = SkyCoord(ra_k * u.deg, dec_k * u.deg, frame="icrs") theta = float(np.max(center.separation(sky_k).to_value(u.deg))) return max(theta, 0.01) except Exception: pass return None def _make_bounds(p0: np.ndarray, mask: np.ndarray, base: WCS, allow_pv: bool) -> tuple: # CRPIX/CRVAL/CD unbounded by default free0 = p0[mask] lb = np.full_like(free0, -np.inf, dtype=float) ub = np.full_like(free0, +np.inf, dtype=float) if not allow_pv or pv_deg < 1: return lb, ub # PV bounds: keep solution in a physically plausible neighborhood pv0 = float(p0[8 + 0]) pv1 = float(p0[8 + 1]) if pv_deg >= 1 else 1.0 if not np.isfinite(pv1) or pv1 == 0: pv1 = 1.0 theta_max_deg = _estimate_theta_max_deg(base) pv1_abs = abs(pv1) if np.isfinite(pv1) else 1.0 # PV2_0 ~ 0 (allow small drift; keep init inside bounds) full_idx = 8 + 0 if mask[full_idx]: free_idx = np.flatnonzero(mask).tolist().index(full_idx) pv0_abs = max(1e-3, abs(pv0) * 2.0) lb[free_idx] = pv0 - pv0_abs ub[free_idx] = pv0 + pv0_abs # PV2_1 (linear): allow moderate range, keep positive full_idx = 8 + 1 if mask[full_idx]: free_idx = np.flatnonzero(mask).tolist().index(full_idx) lb[free_idx] = max(0.1, pv1 * 0.2) ub[free_idx] = pv1 * 5.0 # Higher-order PV terms: limit contribution at field edge pv_frac = 0.3 # allow ~30% of linear term at the edge for m in range(2, pv_deg + 1): full_idx = 8 + m if not mask[full_idx]: continue pv_m = float(p0[full_idx]) if theta_max_deg is not None and np.isfinite(theta_max_deg) and theta_max_deg > 0: abs_bound = pv_frac * pv1_abs / (theta_max_deg ** (m - 1)) else: abs_bound = max(abs(pv_m) * 5.0, 1e-6) # Keep initial value inside bounds abs_bound = max(abs_bound, abs(pv_m) * 2.0) free_idx = np.flatnonzero(mask).tolist().index(full_idx) lb[free_idx] = pv_m - abs_bound ub[free_idx] = pv_m + abs_bound return lb, ub def _fit_with_mask(base: WCS, allow_pv: bool): p0 = _pack_params(base, pv_deg=pv_deg) # Build a mask over parameters to optionally freeze blocks mask = np.ones_like(p0, dtype=bool) # indices: 0..1 CRPIX, 2..3 CRVAL, 4..7 CD, 8.. PV if not fit_crpix: mask[0:2] = False if not fit_crval: mask[2:4] = False if not fit_cd: mask[4:8] = False if not allow_pv: mask[8:] = False elif pv_deg >= 1: # PV2_1 (linear scale) is degenerate with CD — let it float. # Caller normalizes after fit via _normalize_zpn_pv1(). pass if not np.any(mask): return base, None free0 = p0[mask] lb, ub = _make_bounds(p0, mask, base, allow_pv=allow_pv) def make_wcs_from_free(free: np.ndarray) -> WCS: p = p0.copy() p[mask] = free return _unpack_params_to_wcs(base, p, pv_deg=pv_deg) def fun(free: np.ndarray) -> np.ndarray: w = make_wcs_from_free(free) res = _sky_residuals_arcsec(w, xy, sky, center) if not np.all(np.isfinite(res)): res = np.nan_to_num(res, nan=1e6, posinf=1e6, neginf=-1e6) return res res = least_squares( fun, free0, bounds=(lb, ub), loss=robust_loss, f_scale=float(f_scale_arcsec), max_nfev=int(max_nfev), x_scale="jac", verbose=0, ) w_best = make_wcs_from_free(res.x) return w_best, res # Try SciPy; if unavailable, error with a clear message. try: from scipy.optimize import least_squares except Exception as e: raise RuntimeError( "This fitter needs SciPy (scipy.optimize.least_squares). " "Install scipy or tell me and I’ll provide a pure-numpy Gauss-Newton fallback." ) from e # Two-stage fit: first solve CRPIX/CRVAL/CD with PV fixed, # then allow PV with conservative bounds. w_curr = wcs_init res_last = None if fit_pv and (fit_crpix or fit_crval or fit_cd): w_curr, res_last = _fit_with_mask(w_curr, allow_pv=False) w_curr, res_last = _fit_with_mask(w_curr, allow_pv=fit_pv) # Normalize so PV2_1 = 1 (breaks the CD/PV degeneracy) if fit_pv: w_curr = _normalize_zpn_pv1(w_curr) return w_curr, res_last
def _fit_zpn_sip( wcs_zpn, xy, sky, sip_degree=2, pv_deg=5, n_iter=15, robust_loss="soft_l1", f_scale_arcsec=2.0, max_nfev=200, ): """Add SIP distortion corrections on top of a ZPN WCS. The SIP polynomials A(u,v), B(u,v) correct for non-radial distortions (coma, astigmatism, etc.) that the ZPN radial model cannot capture. Iterates between fitting SIP coefficients (pixel-space residuals) and re-optimizing PV parameters with SIP applied. Parameters ---------- wcs_zpn : WCS ZPN WCS with PV parameters already fitted (no SIP). xy : (N, 2) array Pixel coordinates of matched sources (0-based). sky : SkyCoord Reference sky positions. sip_degree : int SIP polynomial order (default 2). pv_deg : int ZPN PV polynomial degree for re-fitting. n_iter : int Number of PV+SIP alternation iterations. robust_loss, f_scale_arcsec, max_nfev : Passed to ``fit_zpn_wcs_from_points`` for PV re-fitting. Returns ------- wcs_sip : WCS ZPN-SIP WCS with both PV and SIP coefficients. """ from astropy.wcs.wcs import Sip xy = np.asarray(xy, dtype=float) w = wcs_zpn.deepcopy() # SIP term indices: (p, q) with 2 <= p+q <= sip_degree pq = [] for total in range(2, sip_degree + 1): for q in range(total + 1): pq.append((total - q, q)) crpix = np.array(w.wcs.crpix, dtype=float) # 1-based prev_med_dist = np.inf for iteration in range(n_iter): # Strip SIP to get the base ZPN-only WCS w_nosip = w.deepcopy() w_nosip.sip = None if '-SIP' in w_nosip.wcs.ctype[0]: w_nosip.wcs.ctype = [c.replace('-SIP', '') for c in w_nosip.wcs.ctype] # Undistorted pixel coords: where ZPN projection places catalog stars x_cat, y_cat = w_nosip.all_world2pix(sky.ra.deg, sky.dec.deg, 0) u_cat = x_cat - (crpix[0] - 1) v_cat = y_cat - (crpix[1] - 1) # Observed (distorted) pixel coords u_obs = xy[:, 0] - (crpix[0] - 1) v_obs = xy[:, 1] - (crpix[1] - 1) # SIP convention: u_distorted = u_undistorted + A(u, v) # u_obs = undistorted (raw pixel); u_cat = distorted (from ZPN inverse) # So: A(u_obs, v_obs) = u_cat - u_obs du = u_cat - u_obs dv = v_cat - v_obs # Filter outliers (use median absolute deviation) dist = np.sqrt(du**2 + dv**2) med_dist = np.median(dist) mad = np.median(np.abs(dist - med_dist)) good = dist < med_dist + 5 * max(mad, 0.5) # generous 5-sigma clip if np.sum(good) < len(pq) + 5: # Not enough points for SIP fitting break # SIP basis uses undistorted (raw pixel) coordinates u_fit = u_obs[good] v_fit = v_obs[good] du_fit = du[good] dv_fit = dv[good] # Normalize coordinates to prevent ill-conditioning for high # SIP degrees. Without this, u^5 ~ 2000^5 ~ 3e16 creates a # design matrix with condition number > 1e14. u_scale = max(np.abs(u_fit).max(), 1.0) v_scale = max(np.abs(v_fit).max(), 1.0) u_norm = u_fit / u_scale v_norm = v_fit / v_scale # Build SIP basis matrix in normalized coordinates basis = np.column_stack([u_norm**p * v_norm**q for p, q in pq]) # Fit A and B coefficients via least squares (normalized) ca_norm, _, _, _ = np.linalg.lstsq(basis, du_fit, rcond=None) cb_norm, _, _, _ = np.linalg.lstsq(basis, dv_fit, rcond=None) # Convert back to original (unnormalized) SIP coefficients a_vals = np.zeros((sip_degree + 1, sip_degree + 1)) b_vals = np.zeros((sip_degree + 1, sip_degree + 1)) for k, (p, q) in enumerate(pq): scale_pq = u_scale**p * v_scale**q a_vals[p, q] = ca_norm[k] / scale_pq b_vals[p, q] = cb_norm[k] / scale_pq # Apply SIP to the WCS if '-SIP' not in w.wcs.ctype[0]: w.wcs.ctype = [c + '-SIP' for c in w.wcs.ctype] w.sip = Sip( a_vals, b_vals, np.zeros((sip_degree + 1, sip_degree + 1)), np.zeros((sip_degree + 1, sip_degree + 1)), crpix, ) # Check convergence: stop when SIP corrections stabilize. # Use 90th percentile (sensitive to outer-field + center) rather # than median which converges before the center is corrected. curr_q90 = np.percentile(dist[good], 90) if iteration > 4 and prev_med_dist < np.inf: rel_change = abs(curr_q90 - prev_med_dist) / max(prev_med_dist, 1e-6) if rel_change < 0.005: break prev_med_dist = curr_q90 # Re-fit PV with SIP now applied (last iteration skip PV refit) if iteration < n_iter - 1: w, _ = fit_zpn_wcs_from_points( xy, sky, w, pv_deg=pv_deg, robust_loss=robust_loss, f_scale_arcsec=f_scale_arcsec, max_nfev=max_nfev, ) crpix = np.array(w.wcs.crpix, dtype=float) return w
[docs] def tan_wcs_to_zpn( w_tan: WCS, pv_deg: int = 7, n_samples: int = 256, theta_max_deg: float | None = None, drop_sip: bool = True, ) -> WCS: """ Convert a celestial TAN WCS into a ZPN WCS with PV2_m initialized to approximate TAN. Notes ----- TAN (gnomonic) radial law: r = tan(theta) [in radians] In "degrees" units (common in FITS WCS plane coordinates), that's: r_deg = tan(theta_rad) * (180/pi) ZPN radial law: r_deg ≈ sum_{m=0..M} PV2_m * theta_deg^m We set PV2_0 = 0 and fit PV2_1..PV2_M to approximate the TAN law over theta in [0, theta_max_deg]. Parameters ---------- w_tan : astropy.wcs.WCS Input TAN WCS (2D celestial). pv_deg : int Highest PV degree to initialize (PV2_0..PV2_pv_deg). 5–9 is usually plenty; higher can get wiggly. n_samples : int Samples used for the polynomial fit. theta_max_deg : float or None Max angular radius (deg) over which to match TAN. If None, estimated from image footprint corners using pixel_shape. drop_sip : bool If True, removes SIP distortions from the returned WCS. Returns ------- w_zpn : astropy.wcs.WCS A ZPN WCS with same CRVAL/CRPIX/CD and PV2_m initialized. """ # Build a clean WCS with CD only (no PC) to avoid PC/CDELT overriding CD. w = WCS(naxis=2) # Compute CD from the input WCS if w_tan.wcs.has_cd(): cd = np.array(w_tan.wcs.cd, dtype=float) elif w_tan.wcs.has_pc(): pc = np.array(w_tan.wcs.pc, dtype=float) cdelt = np.array(w_tan.wcs.cdelt, dtype=float) cd = pc * cdelt[None, :] else: cdelt = np.array(w_tan.wcs.cdelt, dtype=float) cd = np.diag(cdelt) w.wcs.crpix = np.array(w_tan.wcs.crpix, dtype=float) w.wcs.crval = np.array(w_tan.wcs.crval, dtype=float) w.wcs.cd = cd # Switch projection to ZPN (keep axis names) ctype1, ctype2 = w_tan.wcs.ctype if len(ctype1) < 8 or len(ctype2) < 8: raise ValueError("Expected CTYPE like 'RA---TAN'/'DEC--TAN'.") w.wcs.ctype = (ctype1[:5] + "ZPN", ctype2[:5] + "ZPN") # Copy metadata / frame info when available try: w.wcs.cunit = w_tan.wcs.cunit except Exception: pass try: w.wcs.radesys = w_tan.wcs.radesys except Exception: pass try: w.wcs.equinox = w_tan.wcs.equinox except Exception: pass try: if np.isfinite(w_tan.wcs.lonpole): w.wcs.lonpole = float(w_tan.wcs.lonpole) except Exception: pass try: if np.isfinite(w_tan.wcs.latpole): w.wcs.latpole = float(w_tan.wcs.latpole) except Exception: pass if w_tan.pixel_shape is not None: w.pixel_shape = w_tan.pixel_shape # SIP/distortion is not standard for ZPN; keep behavior explicit if drop_sip: w.sip = None w.cpdis1 = None w.cpdis2 = None w.det2im1 = None w.det2im2 = None # Estimate theta_max from footprint if not provided if theta_max_deg is None: if w.pixel_shape is None: raise ValueError( "w_tan.pixel_shape is None; provide theta_max_deg explicitly " "or set w_tan.pixel_shape = (nx, ny)." ) nx, ny = w.pixel_shape # (NAXIS1, NAXIS2) crpix = np.array(w.wcs.crpix, dtype=float) # Corners in pixel coordinates (origin=0 convention for astropy WCS) corners = np.array( [ [0.0, 0.0], [nx - 1.0, 0.0], [0.0, ny - 1.0], [nx - 1.0, ny - 1.0], ], dtype=float, ) # Sky positions of corners and center under TAN WCS ra_c, dec_c = w_tan.all_pix2world(crpix[0], crpix[1], 0) center = SkyCoord(ra_c * u.deg, dec_c * u.deg, frame="icrs") ra_k, dec_k = w_tan.all_pix2world(corners[:, 0], corners[:, 1], 0) sky_k = SkyCoord(ra_k * u.deg, dec_k * u.deg, frame="icrs") theta_max_deg = float(np.max(center.separation(sky_k).to_value(u.deg))) # Safety floor theta_max_deg = max(theta_max_deg, 0.01) # Fit PV2_m so that ZPN radial r(theta) ~ TAN radial r(theta) # Use degrees for theta and degrees for r on plane. theta = np.linspace(0.0, theta_max_deg, n_samples, dtype=float) theta_rad = np.deg2rad(theta) r_tan_deg = np.tan(theta_rad) * (180.0 / np.pi) # Build design matrix for m=1..pv_deg (PV2_0 fixed to 0) # r ≈ sum c[m-1] * theta^m A = np.vstack([theta**m for m in range(1, pv_deg + 1)]).T # Mild weighting: emphasize central region (helps stability) # (You can tune this; it’s just for initialization.) wgt = 1.0 / (1.0 + (theta / (0.6 * theta_max_deg)) ** 2) Aw = A * wgt[:, None] bw = r_tan_deg * wgt coeffs, *_ = np.linalg.lstsq(Aw, bw, rcond=None) pv_list = [(2, 0, 0.0)] + [(2, m, float(coeffs[m - 1])) for m in range(1, pv_deg + 1)] w.wcs.set_pv(pv_list) return w
[docs] def convert_wcs_projection( wcs_input: WCS, target_projection: str, pv_deg: int = 5, ) -> WCS: """Convert a celestial WCS to a different projection type. Parameters ---------- wcs_input : WCS Input WCS (any celestial projection). target_projection : str Target projection code, e.g. ``'TAN'``, ``'ZPN'``, ``'STG'``, ``'ARC'``, ``'ZEA'``, ``'SIN'``, etc. pv_deg : int, optional ZPN PV polynomial degree (only used when *target_projection* is ``'ZPN'``). Default 5. Returns ------- WCS New WCS with the target projection. For ZPN the PV coefficients are initialised to approximate the input projection's radial law. For other projections CRPIX/CRVAL/CD are copied and CTYPE is replaced. """ target = target_projection.upper().strip() # Detect current projection code (last 3 chars of CTYPE after the dash) try: cur_proj = wcs_input.wcs.ctype[0].split('-')[-1] # Strip trailing SIP suffix if present if cur_proj == 'SIP': parts = wcs_input.wcs.ctype[0].replace('-SIP', '').split('-') cur_proj = parts[-1] except Exception: cur_proj = '' if cur_proj == target: return wcs_input.deepcopy() # ZPN needs special initialisation (PV coefficients) if target == 'ZPN': # tan_wcs_to_zpn works from TAN (with or without SIP) # For non-TAN input, we first need a TAN-like WCS base if 'TAN' in cur_proj: return tan_wcs_to_zpn(wcs_input, pv_deg=pv_deg) else: # Build a minimal TAN WCS from the input's linear terms, # then convert to ZPN w_tan = _wcs_to_linear(wcs_input, 'TAN') return tan_wcs_to_zpn(w_tan, pv_deg=pv_deg) # For all other projections: copy linear WCS with new CTYPE return _wcs_to_linear(wcs_input, target)
def _wcs_to_linear(wcs_input: WCS, proj_code: str) -> WCS: """Create a clean linear WCS with the same pointing but a new projection. Copies CRPIX, CRVAL, CD (or derives CD from PC+CDELT), metadata, and pixel_shape. Strips SIP distortion and PV parameters. """ w = WCS(naxis=2) # CD matrix if wcs_input.wcs.has_cd(): cd = np.array(wcs_input.wcs.cd, dtype=float) elif wcs_input.wcs.has_pc(): pc = np.array(wcs_input.wcs.pc, dtype=float) cdelt = np.array(wcs_input.wcs.cdelt, dtype=float) cd = pc * cdelt[None, :] else: cdelt = np.array(wcs_input.wcs.cdelt, dtype=float) cd = np.diag(cdelt) w.wcs.crpix = np.array(wcs_input.wcs.crpix, dtype=float) w.wcs.crval = np.array(wcs_input.wcs.crval, dtype=float) w.wcs.cd = cd # Build CTYPE: preserve axis names (e.g. 'RA---' / 'DEC--') ctype1, ctype2 = wcs_input.wcs.ctype # Strip any existing projection + SIP suffix base1 = ctype1.replace('-SIP', '')[:5] base2 = ctype2.replace('-SIP', '')[:5] w.wcs.ctype = (base1 + proj_code, base2 + proj_code) # Copy metadata for attr in ('cunit', 'radesys', 'equinox'): try: setattr(w.wcs, attr, getattr(wcs_input.wcs, attr)) except Exception: pass for attr in ('lonpole', 'latpole'): try: val = getattr(wcs_input.wcs, attr) if np.isfinite(val): setattr(w.wcs, attr, float(val)) except Exception: pass if wcs_input.pixel_shape is not None: w.pixel_shape = wcs_input.pixel_shape return w def _fit_tan_sip_robust( xy, world_coords, proj_point="center", projection=None, sip_degree=2, robust_loss="soft_l1", f_scale=None, ): """Fit TAN+SIP WCS with robust loss function. Replicates astropy's ``fit_wcs_from_points`` two-stage fitting (linear WCS, then joint CD + CRPIX + SIP) but uses a robust loss in the linear stage and a two-pass approach for the SIP stage: first L2 to capture the distortion pattern, then robust re-fit with data-driven ``f_scale`` to suppress outliers. Parameters ---------- xy : tuple of arrays ``(x, y)`` Pixel coordinates. world_coords : `~astropy.coordinates.SkyCoord` Reference sky positions. proj_point : str or SkyCoord Projection center ('center' or explicit). projection : WCS or None Template WCS; used as initial guess for CD matrix. sip_degree : int SIP polynomial degree. robust_loss : str Loss function for `scipy.optimize.least_squares`. f_scale : float or None Soft margin for robust loss in the SIP stage. If None (default), estimated adaptively from the L2 SIP residuals. The linear stage always uses a data-driven scale. Returns ------- wcs : `~astropy.wcs.WCS` """ from scipy.optimize import least_squares from astropy.wcs.utils import ( _linear_wcs_fit, _sip_fit, celestial_frame_to_wcs, ) from astropy.wcs.wcs import Sip xp, yp = xy try: lon, lat = world_coords.data.lon.deg, world_coords.data.lat.deg except AttributeError: unit_sph = world_coords.unit_spherical lon, lat = unit_sph.lon.deg, unit_sph.lat.deg use_center_as_proj_point = str(proj_point) == "center" # Build WCS template if isinstance(projection, str) or projection is None: proj_code = projection if isinstance(projection, str) else "TAN" wcs = celestial_frame_to_wcs(frame=world_coords.frame, projection=proj_code) else: wcs = projection.deepcopy() wcs.sip = None if wcs.wcs.has_pc(): wcs.wcs.cd = wcs.wcs.pc * wcs.wcs.cdelt wcs.wcs.cdelt = (1.0, 1.0) wcs.wcs.__delattr__("pc") xpmin, xpmax = xp.min(), xp.max() ypmin, ypmax = yp.min(), yp.max() wcs.pixel_shape = ( 1 if xpmax <= 0.0 else int(np.ceil(xpmax)), 1 if ypmax <= 0.0 else int(np.ceil(ypmax)), ) if use_center_as_proj_point: sc1 = SkyCoord(lon.min() * u.deg, lat.max() * u.deg) sc2 = SkyCoord(lon.max() * u.deg, lat.min() * u.deg) pa = sc1.position_angle(sc2) sep = sc1.separation(sc2) midpoint_sc = sc1.directional_offset_by(pa, sep / 2) wcs.wcs.crval = (midpoint_sc.data.lon.deg, midpoint_sc.data.lat.deg) wcs.wcs.crpix = ((xpmax + xpmin) / 2.0, (ypmax + ypmin) / 2.0) else: proj_point = proj_point.transform_to(world_coords.frame) wcs.wcs.crval = (proj_point.data.lon.deg, proj_point.data.lat.deg) close = lambda l, p: p[np.argmin(np.abs(l))] wcs.wcs.crpix = ( close(lon - wcs.wcs.crval[0], xp + 1), close(lat - wcs.wcs.crval[1], yp + 1), ) if xpmin == xpmax: xpmin, xpmax = xpmin - 0.5, xpmax + 0.5 if ypmin == ypmax: ypmin, ypmax = ypmin - 0.5, ypmax + 0.5 # --- Stage 1: linear WCS (CD + CRPIX) --- # Use robust loss here: linear residuals include distortion as systematic # pattern that looks like outliers; robust loss prevents them from biasing # the linear fit. p0_lin = np.concatenate([wcs.wcs.cd.flatten(), wcs.wcs.crpix.flatten()]) lin_resids = _linear_wcs_fit(p0_lin, lon, lat, xp, yp, wcs) lin_f_scale = max(float(np.median(np.abs(lin_resids)) * 3), 1.0 / 3600) fit = least_squares( _linear_wcs_fit, p0_lin, args=(lon, lat, xp, yp, wcs), bounds=[ [-np.inf, -np.inf, -np.inf, -np.inf, xpmin + 1, ypmin + 1], [np.inf, np.inf, np.inf, np.inf, xpmax + 1, ypmax + 1], ], loss=robust_loss, f_scale=lin_f_scale, method="trf", ) wcs.wcs.crpix = np.array(fit.x[4:6]) wcs.wcs.cd = np.array(fit.x[0:4].reshape((2, 2))) # --- Stage 2: joint CD + CRPIX + SIP --- if "-SIP" not in wcs.wcs.ctype[0]: wcs.wcs.ctype = [x + "-SIP" for x in wcs.wcs.ctype] coef_names = [ f"{i}_{j}" for i in range(sip_degree + 1) for j in range(sip_degree + 1) if (i + j) < (sip_degree + 1) and (i + j) > 1 ] sip_bounds = ( [xpmin + 1, ypmin + 1] + [-np.inf] * (4 + 2 * len(coef_names)), [xpmax + 1, ypmax + 1] + [np.inf] * (4 + 2 * len(coef_names)), ) # Parameter layout of _sip_fit: [CRPIX(2), CD(4), A_coeffs, B_coeffs] sip_args = (lon, lat, xp, yp, wcs, sip_degree, coef_names) # Pass 1: L2 fit to capture the full distortion pattern (SIP starts at 0) p0_sip = np.concatenate( ( np.array(wcs.wcs.crpix), wcs.wcs.cd.flatten(), np.zeros(2 * len(coef_names)), ) ) # Compute parameter scales for the optimizer. # SIP coefficients A_{p,q} multiply (u-crpix)^p * (v-crpix)^q where # pixel offsets can be ~U pixels. Expected coefficient magnitude is # ~1/U^(p+q), which spans many orders of magnitude for high SIP orders. # Without proper scaling the optimizer cannot find meaningful SIP≥4 # coefficients on wide-field images. U = max( abs(xpmax - wcs.wcs.crpix[0]), abs(xpmin - wcs.wcs.crpix[0]), abs(ypmax - wcs.wcs.crpix[1]), abs(ypmin - wcs.wcs.crpix[1]), 1.0, ) sip_x_scale = np.ones(len(p0_sip)) # CRPIX: O(pixels) → scale 1 # CD: O(deg/pixel) → use current magnitude cd_scale = max(np.abs(wcs.wcs.cd).max(), 1e-10) sip_x_scale[2:6] = cd_scale # SIP coefficients: O(1/U^(p+q)) for k, coef_name in enumerate(coef_names): p_deg, q_deg = int(coef_name[0]), int(coef_name[2]) order = p_deg + q_deg scale = 1.0 / U**order sip_x_scale[6 + k] = scale sip_x_scale[6 + len(coef_names) + k] = scale fit = least_squares( _sip_fit, p0_sip, args=sip_args, bounds=sip_bounds, x_scale=sip_x_scale, ) # Pass 2: robust re-fit starting from L2 solution # f_scale estimated from L2 residuals so distortion signal is inlier sip_resids = fit.fun if f_scale is None: sip_f_scale = max(float(np.median(np.abs(sip_resids)) * 3), 1e-7) else: sip_f_scale = float(f_scale) fit = least_squares( _sip_fit, fit.x, # warm-start from L2 solution args=sip_args, bounds=sip_bounds, loss=robust_loss, f_scale=sip_f_scale, method="trf", x_scale=sip_x_scale, ) coef_fit = ( list(fit.x[6 : 6 + len(coef_names)]), list(fit.x[6 + len(coef_names) :]), ) wcs.wcs.cd = fit.x[2:6].reshape((2, 2)) wcs.wcs.crpix = fit.x[0:2] a_vals = np.zeros((sip_degree + 1, sip_degree + 1)) b_vals = np.zeros((sip_degree + 1, sip_degree + 1)) for coef_name in coef_names: a_vals[int(coef_name[0])][int(coef_name[2])] = coef_fit[0].pop(0) b_vals[int(coef_name[0])][int(coef_name[2])] = coef_fit[1].pop(0) wcs.sip = Sip( a_vals, b_vals, np.zeros((sip_degree + 1, sip_degree + 1)), np.zeros((sip_degree + 1, sip_degree + 1)), wcs.wcs.crpix, ) return wcs
[docs] def fit_wcs_from_points( xy, world_coords, proj_point="center", projection=None, sip_degree=None, pv_deg=5, ): """Drop-in wrapper around :func:`astropy.wcs.utils.fit_wcs_from_points` that also handles **ZPN** projection (which astropy does not natively fit) and **ZPN-SIP** (ZPN radial distortion plus SIP polynomial corrections). Parameters ---------- xy : tuple of arrays ``(x, y)`` or ``(2, N)`` array Pixel coordinates (same convention as the astropy function). world_coords : `~astropy.coordinates.SkyCoord` Reference sky positions. proj_point : str, optional Passed through to astropy for non-ZPN projections. projection : `~astropy.wcs.WCS` or other, optional Projection template. If this is a WCS with ``RA---ZPN / DEC--ZPN`` CTYPEs, the ZPN fitter is used instead of astropy's. sip_degree : int or None, optional SIP polynomial degree. For TAN projections, controls SIP distortion order. For ZPN projections, if > 0, SIP corrections are fitted on top of ZPN PV parameters to capture non-radial distortions. pv_deg : int, optional ZPN PV polynomial degree (``PV2_0 … PV2_pv_deg``). Default 5. Returns ------- wcs : `~astropy.wcs.WCS` Fitted WCS (same return type as the astropy function). """ from astropy.wcs.utils import fit_wcs_from_points as _astropy_fit # Detect ZPN projection is_zpn = False if isinstance(projection, WCS): try: is_zpn = "ZPN" in projection.wcs.ctype[0] except Exception: pass if is_zpn: # Convert xy to (N, 2) array expected by fit_zpn_wcs_from_points xy_arr = np.asarray(xy, dtype=float) if isinstance(xy, (list, tuple)) and len(xy) == 2: # (x_array, y_array) form xy_arr = np.column_stack( [np.asarray(xy[0], dtype=float), np.asarray(xy[1], dtype=float)] ) elif xy_arr.ndim == 2 and xy_arr.shape[0] == 2 and xy_arr.shape[1] != 2: # (2, N) -> (N, 2) xy_arr = xy_arr.T zpn_deg = int(pv_deg) # First fit ZPN PV parameters (radial distortion) wcs_best, _result = fit_zpn_wcs_from_points( xy_arr, world_coords, wcs_init=projection, pv_deg=zpn_deg ) # Then fit SIP corrections for non-radial distortions if sip_degree is not None and int(sip_degree) > 0: wcs_best = _fit_zpn_sip( wcs_best, xy_arr, world_coords, sip_degree=int(sip_degree), pv_deg=zpn_deg, ) return wcs_best # ---------- Non-ZPN ---------- # SIP only makes sense for TAN-based projections effective_sip = sip_degree if isinstance(projection, WCS): try: if "TAN" not in projection.wcs.ctype[0]: effective_sip = None except Exception: pass if effective_sip is not None and effective_sip > 0: return _fit_tan_sip_robust( xy, world_coords, proj_point=proj_point, projection=projection, sip_degree=int(effective_sip), ) return _astropy_fit( xy, world_coords, proj_point=proj_point, projection=projection, )