"""
Real-Bogus Classifier for Astronomical Object Detection
This module provides a CNN-based classifier to distinguish real astronomical sources
(stars, galaxies) from artifacts (cosmic rays, hot pixels, satellite trails) in
detected object catalogs.
Key Features:
- FWHM-invariant: Hybrid downscaling to canonical PSF size (no auxiliary FWHM input)
- Brightness-invariant: Peak normalization generalizes to any flux level
- Pure morphology: Classification based solely on source shape from 2-channel images
- 2-channel input: background-subtracted (linear), asinh-scaled (dynamic range compression)
- Lightweight 5-layer CNN (~100k parameters)
- Batch processing for efficient inference
- Optional TensorFlow dependency
Usage:
from stdpipe import photometry, realbogus
# Detect objects
obj = photometry.get_objects_sep(image, thresh=3.0)
# Classify and filter
obj_clean = realbogus.classify_realbogus(obj, image, threshold=0.5)
# Or add scores without filtering
obj = realbogus.classify_realbogus(obj, image, add_score=True, flag_bogus=False)
print(obj['rb_score'])
Author: STDPipe Contributors
"""
from __future__ import absolute_import, division, print_function
import numpy as np
import os
import warnings
from astropy.table import Column
# Conditional TensorFlow import
try:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
HAS_TENSORFLOW = True
except ImportError:
HAS_TENSORFLOW = False
warnings.warn(
"TensorFlow not found. Real-bogus classifier will not be available. "
"Install with: pip install stdpipe[ml]",
ImportWarning,
)
# Model architecture parameters
DEFAULT_MODEL_DIR = os.path.expanduser('~/.stdpipe/models')
DEFAULT_MODEL_NAME = 'realbogus_default.h5'
TARGET_FWHM = 3.0 # Canonical FWHM for downscaling normalization
DEFAULT_ASINH_SOFTENING_SIGMA = 3.0 # Asinh softening in units of background sigma
def _check_tensorflow():
"""Check if TensorFlow is available, raise error if not."""
if not HAS_TENSORFLOW:
raise ImportError(
"TensorFlow is required for real-bogus classification. "
"Install with: pip install stdpipe[ml] or pip install tensorflow>=2.10"
)
[docs]
def create_realbogus_model(
input_shape=(31, 31, 2), filters=(32, 64, 128), dense_units=64, dropout_rate=0.5
):
"""
Create CNN architecture for real-bogus classification.
Architecture:
- 3-5 convolutional layers with batch normalization
- Global average pooling (handles variable input sizes)
- Dense layer with dropout
- Sigmoid output (binary classification)
Design Philosophy:
- FWHM-invariant: Images downscaled to canonical FWHM, no auxiliary FWHM input needed
- Brightness-invariant: Peak normalization allows generalization to any flux level
- Pure morphology: Classification based solely on source shape
Input Channels:
- Channel 0: Background-subtracted (linear scale), peak-normalized
- Channel 1: Asinh-scaled background-subtracted, peak-normalized
Parameters
----------
input_shape : tuple, optional
Input shape (height, width, channels). Default: (31, 31, 2)
Height/width can be None for variable-size inputs.
filters : tuple, optional
Number of filters in each conv layer. Default: (32, 64, 128)
dense_units : int, optional
Units in dense layer. Default: 64
dropout_rate : float, optional
Dropout rate for regularization. Default: 0.5
Returns
-------
model : keras.Model
Compiled Keras model ready for training
"""
_check_tensorflow()
# Main image input (2 channels: background-subtracted linear, asinh-scaled)
image_input = keras.Input(shape=input_shape, name='image_input')
# Convolutional feature extraction
x = image_input
for i, n_filters in enumerate(filters):
x = layers.Conv2D(
n_filters, (3, 3), activation='relu', padding='same', name=f'conv{i + 1}'
)(x)
x = layers.BatchNormalization(name=f'bn{i + 1}')(x)
x = layers.MaxPooling2D((2, 2), name=f'pool{i + 1}')(x)
# Global pooling to handle variable input sizes
x = layers.GlobalAveragePooling2D(name='global_pool')(x)
# Dense layers (no auxiliary inputs - pure image-based classification)
x = layers.Dense(dense_units, activation='relu', name='dense1')(x)
x = layers.Dropout(dropout_rate, name='dropout')(x)
# Output layer (sigmoid for binary classification)
output = layers.Dense(1, activation='sigmoid', name='output')(x)
# Create and compile model
model = keras.Model(inputs=image_input, outputs=output, name='realbogus_classifier')
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='binary_crossentropy',
metrics=['accuracy', keras.metrics.AUC(name='auc')],
)
return model
def _downscale_cutout(cutout, scale_factor, mode='mean'):
"""
Downscale cutout by integer factor using block averaging.
Parameters
----------
cutout : ndarray
2D image cutout
scale_factor : int or float
Downscaling factor (converted to integer, must be >= 1).
Typically passed as round(fwhm/target_fwhm) for robust hot pixel suppression.
mode : str, optional
Downscaling mode: 'mean' or 'median'. Default: 'mean'
Returns
-------
downscaled : ndarray
Downscaled cutout
"""
if scale_factor <= 1:
return cutout
scale_factor = int(scale_factor)
h, w = cutout.shape
# Trim to multiple of scale_factor
new_h = (h // scale_factor) * scale_factor
new_w = (w // scale_factor) * scale_factor
cutout_trimmed = cutout[:new_h, :new_w]
# Reshape and aggregate
reshaped = cutout_trimmed.reshape(
new_h // scale_factor, scale_factor, new_w // scale_factor, scale_factor
)
if mode == 'median':
downscaled = np.median(reshaped, axis=(1, 3))
else: # mean
downscaled = np.mean(reshaped, axis=(1, 3))
return downscaled
def _infer_cutout_radius_from_model(model, default_radius=15, log=None):
"""Infer cutout radius from model input shape."""
if log is None:
log = lambda *args, **kwargs: None
shape = model.input_shape
if isinstance(shape, list):
shape = shape[0] if shape else None
if shape is None or len(shape) < 4:
log(
f"Model input shape unavailable; using default cutout size {2 * default_radius + 1} "
f"(radius {default_radius})"
)
return default_radius
height, width = shape[1], shape[2]
channels = shape[3]
if height is None or width is None:
log(
f"Model input shape is dynamic; using default cutout size {2 * default_radius + 1} "
f"(radius {default_radius})"
)
return default_radius
if height != width:
raise ValueError(
f"Model input shape must be square for cutout extraction (got {height}x{width})."
)
if height % 2 == 0:
raise ValueError(f"Model input size must be odd for symmetric cutouts (got {height}).")
radius = int((height - 1) // 2)
log(f"Using cutout size {height}x{width} (radius {radius}) from model input shape")
if channels is not None and channels != 2:
log(f"Warning: model expects {channels} channels; realbogus generates 2-channel cutouts")
return radius
def _upscale_cutout(cutout, scale_factor):
"""
Upscale cutout by integer factor using pixel replication.
Parameters
----------
cutout : ndarray
2D image cutout
scale_factor : int or float
Upscaling factor (converted to integer, must be >= 1).
Each pixel is replicated into scale_factor x scale_factor block.
Returns
-------
upscaled : ndarray
Upscaled cutout
"""
if scale_factor <= 1:
return cutout
scale_factor = int(scale_factor)
# Use numpy repeat for integer upscaling
# Repeat along each axis
upscaled = np.repeat(cutout, scale_factor, axis=0)
upscaled = np.repeat(upscaled, scale_factor, axis=1)
return upscaled
def _pad_to_size(cutout, target_size, mode='edge'):
"""
Pad cutout to target size (centered).
Parameters
----------
cutout : ndarray
2D image cutout
target_size : int
Target size (square)
mode : str, optional
Padding mode for np.pad. Default: 'edge'
Returns
-------
padded : ndarray
Padded cutout (target_size x target_size)
"""
h, w = cutout.shape
# Crop dimensions that exceed target
if h > target_size:
start_h = (h - target_size) // 2
cutout = cutout[start_h : start_h + target_size, :]
h = target_size
if w > target_size:
start_w = (w - target_size) // 2
cutout = cutout[:, start_w : start_w + target_size]
w = target_size
# Pad dimensions that are smaller than target
pad_h = target_size - h
pad_w = target_size - w
if pad_h > 0 or pad_w > 0:
cutout = np.pad(
cutout, ((pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)), mode=mode
)
return cutout
[docs]
def preprocess_cutout(
cutout_sci,
cutout_bg=None,
cutout_err=None,
fwhm=None,
target_fwhm=TARGET_FWHM,
target_size=31,
downscale_threshold=1.5,
normalize=True,
asinh_softening=None,
):
"""
Preprocess cutout for CNN input.
Steps:
1. Optional scaling (downscale or upscale) to canonical FWHM
2. Create 2-channel input (background-subtracted linear, asinh-scaled)
3. Peak normalization (each channel normalized by its own peak value)
4. Pad/crop to target size
FWHM Scaling Strategy (Symmetric):
- Downscaling (FWHM > target_fwhm × threshold): Integer block averaging
- No scaling (target_fwhm / threshold ≤ FWHM ≤ target_fwhm × threshold): Keep as-is
- Upscaling (FWHM < target_fwhm / threshold): Integer pixel replication
Default: target_fwhm=3.0, threshold=1.5
→ Downscale if FWHM > 4.5, upscale if FWHM < 2.0, else unchanged
This ensures all PSFs normalized to approximately the same size regardless of
sharpness, eliminating FWHM as a confounding variable.
Channel Design:
- Channel 0: Background-subtracted (linear scale), peak-normalized
- Channel 1: Asinh-scaled background-subtracted, peak-normalized
Peak normalization makes the representation brightness-invariant: all sources
(faint to extremely bright) are scaled to [-1, 1] range based on their peak
value. This allows the CNN to learn pure morphological features that generalize
to ANY brightness level, including sources far brighter than the training set.
The asinh channel complements the linear channel by providing compressed
dynamic range information useful for distinguishing extended vs. compact sources.
Parameters
----------
cutout_sci : ndarray
Science image cutout (assumed to be background-subtracted if cutout_bg is None)
cutout_bg : ndarray, optional
Background cutout (or scalar value). If None, estimated from cutout edges.
cutout_err : ndarray or float, optional
Error/noise cutout (or scalar value). Used to estimate the noise level
(sigma) for asinh softening. Only the median value is used.
fwhm : float, optional
Image FWHM in pixels. If provided, cutout will be downscaled to target_fwhm.
target_fwhm : float, optional
Target FWHM for downscaling normalization. Default: 3.0
target_size : int, optional
Target cutout size (square). Default: 31
downscale_threshold : float, optional
Only downscale if fwhm/target_fwhm > threshold. Default: 1.5
normalize : bool, optional
Apply peak normalization to each channel (scales to [-1, 1] range).
Default: True. This makes the representation brightness-invariant.
asinh_softening : float, optional
Asinh softening in units of background sigma. If None, uses
DEFAULT_ASINH_SOFTENING_SIGMA. Actual softening is
(asinh_softening * sigma), where sigma is estimated from cutout_err.
Returns
-------
preprocessed : ndarray
Preprocessed cutout (target_size, target_size, 2)
scale_factor : float
Applied scale factor (for diagnostics)
"""
# Handle background
if cutout_bg is None:
# Estimate from edges
edge_pixels = np.concatenate(
[cutout_sci[0, :], cutout_sci[-1, :], cutout_sci[:, 0], cutout_sci[:, -1]]
)
cutout_bg = np.median(edge_pixels)
if np.isscalar(cutout_bg):
cutout_bg = np.full_like(cutout_sci, cutout_bg)
# Estimate noise sigma for asinh softening (only the median is needed,
# so compute before any scaling to avoid unnecessary array operations)
if cutout_err is None:
from astropy.stats import mad_std
sigma = float(mad_std(cutout_sci))
elif np.isscalar(cutout_err):
sigma = float(cutout_err)
else:
sigma = float(np.nanmedian(cutout_err))
if not np.isfinite(sigma) or sigma <= 0:
sigma = 1.0
# Scaling to canonical FWHM (both downscale and upscale)
scale_factor = 1.0
if fwhm is not None and fwhm > 0:
scale_factor = fwhm / target_fwhm
# Downscale if PSF too large
if scale_factor > downscale_threshold:
factor = round(scale_factor)
cutout_sci = _downscale_cutout(cutout_sci, factor)
cutout_bg = _downscale_cutout(cutout_bg, factor)
# Upscale if PSF too sharp (symmetric with downscaling)
elif scale_factor < 1.0 / downscale_threshold:
factor = round(1.0 / scale_factor)
cutout_sci = _upscale_cutout(cutout_sci, factor)
cutout_bg = _upscale_cutout(cutout_bg, factor)
# Create background-subtracted channel
cutout_bgsub = cutout_sci - cutout_bg
# Create asinh-scaled channel for dynamic range compression
if asinh_softening is None:
asinh_softening = DEFAULT_ASINH_SOFTENING_SIGMA
softening = float(asinh_softening) * sigma
if not np.isfinite(softening) or softening <= 0:
softening = sigma
# Asinh scaling: compresses high values, ~linear for low values
cutout_asinh = np.arcsinh(cutout_bgsub / softening)
# Stack channels: [background-subtracted linear, asinh-scaled]
channels = [cutout_bgsub, cutout_asinh]
# Peak normalization for brightness invariance
if normalize:
for i, ch in enumerate(channels):
peak = np.max(np.abs(ch))
if peak > 1e-10:
channels[i] = ch / peak
# Pad/crop to target size
channels = [_pad_to_size(ch, target_size) for ch in channels]
# Stack into 2-channel image
preprocessed = np.stack(channels, axis=-1).astype(np.float32)
return preprocessed, scale_factor
[docs]
def load_realbogus_model(model_file=None, verbose=False):
"""
Load pre-trained real-bogus model.
Parameters
----------
model_file : str, optional
Path to model file (.h5 or SavedModel directory).
If None, loads default model from ~/.stdpipe/models/
verbose : bool, optional
Print loading information. Default: False
Returns
-------
model : keras.Model
Loaded Keras model
"""
_check_tensorflow()
log = print if verbose else lambda *args, **kwargs: None
# Default model path
if model_file is None:
model_file = os.path.join(DEFAULT_MODEL_DIR, DEFAULT_MODEL_NAME)
if not os.path.exists(model_file):
raise FileNotFoundError(
f"Model file not found: {model_file}\n"
"Train a model using train_realbogus_classifier() or specify model_file."
)
log(f"Loading model from {model_file}")
model = keras.models.load_model(model_file)
log(f"Model loaded successfully ({model.count_params()} parameters)")
return model
[docs]
def save_realbogus_model(model, model_file=None, verbose=False):
"""
Save trained real-bogus model.
Parameters
----------
model : keras.Model
Trained model
model_file : str, optional
Output path. If None, saves to ~/.stdpipe/models/realbogus_default.h5
verbose : bool, optional
Print saving information. Default: False
"""
_check_tensorflow()
log = print if verbose else lambda *args, **kwargs: None
# Default model path
if model_file is None:
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
model_file = os.path.join(DEFAULT_MODEL_DIR, DEFAULT_MODEL_NAME)
# Ensure directory exists
model_dir = os.path.dirname(model_file)
if model_dir:
os.makedirs(model_dir, exist_ok=True)
log(f"Saving model to {model_file}")
model.save(model_file)
log("Model saved successfully")
[docs]
def classify_realbogus(
obj,
image,
model=None,
model_file=None,
bg=None,
err=None,
mask=None,
fwhm=None,
asinh_softening=None,
threshold=0.5,
add_score=True,
flag_bogus=True,
batch_size=128,
verbose=False,
):
"""
Classify detected objects as real or bogus using CNN.
This is the main entry point for real-bogus classification.
Parameters
----------
obj : astropy.table.Table
Object catalog with 'x' and 'y' columns (from photometry.get_objects_*)
image : ndarray
Science image
model : keras.Model, optional
Pre-loaded model. If None, loads from model_file.
model_file : str, optional
Path to model file. If None, uses default model.
bg : ndarray or float, optional
Background map or scalar value
err : ndarray or float, optional
Error/noise map or scalar value
mask : ndarray, optional
Boolean mask (True = masked pixels)
cutout size : derived
Cutout size is inferred from the model input shape. If the model has
dynamic spatial dimensions, defaults to 31x31 (radius 15).
fwhm : float, optional
Image FWHM. If None, estimated from catalog.
asinh_softening : float, optional
Asinh softening in units of background sigma. If None, uses
DEFAULT_ASINH_SOFTENING_SIGMA.
threshold : float, optional
Classification threshold (0-1). Objects with score > threshold are real.
Default: 0.5
add_score : bool, optional
Add 'rb_score' column to output catalog. Default: True
flag_bogus : bool, optional
Set flags=0x1000 for bogus objects and filter them out. Default: True
batch_size : int, optional
Batch size for inference. Default: 128
verbose : bool or callable, optional
Print progress. Can be callable for custom logging. Default: False
Returns
-------
obj_filtered : astropy.table.Table
Filtered catalog with real sources only (if flag_bogus=True)
or full catalog with 'rb_score' column (if flag_bogus=False)
Examples
--------
>>> from stdpipe import photometry, realbogus
>>> obj = photometry.get_objects_sep(image, thresh=3.0)
>>> obj_clean = realbogus.classify_realbogus(obj, image)
>>> print(f"Kept {len(obj_clean)}/{len(obj)} objects")
"""
_check_tensorflow()
# Handle verbose as callable
log = (verbose if callable(verbose) else print) if verbose else lambda *args, **kwargs: None
# Load model if not provided
if model is None:
model = load_realbogus_model(model_file=model_file, verbose=verbose)
cutout_radius = _infer_cutout_radius_from_model(model, default_radius=15, log=log)
log(f"Classifying {len(obj)} objects (threshold={threshold:.2f})")
# Extract cutouts
cutouts, valid_indices = extract_cutouts(
obj,
image,
bg=bg,
err=err,
mask=mask,
radius=cutout_radius,
fwhm=fwhm,
asinh_softening=asinh_softening,
verbose=verbose,
)
# Batch inference
log(f"Running inference on {len(cutouts)} cutouts (batch_size={batch_size})")
predictions = model.predict(cutouts, batch_size=batch_size, verbose=1 if verbose else 0)
# Flatten predictions
scores = predictions.flatten()
# Create full score array (NaN for invalid objects)
full_scores = np.full(len(obj), np.nan, dtype=float)
full_scores[valid_indices] = scores
# Add score column if requested
if add_score:
if 'rb_score' in obj.colnames:
obj['rb_score'] = full_scores
else:
obj.add_column(Column(full_scores, name='rb_score'))
# Flag and filter bogus objects
if flag_bogus:
# Ensure flags column exists
if 'flags' not in obj.colnames:
obj.add_column(Column(np.zeros(len(obj), dtype=int), name='flags'))
# Set bogus flag (0x1000)
is_bogus = full_scores < threshold
is_bogus[np.isnan(full_scores)] = True # Flag invalid objects as bogus
obj['flags'][is_bogus] |= 0x1000
# Filter
obj_filtered = obj[~is_bogus]
log(
f"Filtered: {len(obj_filtered)}/{len(obj)} objects retained "
f"({100 * len(obj_filtered) / len(obj):.1f}%)"
)
return obj_filtered
else:
return obj
[docs]
def train_realbogus_classifier(
training_data=None,
n_simulated=1000,
image_size=(2048, 2048),
fwhm_range=(1.5, 8.0),
real_source_types=['star'],
validation_split=0.15,
model=None,
model_file=None,
epochs=50,
batch_size=64,
class_weight='balanced',
callbacks=None,
verbose=True,
):
"""
Train real-bogus classifier on simulated or real data.
Parameters
----------
training_data : tuple or dict, optional
Pre-generated training data: (X, y, fwhm_features) tuple or dict with 'X', 'y', 'fwhm' keys.
If None, generates simulated data using simulation.generate_realbogus_training_data().
n_simulated : int, optional
Number of simulated images to generate (if training_data=None). Default: 1000
image_size : tuple, optional
Size of simulated images (width, height). Default: (2048, 2048)
fwhm_range : tuple, optional
Range of FWHM values for simulated images. Default: (1.5, 8.0)
real_source_types : list, optional
List of source types to consider 'real' (if training_data=None).
Default: ['star'] treats only stars as real and galaxies as bogus.
Use ['star', 'galaxy'] to train a classifier that treats both as real.
validation_split : float, optional
Fraction of data for validation. Default: 0.15
model : keras.Model, optional
Model to train. If None, creates new model.
model_file : str, optional
Path to save trained model. Default: ~/.stdpipe/models/realbogus_default.h5
epochs : int, optional
Training epochs. Default: 50
batch_size : int, optional
Batch size. Default: 64
class_weight : str or dict, optional
Class weights for imbalanced data. 'balanced' or dict {0: w0, 1: w1}.
Default: 'balanced'
callbacks : list, optional
Keras callbacks (e.g., early stopping, checkpoints)
verbose : bool, optional
Print training progress. Default: True
Returns
-------
model : keras.Model
Trained model
history : keras.callbacks.History
Training history
Examples
--------
>>> from stdpipe import realbogus
>>> # Train on simulated data (stars and galaxies as real)
>>> model, history = realbogus.train_realbogus_classifier(
... n_simulated=500,
... epochs=30,
... verbose=True
... )
>>> # Train stars-only classifier (galaxies as bogus)
>>> model, history = realbogus.train_realbogus_classifier(
... n_simulated=500,
... real_source_types=['star'],
... epochs=30,
... verbose=True
... )
>>> # Or use pre-generated data
>>> data = realbogus.generate_training_data(...)
>>> model, history = realbogus.train_realbogus_classifier(
... training_data=(data['X'], data['y']),
... epochs=30
... )
"""
_check_tensorflow()
log = print if verbose else lambda *args, **kwargs: None
# Generate or load training data
if training_data is None:
log(f"Generating training data from {n_simulated} simulated images...")
from . import simulation
data = simulation.generate_realbogus_training_data(
n_images=n_simulated,
image_size=image_size,
fwhm_range=fwhm_range,
real_source_types=real_source_types,
augment=True,
verbose=verbose,
)
X = data['X']
y = data['y']
else:
# Handle both tuple and dict formats
if isinstance(training_data, dict):
X = training_data['X']
y = training_data['y']
elif len(training_data) == 3:
X, y, _fwhm_features = training_data # Legacy 3-tuple format
else:
X, y = training_data
log(f"Training data: {len(X)} samples, {np.sum(y)} real, {len(y) - np.sum(y)} bogus")
# Create model if not provided
if model is None:
input_shape = X.shape[1:] # (height, width, channels)
model = create_realbogus_model(input_shape=input_shape)
model.summary(print_fn=log)
# Calculate class weights
if class_weight == 'balanced':
n_real = np.sum(y)
n_bogus = len(y) - n_real
if n_real == 0 or n_bogus == 0:
log("Warning: Only one class present in training data; disabling class weighting")
class_weight = None
else:
class_weight = {0: len(y) / (2 * n_bogus), 1: len(y) / (2 * n_real)}
log(f"Class weights: {class_weight}")
# Default callbacks
if callbacks is None:
callbacks = [
keras.callbacks.EarlyStopping(
monitor='val_loss', patience=10, restore_best_weights=True
),
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6
),
]
# Train model (pure image-based, no FWHM auxiliary input)
log("Starting training...")
history = model.fit(
X,
y,
validation_split=validation_split,
epochs=epochs,
batch_size=batch_size,
class_weight=class_weight,
callbacks=callbacks,
verbose=1 if verbose else 0,
)
log("Training complete")
# Save model only when explicitly requested
if model_file is not None:
save_realbogus_model(model, model_file=model_file, verbose=verbose)
return model, history