diff --git a/tests/test_focus_estimator.py b/tests/test_focus_estimator.py index 8c093195..2460917c 100644 --- a/tests/test_focus_estimator.py +++ b/tests/test_focus_estimator.py @@ -156,3 +156,156 @@ def test_compute_midband_power_consistency(): expected_focus_slice = np.argmax(manual_powers) assert focus_slice == expected_focus_slice + + +def test_subpixel_precision(): + """Test that sub-pixel precision returns float values when enabled.""" + # Test parameters + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + + # Create synthetic test data with a clear peak between slices + z_size, y_size, x_size = 11, 64, 64 + x = np.linspace(-1, 1, x_size) + y = np.linspace(-1, 1, y_size) + z = np.linspace(-5, 5, z_size) + + # Create a 3D Gaussian that peaks between slice indices + test_data = np.zeros((z_size, y_size, x_size)) + true_peak_z = 5.3 # Peak between slices 5 and 6 + + for i, z_val in enumerate(z): + # Create Gaussian centered at true_peak_z position in physical space + gaussian_2d = np.exp( + -( + (x[None, :] ** 2 + y[:, None] ** 2) + + (z_val - (true_peak_z - 5)) ** 2 + ) + ) + test_data[i] = gaussian_2d + + # Test without sub-pixel precision (should return integer) + focus_slice_int = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + enable_subpixel_precision=False, + ) + assert isinstance(focus_slice_int, (int, np.integer)) + + # Test with sub-pixel precision (should return float) + focus_slice_float = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + enable_subpixel_precision=True, + ) + + # Should return a float + assert isinstance(focus_slice_float, float) + + # Should be close to the true peak position + assert abs(focus_slice_float - true_peak_z) < 1.0 # Within 1 slice + + # Sub-pixel result should be different from integer result + assert focus_slice_float != focus_slice_int + + +def test_subpixel_precision_backward_compatibility(): + """Test that default behavior (integer results) is preserved.""" + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + + # Create simple test data + test_data = np.random.random((5, 32, 32)).astype(np.float32) + + # Test default behavior (should return integer) + focus_slice = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + ) + + assert isinstance(focus_slice, (int, np.integer)) + + +def test_subpixel_precision_with_plotting(tmp_path): + """Test that sub-pixel precision works with plotting.""" + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + + # Create test data + test_data = np.random.random((7, 32, 32)).astype(np.float32) + plot_path = tmp_path / "subpixel_test.pdf" + + # Should work without errors + focus_slice = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + enable_subpixel_precision=True, + plot_path=str(plot_path), + ) + + assert isinstance(focus_slice, float) + assert plot_path.exists() + + +def test_z_focus_offset_float_type(): + """Test that z_focus_offset can accept float values in settings.""" + from waveorder.cli.settings import FourierTransferFunctionSettings + + # Test that float values are accepted + settings = FourierTransferFunctionSettings(z_focus_offset=1.5) + assert settings.z_focus_offset == 1.5 + assert isinstance(settings.z_focus_offset, float) + + # Test that "auto" still works + settings_auto = FourierTransferFunctionSettings(z_focus_offset="auto") + assert settings_auto.z_focus_offset == "auto" + + # Test that integers are converted to float + settings_int = FourierTransferFunctionSettings(z_focus_offset=2) + assert settings_int.z_focus_offset == 2 + assert isinstance(settings_int.z_focus_offset, (int, float)) + + +def test_position_list_with_float_offset(): + """Test that _position_list_from_shape_scale_offset works correctly with float offsets.""" + from waveorder.cli.compute_transfer_function import ( + _position_list_from_shape_scale_offset, + ) + + # Test integer offset + pos_int = _position_list_from_shape_scale_offset(5, 1.0, 0) + expected_int = [2.0, 1.0, 0.0, -1.0, -2.0] + assert pos_int == expected_int + + # Test float offset + pos_float = _position_list_from_shape_scale_offset(5, 1.0, 0.5) + expected_float = [2.5, 1.5, 0.5, -0.5, -1.5] + assert pos_float == expected_float + + # Verify the difference is exactly the offset + import numpy as np + + diff = np.array(pos_float) - np.array(pos_int) + assert np.allclose(diff, 0.5) + + # Test with different scale and offset + pos_scaled = _position_list_from_shape_scale_offset(4, 2.0, 0.3) + # shape=4, shape//2=2, so indices are [0,1,2,3], + # positions are [(-0+2+0.3)*2, (-1+2+0.3)*2, (-2+2+0.3)*2, (-3+2+0.3)*2] = [4.6, 2.6, 0.6, -1.4] + expected_scaled = [4.6, 2.6, 0.6, -1.4] + assert np.allclose(pos_scaled, expected_scaled) diff --git a/waveorder/cli/settings.py b/waveorder/cli/settings.py index e28e3a15..2a885f25 100644 --- a/waveorder/cli/settings.py +++ b/waveorder/cli/settings.py @@ -62,7 +62,7 @@ class FourierTransferFunctionSettings(MyBaseModel): yx_pixel_size: PositiveFloat = 6.5 / 20 z_pixel_size: PositiveFloat = 2.0 z_padding: NonNegativeInt = 0 - z_focus_offset: Union[int, Literal["auto"]] = 0 + z_focus_offset: Union[float, Literal["auto"]] = 0 index_of_refraction_media: PositiveFloat = 1.3 numerical_aperture_detection: PositiveFloat = 1.2 diff --git a/waveorder/focus.py b/waveorder/focus.py index 38fcb521..c4401d8d 100644 --- a/waveorder/focus.py +++ b/waveorder/focus.py @@ -60,6 +60,7 @@ def focus_from_transverse_band( polynomial_fit_order: Optional[int] = None, plot_path: Optional[str] = None, threshold_FWHM: float = 0, + enable_subpixel_precision: bool = False, ): """Estimates the in-focus slice from a 3D stack by optimizing a transverse spatial frequency band. @@ -91,12 +92,16 @@ def focus_from_transverse_band( The default value, 0, applies no threshold, and the maximum midband power is always considered in focus. For values > 0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus. If the peak does not meet this threshold, the function returns None. + enable_subpixel_precision: bool, optional + If True and polynomial_fit_order is provided, enables sub-pixel precision focus detection + by finding the continuous extremum of the polynomial fit. Default is False for backward compatibility. Returns - ------ - slice : int or None + ------- + slice : int, float, or None If peak's FWHM > peak_width_threshold: - return the index of the in-focus slice + return the index of the in-focus slice (int if enable_subpixel_precision=False, + float if enable_subpixel_precision=True and polynomial_fit_order is not None) else: return None @@ -140,9 +145,44 @@ def focus_from_transverse_band( else: x = np.arange(len(midband_sum)) coeffs = np.polyfit(x, midband_sum, polynomial_fit_order) - peak_index = minmaxfunc(np.poly1d(coeffs)(x)) + poly_func = np.poly1d(coeffs) + + if enable_subpixel_precision: + # Find the continuous extremum using derivative + poly_deriv = np.polyder(coeffs) + # Find roots of the derivative (critical points) + critical_points = np.roots(poly_deriv) + + # Filter for real roots within the data range + real_critical_points = [] + for cp in critical_points: + if np.isreal(cp) and 0 <= cp.real < len(midband_sum): + real_critical_points.append(cp.real) + + if real_critical_points: + # Evaluate the polynomial at critical points to find extremum + critical_values = [ + poly_func(cp) for cp in real_critical_points + ] + if mode == "max": + best_idx = np.argmax(critical_values) + else: # mode == "min" + best_idx = np.argmin(critical_values) + peak_index = real_critical_points[best_idx] + else: + # Fall back to discrete maximum if no valid critical points + peak_index = float(minmaxfunc(poly_func(x))) + else: + peak_index = minmaxfunc(poly_func(x)) - peak_results = peak_widths(midband_sum, [peak_index]) + # For peak width calculation, use integer peak index + if enable_subpixel_precision and polynomial_fit_order is not None: + # Use the closest integer index for peak width calculation + integer_peak_index = int(np.round(peak_index)) + else: + integer_peak_index = int(peak_index) + + peak_results = peak_widths(midband_sum, [integer_peak_index]) peak_FWHM = peak_results[0][0] if peak_FWHM >= threshold_FWHM: @@ -215,9 +255,19 @@ def _plot_focus_metric( ): _, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.plot(midband_sum, "-k") + + # Handle floating-point peak_index for plotting + if isinstance(peak_index, float) and not peak_index.is_integer(): + # Use interpolation to get the y-value at the floating-point x-position + peak_y_value = np.interp( + peak_index, np.arange(len(midband_sum)), midband_sum + ) + else: + peak_y_value = midband_sum[int(peak_index)] + ax.plot( peak_index, - midband_sum[peak_index], + peak_y_value, "go" if in_focus_index is not None else "ro", ) ax.hlines(*peak_results[1:], color="k", linestyles="dashed")