diff --git a/tests/test_focus_estimator.py b/tests/test_focus_estimator.py index 5e4442df..8c093195 100644 --- a/tests/test_focus_estimator.py +++ b/tests/test_focus_estimator.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import torch from waveorder import focus @@ -85,3 +86,73 @@ def test_focus_estimator_snr(tmp_path): assert plot_path.exists() if slice is not None: assert np.abs(slice - 10) <= 2 + + +def test_compute_midband_power(): + """Test the compute_midband_power function with torch tensors.""" + # Test parameters + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + midband_fractions = (0.125, 0.25) + + # Create test data + np.random.seed(42) + test_2d_np = np.random.random((64, 64)).astype(np.float32) + test_2d_torch = torch.from_numpy(test_2d_np) + + # Test the compute_midband_power function + result = focus.compute_midband_power( + test_2d_torch, NA_det, lambda_ill, ps, midband_fractions + ) + + # Check result properties + assert isinstance(result, torch.Tensor) + assert result.shape == torch.Size([]) # scalar tensor + assert result.item() > 0 # should be positive + + # Test with different midband fractions + result2 = focus.compute_midband_power( + test_2d_torch, NA_det, lambda_ill, ps, (0.1, 0.3) + ) + assert isinstance(result2, torch.Tensor) + assert result2.item() > 0 + + # Results should be different for different bands + assert abs(result.item() - result2.item()) > 1e-6 + + +def test_compute_midband_power_consistency(): + """Test that compute_midband_power is consistent with focus_from_transverse_band.""" + # Test parameters + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + midband_fractions = (0.125, 0.25) + + # Create 3D test data + np.random.seed(42) + test_3d = np.random.random((3, 32, 32)).astype(np.float32) + + # Test focus_from_transverse_band still works + focus_slice = focus.focus_from_transverse_band( + test_3d, NA_det, lambda_ill, ps, midband_fractions + ) + + assert isinstance(focus_slice, (int, np.integer)) + assert 0 <= focus_slice < test_3d.shape[0] + + # Manually compute midband power for each slice + manual_powers = [] + for z in range(test_3d.shape[0]): + power = focus.compute_midband_power( + torch.from_numpy(test_3d[z]), + NA_det, + lambda_ill, + ps, + midband_fractions, + ) + manual_powers.append(power.item()) + + expected_focus_slice = np.argmax(manual_powers) + assert focus_slice == expected_focus_slice diff --git a/waveorder/focus.py b/waveorder/focus.py index 38128f3c..38fcb521 100644 --- a/waveorder/focus.py +++ b/waveorder/focus.py @@ -3,11 +3,53 @@ import matplotlib.pyplot as plt import numpy as np +import torch from scipy.signal import peak_widths from waveorder import util +def compute_midband_power( + yx_array: torch.Tensor, + NA_det: float, + lambda_ill: float, + pixel_size: float, + midband_fractions: tuple[float, float] = (0.125, 0.25), +) -> torch.Tensor: + """Compute midband spatial frequency power by summing over a 2D midband donut. + + Parameters + ---------- + yx_array : torch.Tensor + 2D tensor in (Y, X) order. + NA_det : float + Detection NA. + lambda_ill : float + Illumination wavelength. + Units are arbitrary, but must match [pixel_size]. + pixel_size : float + Object-space pixel size = camera pixel size / magnification. + Units are arbitrary, but must match [lambda_ill]. + midband_fractions : tuple[float, float], optional + The minimum and maximum fraction of the cutoff frequency that define the midband. + Default is (0.125, 0.25). + + Returns + ------- + torch.Tensor + Sum of absolute FFT values in the midband region. + """ + _, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size) + frr = torch.tensor(np.sqrt(fxx**2 + fyy**2)) + xy_abs_fft = torch.abs(torch.fft.fftn(yx_array)) + cutoff = 2 * NA_det / lambda_ill + mask = torch.logical_and( + frr > cutoff * midband_fractions[0], + frr < cutoff * midband_fractions[1], + ) + return torch.sum(xy_abs_fft[mask]) + + def focus_from_transverse_band( zyx_array, NA_det, @@ -79,24 +121,20 @@ def focus_from_transverse_band( ) return 0 - # Calculate coordinates - _, Y, X = zyx_array.shape - _, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size) - frr = np.sqrt(fxx**2 + fyy**2) - - # Calculate fft - xy_abs_fft = np.abs(np.fft.fftn(zyx_array, axes=(1, 2))) - - # Calculate midband mask - cutoff = 2 * NA_det / lambda_ill - midband_mask = np.logical_and( - frr > cutoff * midband_fractions[0], - frr < cutoff * midband_fractions[1], + # Calculate midband power for each slice + midband_sum = np.array( + [ + compute_midband_power( + torch.from_numpy(zyx_array[z]), + NA_det, + lambda_ill, + pixel_size, + midband_fractions, + ).numpy() + for z in range(zyx_array.shape[0]) + ] ) - # Find slice index with min/max power in midband - midband_sum = np.sum(xy_abs_fft[:, midband_mask], axis=1) - if polynomial_fit_order is None: peak_index = minmaxfunc(midband_sum) else: