Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/test_focus_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,26 @@ def test_focus_estimator(tmp_path):

plot_path = tmp_path.joinpath("test.pdf")
data3D = np.random.random((11, 256, 256))
slice = focus.focus_from_transverse_band(
slice, stats = focus.focus_from_transverse_band(
data3D, NA_det, lambda_ill, ps, plot_path=str(plot_path)
)
assert slice >= 0
assert slice <= data3D.shape[0]
assert plot_path.exists()
assert isinstance(stats, dict)
assert stats["peak_index"] == slice
assert stats["peak_FWHM"] > 0

# Check single slice
slice = focus.focus_from_transverse_band(
slice, stats = focus.focus_from_transverse_band(
np.random.random((1, 10, 10)),
NA_det,
lambda_ill,
ps,
)
assert slice == 0
assert stats["peak_index"] is None
assert stats["peak_FWHM"] is None


def test_focus_estimator_snr(tmp_path):
Expand Down Expand Up @@ -80,7 +85,7 @@ def test_focus_estimator_snr(tmp_path):
ps,
plot_path=plot_path,
threshold_FWHM=5,
)
)[0]
assert plot_path.exists()
if slice is not None:
assert np.abs(slice - 10) <= 2
10 changes: 7 additions & 3 deletions waveorder/focus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def focus_from_transverse_band(
return the index of the in-focus slice
else:
return None
peak_stats : dict
Dictionary with statistics of the detected peaks, currently 'peak_index' and 'peak_FWHM'.

Example
------
Expand All @@ -62,6 +64,7 @@ def focus_from_transverse_band(
>>> in_focus_data = data[slice,:,:]
"""
minmaxfunc = _mode_to_minmaxfunc(mode)
peak_stats = {'peak_index': None, 'peak_FWHM': None}

_check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
Expand All @@ -72,7 +75,7 @@ def focus_from_transverse_band(
warnings.warn(
"The dataset only contained a single slice. Returning trivial slice index = 0."
)
return 0
return 0, peak_stats

# Calculate coordinates
_, Y, X = zyx_array.shape
Expand All @@ -95,9 +98,10 @@ def focus_from_transverse_band(

peak_results = peak_widths(midband_sum, [peak_index])
peak_FWHM = peak_results[0][0]
peak_stats.update({'peak_index': peak_index, 'peak_FWHM': peak_FWHM})

if peak_FWHM >= threshold_FWHM:
in_focus_index = peak_index
in_focus_index = int(peak_index)
else:
in_focus_index = None

Expand All @@ -112,7 +116,7 @@ def focus_from_transverse_band(
threshold_FWHM,
)

return in_focus_index
return in_focus_index, peak_stats


def _mode_to_minmaxfunc(mode):
Expand Down
Loading