Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/test_focus_estimator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import torch

from waveorder import focus

Expand Down Expand Up @@ -85,3 +86,73 @@ def test_focus_estimator_snr(tmp_path):
assert plot_path.exists()
if slice is not None:
assert np.abs(slice - 10) <= 2


def test_compute_midband_power():
"""Test the compute_midband_power function with torch tensors."""
# Test parameters
ps = 6.5 / 100
lambda_ill = 0.532
NA_det = 1.4
midband_fractions = (0.125, 0.25)

# Create test data
np.random.seed(42)
test_2d_np = np.random.random((64, 64)).astype(np.float32)
test_2d_torch = torch.from_numpy(test_2d_np)

# Test the compute_midband_power function
result = focus.compute_midband_power(
test_2d_torch, NA_det, lambda_ill, ps, midband_fractions
)

# Check result properties
assert isinstance(result, torch.Tensor)
assert result.shape == torch.Size([]) # scalar tensor
assert result.item() > 0 # should be positive

# Test with different midband fractions
result2 = focus.compute_midband_power(
test_2d_torch, NA_det, lambda_ill, ps, (0.1, 0.3)
)
assert isinstance(result2, torch.Tensor)
assert result2.item() > 0

# Results should be different for different bands
assert abs(result.item() - result2.item()) > 1e-6


def test_compute_midband_power_consistency():
"""Test that compute_midband_power is consistent with focus_from_transverse_band."""
# Test parameters
ps = 6.5 / 100
lambda_ill = 0.532
NA_det = 1.4
midband_fractions = (0.125, 0.25)

# Create 3D test data
np.random.seed(42)
test_3d = np.random.random((3, 32, 32)).astype(np.float32)

# Test focus_from_transverse_band still works
focus_slice = focus.focus_from_transverse_band(
test_3d, NA_det, lambda_ill, ps, midband_fractions
)

assert isinstance(focus_slice, (int, np.integer))
assert 0 <= focus_slice < test_3d.shape[0]

# Manually compute midband power for each slice
manual_powers = []
for z in range(test_3d.shape[0]):
power = focus.compute_midband_power(
torch.from_numpy(test_3d[z]),
NA_det,
lambda_ill,
ps,
midband_fractions,
)
manual_powers.append(power.item())

expected_focus_slice = np.argmax(manual_powers)
assert focus_slice == expected_focus_slice
70 changes: 54 additions & 16 deletions waveorder/focus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,53 @@

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.signal import peak_widths

from waveorder import util


def compute_midband_power(
yx_array: torch.Tensor,
NA_det: float,
lambda_ill: float,
pixel_size: float,
midband_fractions: tuple[float, float] = (0.125, 0.25),
) -> torch.Tensor:
"""Compute midband spatial frequency power by summing over a 2D midband donut.

Parameters
----------
yx_array : torch.Tensor
2D tensor in (Y, X) order.
NA_det : float
Detection NA.
lambda_ill : float
Illumination wavelength.
Units are arbitrary, but must match [pixel_size].
pixel_size : float
Object-space pixel size = camera pixel size / magnification.
Units are arbitrary, but must match [lambda_ill].
midband_fractions : tuple[float, float], optional
The minimum and maximum fraction of the cutoff frequency that define the midband.
Default is (0.125, 0.25).

Returns
-------
torch.Tensor
Sum of absolute FFT values in the midband region.
"""
_, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size)
frr = torch.tensor(np.sqrt(fxx**2 + fyy**2))
xy_abs_fft = torch.abs(torch.fft.fftn(yx_array))
cutoff = 2 * NA_det / lambda_ill
mask = torch.logical_and(
frr > cutoff * midband_fractions[0],
frr < cutoff * midband_fractions[1],
)
return torch.sum(xy_abs_fft[mask])


def focus_from_transverse_band(
zyx_array,
NA_det,
Expand Down Expand Up @@ -79,24 +121,20 @@ def focus_from_transverse_band(
)
return 0

# Calculate coordinates
_, Y, X = zyx_array.shape
_, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
frr = np.sqrt(fxx**2 + fyy**2)

# Calculate fft
xy_abs_fft = np.abs(np.fft.fftn(zyx_array, axes=(1, 2)))

# Calculate midband mask
cutoff = 2 * NA_det / lambda_ill
midband_mask = np.logical_and(
frr > cutoff * midband_fractions[0],
frr < cutoff * midband_fractions[1],
# Calculate midband power for each slice
midband_sum = np.array(
[
compute_midband_power(
torch.from_numpy(zyx_array[z]),
NA_det,
lambda_ill,
pixel_size,
midband_fractions,
).numpy()
for z in range(zyx_array.shape[0])
]
)

# Find slice index with min/max power in midband
midband_sum = np.sum(xy_abs_fft[:, midband_mask], axis=1)

if polynomial_fit_order is None:
peak_index = minmaxfunc(midband_sum)
else:
Expand Down
Loading