diff --git a/deep_field_metadetect/jaxify/jax_detection.py b/deep_field_metadetect/jaxify/jax_detection.py new file mode 100644 index 0000000..b63b9ca --- /dev/null +++ b/deep_field_metadetect/jaxify/jax_detection.py @@ -0,0 +1,446 @@ +from functools import partial +from typing import Tuple, Union + +import jax +import jax.numpy as jnp + + +@partial(jax.jit, static_argnames=["window_size"]) +def local_maxima_filter( + image: jnp.ndarray, + noise: Union[jnp.ndarray, float], + window_size: int = 5, +) -> jnp.ndarray: + """ + Find local maximas in an image within window_size + + Parameters: + ----------- + image : jnp.ndarray + 2D Input galaxy field + window_size : int + Size of the neighborhood for local maximum detection + noise : jnp.ndarray | float + Pixelwise noise sigma + Minimum pixel value of central pixel must > 3-sigma + + Returns: + -------- + jnp.ndarray + Binary mask indicating local maxima positions + """ + noise_array = jnp.broadcast_to(noise, image.shape) if jnp.isscalar(noise) else noise + + pad_size = window_size // 2 + padded_image = jnp.pad(image, pad_size, mode="constant", constant_values=-jnp.inf) + + def is_local_max(i, j): + center_val = padded_image[i + pad_size, j + pad_size] + threshold = 3 * noise_array[i, j] # noise is not padded + + neighborhood = jax.lax.dynamic_slice( + padded_image, (i, j), (window_size, window_size) + ) + + return (jnp.all(center_val >= neighborhood)) & (threshold < center_val) + + height, width = image.shape + i_indices, j_indices = jnp.meshgrid( + jnp.arange(height), jnp.arange(width), indexing="ij" + ) + + local_max_mask = jax.vmap( + jax.vmap( + is_local_max, + in_axes=(0, 0), + ), + in_axes=(0, 0), + )(i_indices, j_indices) + + return local_max_mask + + +@partial(jax.jit, static_argnames=["window_size", "max_objects"]) +def peak_finder( + image: jnp.ndarray, + noise: Union[jnp.ndarray, float], + window_size: int = 5, + max_objects: int = 100, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Find peaks in an image above a threshold + + Parameters: + ----------- + image : jnp.ndarray + 2D Input galaxy field + noise : jnp.ndarray | float + Pixelwise noise sigma + Minimum pixel value of central pixel must be > 3-sigma + window_size : int + Size of the neighborhood for local maximum detection + max_objects : int + Maximum number of objects to detect (to make functions jitable) + + Returns: + -------- + positions : jnp.ndarray + Array of peak coordinates (y, x) of shape (max_objects, 2) + Invalid entries filled with (-999, -999) + """ + local_max_mask = local_maxima_filter( + image=image, + noise=noise, + window_size=window_size, + ) + + positions = jnp.argwhere(local_max_mask, size=max_objects, fill_value=(-999, -999)) + + return positions + + +@partial(jax.jit, static_argnames=["window_size"]) +def refine_centroid( + image: jnp.ndarray, peak: Tuple[int, int], window_size: int = 5 +) -> Tuple[float, float, bool]: + """ + Refine peak position of single object using intensity-weighted centroid. + Skips refinement for objects too close to the border. + Returns whether object was near border for warning purposes. + + Parameters: + ----------- + image : jnp.ndarray + 2D Input galaxy field + peak: jnp.ndarray + Initial peak position + window_size : int + Size of window around peak for centroid calculation + if window crosses image boudary, optimization is skipped. + + Returns: + -------- + jnp.ndarray + Refined peak coordinates (refined_y, refined_x) : float + Note: original coordinatesare returned if near border + near_border : bool + True if object was near border and refinement was skipped + """ + half_window = window_size // 2 + height, width = image.shape + + # If near border, return original coordinates + near_border = ( + (peak[0] < half_window) + | (peak[0] >= height - half_window) + | (peak[1] < half_window) + | (peak[1] >= width - half_window) + ) + + def border_case(): + return jnp.array([peak[0], peak[1]]).astype(float) + + def normal_case(): + window = jax.lax.dynamic_slice( + image, + (peak[0] - half_window, peak[1] - half_window), + (window_size, window_size), + ) + + y_start = -half_window + x_start = -half_window + y_coords = jnp.arange(y_start, y_start + window_size) + x_coords = jnp.arange(x_start, x_start + window_size) + y_grid, x_grid = jnp.meshgrid(y_coords, x_coords, indexing="ij") + + total_intensity = jnp.sum(window) + + y_shift = jnp.sum((y_grid) * window) / total_intensity + x_shift = jnp.sum((x_grid) * window) / total_intensity + + refined_y = y_shift + peak[0] + refined_x = x_shift + peak[1] + + return jnp.array([refined_y, refined_x]) + + result = jax.lax.cond(near_border, border_case, normal_case) + + return jnp.array([result[0], result[1]]), near_border + + +@partial(jax.jit, static_argnames=["window_size"]) +def refine_centroid_in_cell( + image: jnp.ndarray, + peak_positions: jnp.ndarray, + window_size: int = 5, +): + """ + vmapped version of refine_centroid + """ + return jax.vmap(refine_centroid, in_axes=(None, 0, None))( + image, peak_positions, window_size + ) + + +@partial(jax.jit, static_argnames=["window_size", "refine_centroids", "max_objects"]) +def detect_galaxies( + image: jnp.ndarray, + noise: Union[jnp.ndarray, float], + window_size: int = 5, + refine_centroids: bool = True, + max_objects: int = 100, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Complete galaxy center detection pipeline with JIT compilation support. + + Parameters: + ----------- + image : jnp.ndarray + 2D Input galaxy field + noise : jnp.ndarray | float + Pixelwise noise sigma + Minimum pixel value of central pixel must be > 3-sigma + window_size : int + Minimum distance between detected peaks + refine_centroids : bool + Whether to refine peak positions using centroid calculation + max_objects : int + Maximum number of objects to detect (for fixed array sizes) + + Returns: + -------- + peak_positions : jnp.ndarray + Array of detected galaxy centers (y, x) of shape (max_objects, 2). + Returns only the integral pixel location. + Invalid entries filled with -999 + refined_positions : jnp.ndarray + Array of detected galaxy centers (y, x) after centroid refinement. + Returns the refined floating point values of the center. + border_flags : jnp.ndarray + Array indicating which objects were near border (shape max_objects,) + """ + peak_positions = peak_finder( + image=image, + noise=noise, + window_size=window_size, + max_objects=max_objects, + ) + + if not refine_centroids: + border_flags = jnp.zeros(max_objects, dtype=bool) + return peak_positions, peak_positions.astype(float), border_flags + + refined_positions, border_flags = refine_centroid_in_cell( + image, peak_positions, window_size=5 + ) + + return peak_positions, refined_positions, border_flags + + +@partial(jax.jit, static_argnames=["max_iterations"]) +def watershed_segmentation( + inverted_image: jnp.ndarray, + noise: Union[jnp.ndarray, float], + markers: jnp.ndarray, + mask: jnp.ndarray = None, + max_iterations: int = 30, +) -> jnp.ndarray: + """ + JAX implementation of watershed segmentation algorithm. + + Parameters: + ----------- + inverted_image : jnp.ndarray + 2D input image with inverted intensity + noise : jnp.ndarray | float + Pixelwise noise sigma + flooding continues to unmarked neightboring pixel within a limit of sigma + markers : jnp.ndarray + 2D array of initial markers (labeled regions) where positive values + indicate different watershed basins and 0 indicates unmarked pixels + mask : jnp.ndarray, optional + Binary mask indicating valid pixels for segmentation. + Pixels with non-zero masked values are masked. + max_iterations : int + Maximum number of iterations for the flooding process + + Returns: + -------- + labels : jnp.ndarray + 2D segmentation map with same shape as input image + """ + noise_array = ( + jnp.broadcast_to(noise, inverted_image.shape) if jnp.isscalar(noise) else noise + ) + + if mask is None: + mask = jnp.zeros_like(inverted_image, dtype=bool) + + labels = markers.copy() + height, width = inverted_image.shape + + def watershed_step(labels_prev): + """Single iteration of watershed flooding""" + labels_new = labels_prev.copy() + + def update_pixel(i, j): + # Skip if masked out + # Note: another option here would be skip if already labeled + current_label = labels_prev[i, j] + is_valid = ~mask[i, j] + + def check_neighbors(): + # Check 4-connected neighbors + neighbor_coords = jnp.array( + [[i - 1, j], [i + 1, j], [i, j - 1], [i, j + 1]] + ) + + in_bounds = ( + (neighbor_coords[:, 0] >= 0) + & (neighbor_coords[:, 0] < height) + & (neighbor_coords[:, 1] >= 0) + & (neighbor_coords[:, 1] < width) + ) + + neighbor_labels = labels_prev[ + neighbor_coords[:, 0], neighbor_coords[:, 1] + ] + neighbor_values = inverted_image[ + neighbor_coords[:, 0], neighbor_coords[:, 1] + ] + + # Mask for valid (labeled and in-bounds) neighbors + valid_mask = in_bounds & (neighbor_labels > 0) + + has_valid = jnp.any(valid_mask) + + def process_valid_neighbors(): + # Use large value for invalid neighbors in argmin + masked_values = jnp.where(valid_mask, neighbor_values, jnp.inf) + min_idx = jnp.argmin(masked_values) + + # Check if current pixel should be flooded + current_value = inverted_image[i, j] + current_noise = noise_array[i, j] + min_neighbor_value = neighbor_values[min_idx] + + # Flood if current value is >= minimum neighbor value + def should_flood(): + """Decides when to flood a pixel based on if it is marked""" + + def unmarked_pixel(): + return (current_value + current_noise) >= min_neighbor_value + + def marked_pixel(): + return (current_value) >= min_neighbor_value + + is_marked = current_label != 0 + + return jax.lax.cond( + is_marked, + marked_pixel, + unmarked_pixel, + ) + + to_flood = should_flood() + + # leave current value if update is not required + return jax.lax.cond( + to_flood, + lambda: neighbor_labels[min_idx], + lambda: current_label, + ) + + # If no valid neighbors, leave current value else process + return jax.lax.cond( + has_valid, process_valid_neighbors, lambda: current_label + ) + + # If pixel is not maked, check for neightbors + new_label = jax.lax.cond(is_valid, check_neighbors, lambda: current_label) + + return new_label + + # Vectorized update over all pixels + i_coords, j_coords = jnp.meshgrid( + jnp.arange(height), jnp.arange(width), indexing="ij" + ) + + labels_new = jax.vmap(jax.vmap(update_pixel, in_axes=(0, 0)), in_axes=(0, 0))( + i_coords, j_coords + ) + + return labels_new + + # Iterative flooding using scan + def scan_fn(labels_current, _): + labels_next = watershed_step(labels_current) + return labels_next, None + + final_labels, _ = jax.lax.scan(scan_fn, labels, jnp.arange(max_iterations)) + + return final_labels + + +@partial(jax.jit, static_argnames=["max_iterations"]) +def watershed_from_peaks( + image: jnp.ndarray, + noise: Union[jnp.ndarray, float], + peaks: jnp.ndarray, + mask: jnp.ndarray = None, + max_iterations: int = 30, +) -> jnp.ndarray: + """ + Perform watershed segmentation using detected peaks as markers. + + Parameters: + ----------- + image : jnp.ndarray + 2D input image + noise : jnp.ndarray | float + Pixelwise noise sigma + flooding continues to the neightboring pixel within a limit of sigma + peaks : jnp.ndarray + Array of peak positions (y, x) of shape (n_peaks, 2) + mask : jnp.ndarray + Array of masked pixels. + Pixels with non-zero masked values are masked. + max_iterations : int + Maximum iterations for watershed algorithm + + Returns: + -------- + watershed_labels : jnp.ndarray + 2D segmentation map from watershed algorithm + """ + height, width = image.shape + + inverted_image = -image # Invert so peaks become valleys + + markers = jnp.zeros((height, width), dtype=jnp.int32) + + # Place markers at peak positions + def place_marker(i, peak_pos): + y, x = peak_pos.astype(jnp.int32) + is_valid = (y >= 0) & (y < height) & (x >= 0) & (x < width) + + marker_value = jax.lax.cond(is_valid, lambda: i + 1, lambda: 0) # Label from 1 + + return jax.lax.cond( + is_valid, lambda: markers.at[y, x].set(marker_value), lambda: markers + ) + + # Sequential marker placement + for i in range(peaks.shape[0]): + markers = place_marker(i, peaks[i]) + + # Apply watershed algorithm + watershed_labels = watershed_segmentation( + inverted_image, + noise=noise, + markers=markers, + max_iterations=max_iterations, + mask=mask, + ) + + return watershed_labels diff --git a/deep_field_metadetect/jaxify/tests/test_jax_detection.py b/deep_field_metadetect/jaxify/tests/test_jax_detection.py new file mode 100644 index 0000000..242a2f7 --- /dev/null +++ b/deep_field_metadetect/jaxify/tests/test_jax_detection.py @@ -0,0 +1,356 @@ +import jax.numpy as jnp +import numpy as np + +from deep_field_metadetect.jaxify.jax_detection import ( + detect_galaxies, + local_maxima_filter, + peak_finder, + refine_centroid, + watershed_from_peaks, + watershed_segmentation, +) + + +def create_gaussian_blob(shape, center, sigma=1.0, amplitude=1.0): + """ + Create a 2D Gaussian blob for testing. + + Parameters: + ----------- + shape : tuple + Shape of the output array (height, width) + center : tuple + Center position (y, x) of the Gaussian + sigma : float + Standard deviation of the Gaussian + amplitude : float + Peak amplitude of the Gaussian + + Returns: + -------- + jnp.ndarray + 2D array containing the Gaussian blob + """ + y, x = jnp.meshgrid(jnp.arange(shape[0]), jnp.arange(shape[1]), indexing="ij") + cy, cx = center + + gaussian = amplitude * jnp.exp(-((y - cy) ** 2 + (x - cx) ** 2) / (2 * sigma**2)) + return gaussian + + +def create_multiple_gaussian_blobs(shape, centers, sigmas=None, amplitudes=None): + """ + Create multiple Gaussian blobs in a single image. + + Parameters: + ----------- + shape : tuple + Shape of the output array (height, width) + centers : list of tuples + List of center positions [(y1, x1), (y2, x2), ...] + sigmas : list of floats or None + Standard deviations for each blob. If None, uses 1.0 for all + amplitudes : list of floats or None + Amplitudes for each blob. If None, uses 1.0 for all + + Returns: + -------- + jnp.ndarray + 2D array containing all Gaussian blobs + """ + if sigmas is None: + sigmas = [1.0] * len(centers) + if amplitudes is None: + amplitudes = [1.0] * len(centers) + + image = jnp.zeros(shape) + for center, sigma, amplitude in zip(centers, sigmas, amplitudes): + blob = create_gaussian_blob(shape, center, sigma, amplitude) + image = image + blob + + return image + + +# ------------------- +# Test peak detection +# ------------------- + + +def single_gaussian(): + """Test detection of multiple well-separated Gaussian peaks.""" + centers = [(5, 5)] + amplitudes = [1.0] + max_objects = 10 + image = create_multiple_gaussian_blobs( + (10, 10), centers, sigmas=[1.0], amplitudes=amplitudes + ) + + noise = 0.0 + peak = peak_finder(image, noise=noise, max_objects=max_objects) + + assert len(peak) == max_objects + assert (peak[0][0] == centers[0][0]) & (peak[0][1] == centers[0][1]) + + +def test_multiple_separated_gaussians(): + """Test detection of multiple well-separated Gaussian peaks.""" + centers = [(2, 2), (2, 7), (7, 2), (7, 7)] + amplitudes = [1.0, 1.5, 2.0, 0.8] + image = create_multiple_gaussian_blobs( + (10, 10), centers, sigmas=[1.0] * 4, amplitudes=amplitudes + ) + + noise = 0.0 + result = local_maxima_filter(image, noise=noise, window_size=3) + + # All centers should be detected as peaks + for center in centers: + assert result[center[0], center[1]] + + +def test_threshold_filtering_gaussians(): + """Test that Gaussian peaks below threshold are filtered out.""" + centers = [(3, 3), (3, 9)] + amplitudes = [0.5, 2.0] # First below threshold, second above + image = create_multiple_gaussian_blobs((12, 12), centers, amplitudes=amplitudes) + + noise = 0.5 + result = local_maxima_filter(image, noise=noise, window_size=3) + + # Only the high amplitude peak should be detected + assert not result[3, 3] # Below threshold + assert result[3, 9] # Above threshold + + +def test_overlapping_gaussians(): + """Test behavior with overlapping Gaussian blobs.""" + # Two Gaussians close together + centers = [(4, 4), (4, 6)] + image = create_multiple_gaussian_blobs( + (9, 9), centers, sigmas=[1.5, 1.5], amplitudes=[1.0, 1.0] + ) + + noise = 0.0 + result = local_maxima_filter(image, noise=noise, window_size=3) + + # Depending on overlap, may detect one or both peaks + # At minimum, should detect at least one peak in the region + peak_region = result[3:6, 3:7] + assert jnp.any(peak_region) + + +def test_edge_case_detection(): + """Test detection of edge cases and boundary conditions.""" + # Single pixel "galaxy" + image = jnp.zeros((7, 7)) + image = image.at[5, 5].set(5.0) + + noise = 0.0 + peaks, refined, border_flags = detect_galaxies( + image=image, + noise=noise, + window_size=3, + refine_centroids=True, + max_objects=5, + ) + + print(peaks) + valid_peaks = peaks[peaks[:, 0] > 0] + + assert len(valid_peaks) == 1 + assert jnp.array_equal(valid_peaks[0], jnp.array([5, 5])) + + +# ------------------------ +# Test Centriod Refinement +# ------------------------ + + +def test_gaussian_centroid_refinement(): + """Test centroid refinement on slightly off-center Gaussian.""" + # Create Gaussian slightly off-grid + true_center = (4.3, 4.7) + image = create_gaussian_blob((9, 9), true_center, sigma=1.5, amplitude=2.0) + + # Start refinement from nearest grid point + initial_peak = (4, 5) + refined_peak, near_border = refine_centroid(image, initial_peak, window_size=5) + + # Refined position should be closer to true center + initial_distance = np.sqrt( + (initial_peak[0] - true_center[0]) ** 2 + + (initial_peak[1] - true_center[1]) ** 2 + ) + refined_distance = np.sqrt( + (refined_peak[0] - true_center[0]) ** 2 + + (refined_peak[1] - true_center[1]) ** 2 + ) + + assert refined_distance < initial_distance + assert not near_border + + +def test_near_border(): + """Test near border case.""" + # Create two overlapping Gaussians to make asymmetric peak + centers = [(3, 3)] + amplitudes = [1.0] + image = create_multiple_gaussian_blobs( + (5, 5), centers, sigmas=[1.0], amplitudes=amplitudes + ) + + refined_pos, near_border = refine_centroid(image, (4, 4), window_size=5) + + assert (refined_pos[0] == 4) & (refined_pos[1] == 4) # refined is same as input + assert near_border + + +# ----------------------------- +# Test galaxy dection in fields +# ----------------------------- + + +def test_complete_gaussian_detection(): + """Test complete detection pipeline on Gaussian galaxies.""" + centers = [(5, 5), (5, 15), (15, 5), (15, 15)] + amplitudes = [2.0, 1.5, 1.8, 1.2] + sigmas = [1.5, 1.2, 1.3, 1] + + image = create_multiple_gaussian_blobs( + (21, 21), centers, sigmas=sigmas, amplitudes=amplitudes + ) + + noise = 0.0 + peaks, refined, _ = detect_galaxies( + image, noise=noise, window_size=5, refine_centroids=True, max_objects=10 + ) + + valid_peaks = peaks[peaks[:, 0] > 0] + valid_refined = refined[peaks[:, 0] > 0] + + # Should detect all 4 galaxies + assert len(valid_peaks) == 4 + + assert np.all(valid_peaks == jnp.array(centers)) + + # Refinement should improve positions for off-grid centers + for i in range(len(valid_refined)): + # Refined positions should be reasonable + assert np.abs(np.asarray(centers)[i, 0] - valid_refined[i, 0]) < 0.5 + assert np.abs(np.asarray(centers)[i, 1] - valid_refined[i, 1]) < 0.5 + + +def test_detection_with_noise(): + """Test detection robustness with added noise.""" + np.random.seed(42) + # Create clean Gaussian + peak_location = (6, 6) + image_clean = create_gaussian_blob((15, 15), (6, 6), sigma=1.0, amplitude=2.0) + + # Add noise + noise = jnp.array(np.random.normal(0, 0.2, image_clean.shape)) + image_noisy = image_clean + noise + + noise = 0.2 + _, refined, _ = detect_galaxies( + image_noisy, noise=noise, window_size=5, refine_centroids=True, max_objects=5 + ) + + valid_peaks = refined[refined[:, 0] > 0] + + # Should still detect the main peak despite noise + assert len(valid_peaks) >= 1 + + # Main peak should be near expected position + main_peak = valid_peaks[0] + distance_to_true = np.sqrt( + (main_peak[0] - peak_location[0]) ** 2 + (main_peak[1] - peak_location[1]) ** 2 + ) + assert distance_to_true < 0.5 + + +def test_faint_galaxy_detection(): + """Test detection of faint galaxies.""" + centers = [(8, 6), (8, 12)] + amplitudes = [2.0, 0.8] # One bright, one faint + + image = create_multiple_gaussian_blobs( + (17, 17), centers, sigmas=[1.5, 1.5], amplitudes=amplitudes + ) + + # Test with threshold that should catch both + noise = 0.2 + peaks_low, _, _ = detect_galaxies(image, noise=noise, max_objects=5) + valid_low = peaks_low[peaks_low[:, 0] > 0] + + # Test with threshold that should only catch bright one + noise = 0.3 + peaks_high, _, _ = detect_galaxies(image, noise=noise, max_objects=5) + valid_high = peaks_high[peaks_high[:, 0] > 0] + + assert len(valid_low) == 2 + assert len(valid_high) == 1 + + +# ------------------------- +# Test watershed algorithm +# ------------------------- + + +def test_watershed_edge_cases(): + """Test watershed algorithm edge cases.""" + uniform_image = jnp.ones((3, 3)) * 2.0 + uniform_markers = jnp.zeros((3, 3), dtype=int) + uniform_markers = uniform_markers.at[1, 1].set(1) + + noise = 0.0 + result = watershed_segmentation( + uniform_image, noise, uniform_markers, max_iterations=5 + ) + # All pixels should eventually be labeled due to uniform flooding + assert jnp.all(result == 1) + + +def test_watershed_with_mask(): + """Test watershed segmentation with a mask.""" + image = jnp.ones((5, 5)) * 2.0 + image = image.at[1:4, 1:4].set(1.0) # Lower values in center + + # Create mask that excludes border pixels + mask = jnp.ones((5, 5), dtype=bool) + mask = mask.at[1:4, 1:4].set(False) + + markers = jnp.zeros((5, 5), dtype=int) + markers = markers.at[2, 2].set(1) + + noise = 0.0 + result = watershed_segmentation(image, noise, markers, mask=mask, max_iterations=5) + + # Only pixels within mask should be labeled + assert jnp.all(result[mask] == 0) + assert jnp.all(result[~mask] != 0) + + +def test_watershed_from_peaks_with_invalid(): + """Test watershed_from_peaks with invalid peak positions.""" + image = jnp.ones((5, 5)) + image = image.at[2, 2].set(3.0) + + # Include some invalid peaks (marked with -999) + peaks = jnp.array( + [ + [2, 2], # Valid peak + [1, 1], # Valid peak + [-999, -999], # Invalid peak + [-999, -999], # Invalid peak + ] + ) + + noise = 0.0 + result = watershed_from_peaks(image, noise, peaks, max_iterations=10) + + unique_labels = jnp.unique(result) + unique_labels = unique_labels[unique_labels > 0] + + # Should have exactly 2 regions (for the 2 valid peaks) + assert len(unique_labels) == 2