diff --git a/docs/examples/models/phase_thick_3d_sector_illumination.py b/docs/examples/models/phase_thick_3d_sector_illumination.py new file mode 100644 index 00000000..60b85f39 --- /dev/null +++ b/docs/examples/models/phase_thick_3d_sector_illumination.py @@ -0,0 +1,182 @@ +""" +phase thick 3d with sector illumination +======================================== + +# 3D phase reconstruction with oblique sector illumination +# This example demonstrates multi-channel phase reconstruction where each channel +# corresponds to a different illumination sector angle. +""" + +import napari +import numpy as np +import torch + +from waveorder.models import phase_thick_3d + +# Parameters +# all lengths must use consistent units e.g. um +simulation_arguments = { + "zyx_shape": (100, 256, 256), + "yx_pixel_size": 6.5 / 63, + "z_pixel_size": 0.25, + "index_of_refraction_media": 1.3, +} +phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 5} +transfer_function_arguments = { + "z_padding": 0, + "wavelength_illumination": 0.532, + "numerical_aperture_illumination": 0.9, + "numerical_aperture_detection": 1.2, +} + +# Define 9 sector illumination angles +# 8 sectors at 45-degree intervals + 1 full aperture +sector_angle = 45 +sector_angle_offset = -22.5 +illumination_sector_angles = [ + ( + (i * sector_angle + sector_angle_offset) % 360, + ((i + 1) * sector_angle + sector_angle_offset) % 360, + ) + for i in range(8) +] + [(0, 360)] + +print(f"Using {illumination_sector_angles} illumination sectors") + +# Create a phantom +zyx_phase = phase_thick_3d.generate_test_phantom( + **simulation_arguments, **phantom_arguments +) + +# Calculate multi-channel transfer function (one for each sector) +( + real_potential_transfer_function, + imag_potential_transfer_function, +) = phase_thick_3d.calculate_transfer_function( + **simulation_arguments, + **transfer_function_arguments, + illumination_sector_angles=illumination_sector_angles, +) + +print( + f"Transfer function shape: {real_potential_transfer_function.shape}" +) # Should be (C, Z, Y, X) + +# Display complete multi-channel transfer function +viewer = napari.Viewer() +zyx_scale = np.array( + [ + simulation_arguments["z_pixel_size"], + simulation_arguments["yx_pixel_size"], + simulation_arguments["yx_pixel_size"], + ] +) + +# Add full CZYX transfer function (imaginary part) as single 4D layer +# Match the visualization style from add_transfer_function_to_viewer +czyx_shape = imag_potential_transfer_function.shape +voxel_scale = np.array( + [ + czyx_shape[1] * zyx_scale[0], # Z extent + czyx_shape[2] * zyx_scale[1], # Y extent + czyx_shape[3] * zyx_scale[2], # X extent + ] +) +lim = 0.5 * torch.max(torch.abs(imag_potential_transfer_function)).item() + +viewer.add_image( + torch.fft.ifftshift( + torch.imag(imag_potential_transfer_function), dim=(-3, -2, -1) + ) + .cpu() + .numpy(), + name="Imag pot. TF (CZYX)", + colormap="bwr", + contrast_limits=(-lim, lim), + scale=(1,) + tuple(1 / voxel_scale), # No scaling on C dimension +) + +# Set up XZ view with C and Y as sliders +viewer.dims.order = [0, 2, 1, 3] # (C, Y, Z, X) for XZ display +viewer.dims.current_step = ( + 0, + czyx_shape[1] // 2, + czyx_shape[2] // 2, + czyx_shape[3] // 2, +) + +input( + "Showing CZYX OTF in XZ view (use C and Y sliders). Press to continue..." +) +viewer.layers.select_all() +viewer.layers.remove_selected() + +# Simulate multi-channel data (one channel per sector) +# In practice, these would come from your microscope as separate acquisitions +zyx_data_multi_channel = [] +for c in range(len(illumination_sector_angles)): + zyx_data_channel = phase_thick_3d.apply_transfer_function( + zyx_phase, + real_potential_transfer_function[c], + transfer_function_arguments["z_padding"], + brightness=1e3, + ) + zyx_data_multi_channel.append(zyx_data_channel) + +# Stack into (C, Z, Y, X) tensor +zyx_data_multi_channel = torch.stack(zyx_data_multi_channel, dim=0) +print(f"Multi-channel data shape: {zyx_data_multi_channel.shape}") + +# Reconstruct phase from all channels combined +zyx_recon = phase_thick_3d.apply_inverse_transfer_function( + zyx_data_multi_channel, + real_potential_transfer_function, + imag_potential_transfer_function, + transfer_function_arguments["z_padding"], +) + +# Display +viewer.add_image(zyx_phase.numpy(), name="Phantom", scale=zyx_scale) +viewer.add_image( + zyx_data_multi_channel.numpy(), + name="Data (CZYX)", + scale=zyx_scale, +) +viewer.add_image(zyx_recon.numpy(), name="Reconstruction", scale=zyx_scale) + +# Show comparison with single channel (full aperture) for reference +print("\nComparing with single-channel (full aperture) reconstruction...") +( + real_tf_single, + imag_tf_single, +) = phase_thick_3d.calculate_transfer_function( + **simulation_arguments, + **transfer_function_arguments, + illumination_sector_angles=None, # Full aperture +) +zyx_data_single = phase_thick_3d.apply_transfer_function( + zyx_phase, + real_tf_single[0], # Single channel + transfer_function_arguments["z_padding"], + brightness=1e3, +) +zyx_recon_single = phase_thick_3d.apply_inverse_transfer_function( + zyx_data_single[None, ...], # Add channel dimension + real_tf_single, + imag_tf_single, + transfer_function_arguments["z_padding"], +) +viewer.add_image( + zyx_recon_single.numpy(), + name="Reconstruction (single channel)", + scale=zyx_scale, +) + +print( + f"\nReconstruction error (multi-channel): {torch.mean(torch.abs(zyx_recon - zyx_phase)).item():.6f}" +) +print( + f"Reconstruction error (single channel): {torch.mean(torch.abs(zyx_recon_single - zyx_phase)).item():.6f}" +) + +input("\nShowing phantom, data, and reconstructions. Press to quit...") diff --git a/waveorder/cli/apply_inverse_models.py b/waveorder/cli/apply_inverse_models.py index 995c2e5d..b02596f2 100644 --- a/waveorder/cli/apply_inverse_models.py +++ b/waveorder/cli/apply_inverse_models.py @@ -92,19 +92,19 @@ def phase( # [phase only, 3] elif recon_dim == 3: - # Load transfer functions + # Load transfer functions (keep channel dimension) real_potential_transfer_function = torch.tensor( - transfer_function_dataset["real_potential_transfer_function"][0, 0] + transfer_function_dataset["real_potential_transfer_function"][0] ) imaginary_potential_transfer_function = torch.tensor( transfer_function_dataset["imaginary_potential_transfer_function"][ - 0, 0 + 0 ] ) - # Apply + # Apply (pass full CZYX data) output = phase_thick_3d.apply_inverse_transfer_function( - czyx_data[0], + czyx_data, real_potential_transfer_function, imaginary_potential_transfer_function, z_padding=settings_phase.transfer_function.z_padding, diff --git a/waveorder/cli/compute_transfer_function.py b/waveorder/cli/compute_transfer_function.py index 4374b19c..9eecbebf 100644 --- a/waveorder/cli/compute_transfer_function.py +++ b/waveorder/cli/compute_transfer_function.py @@ -211,14 +211,12 @@ def generate_and_save_phase_transfer_function( # Save dataset.create_image( "real_potential_transfer_function", - real_potential_transfer_function.cpu().numpy()[None, None, ...], + real_potential_transfer_function.cpu().numpy()[None, ...], chunks=(1, 1, 1, zyx_shape[1], zyx_shape[2]), ) dataset.create_image( "imaginary_potential_transfer_function", - imaginary_potential_transfer_function.cpu().numpy()[ - None, None, ... - ], + imaginary_potential_transfer_function.cpu().numpy()[None, ...], chunks=(1, 1, 1, zyx_shape[1], zyx_shape[2]), ) @@ -367,14 +365,15 @@ def compute_transfer_function_cli( print("Found z_focus_offset:", z_focus_offset) # Prepare output dataset - num_channels = ( + num_input_channel = len(settings.input_channel_names) + num_output_channels = ( 2 if settings.reconstruction_dimension == 2 else 1 ) # space for SVD output_dataset = open_ome_zarr( output_dirpath, layout="fov", mode="w", - channel_names=num_channels * ["None"], + channel_names=num_input_channel * num_output_channels * ["None"], ) # Pass settings to appropriate calculate_transfer_function and save diff --git a/waveorder/cli/settings.py b/waveorder/cli/settings.py index 9ecbf628..1d3b9e94 100644 --- a/waveorder/cli/settings.py +++ b/waveorder/cli/settings.py @@ -1,7 +1,7 @@ import os import warnings from pathlib import Path -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional, Tuple, Union from pydantic.v1 import ( BaseModel, @@ -100,6 +100,7 @@ class PhaseTransferFunctionSettings( ): numerical_aperture_illumination: NonNegativeFloat = 0.5 invert_phase_contrast: bool = False + illumination_sector_angles: Optional[List[Tuple[float, float]]] = None @validator("numerical_aperture_illumination") def na_ill(cls, v, values): @@ -110,6 +111,27 @@ def na_ill(cls, v, values): ) return v + @validator("illumination_sector_angles") + def validate_sector_angles(cls, v): + if v is None: + return v + if len(v) == 0: + raise ValueError( + "illumination_sector_angles must contain at least one sector" + ) + normalized = [] + for start, end in v: + if start >= end: + raise ValueError( + f"Sector start angle {start} must be less than end angle {end}" + ) + # Normalize angles to [0, 360) using modulo 360 + # Special case: preserve 360 for full aperture (don't reduce to 0) + normalized_start = start % 360 + normalized_end = end % 360 if end % 360 != 0 else 360 + normalized.append((normalized_start, normalized_end)) + return normalized + class FluorescenceTransferFunctionSettings(FourierTransferFunctionSettings): wavelength_emission: PositiveFloat = 0.507 @@ -171,17 +193,37 @@ def validate_reconstruction_types(cls, values): '"fluorescence" cannot be present alongside "birefringence" or "phase". Please use one configuration file for a "fluorescence" reconstruction and another configuration file for a "birefringence" and/or "phase" reconstructions.' ) num_channel_names = len(values.get("input_channel_names")) - if values.get("birefringence") is None: - if ( - values.get("phase") is None - and values.get("fluorescence") is None - ): - raise ValueError( - "Provide settings for either birefringence, phase, birefringence + phase, or fluorescence." - ) - if num_channel_names != 1: - raise ValueError( - f"{num_channel_names} channels names provided. Please provide a single channel for fluorescence/phase reconstructions." - ) + + # Check for sector illumination in phase reconstruction + phase_settings = values.get("phase") + if phase_settings is not None: + sector_angles = ( + phase_settings.transfer_function.illumination_sector_angles + ) + if sector_angles is not None: + # Multi-channel reconstruction with sector illumination + if len(sector_angles) != num_channel_names: + raise ValueError( + f"Number of illumination_sector_angles ({len(sector_angles)}) must match number of input_channel_names ({num_channel_names})" + ) + else: + # Single channel phase reconstruction without sector illumination + if ( + values.get("birefringence") is None + and num_channel_names != 1 + ): + raise ValueError( + f"{num_channel_names} channels names provided. Please provide a single channel for phase reconstructions without sector illumination." + ) + else: + if values.get("birefringence") is None: + if values.get("fluorescence") is None: + raise ValueError( + "Provide settings for either birefringence, phase, birefringence + phase, or fluorescence." + ) + if num_channel_names != 1: + raise ValueError( + f"{num_channel_names} channels names provided. Please provide a single channel for fluorescence reconstructions." + ) return values diff --git a/waveorder/models/isotropic_thin_3d.py b/waveorder/models/isotropic_thin_3d.py index 16f253b9..4b89ce9b 100644 --- a/waveorder/models/isotropic_thin_3d.py +++ b/waveorder/models/isotropic_thin_3d.py @@ -44,6 +44,7 @@ def calculate_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + illumination_sector_angles: list[tuple[float, float]] = None, ) -> Tuple[Tensor, Tensor]: transverse_nyquist = sampling.transverse_nyquist( wavelength_illumination, @@ -52,45 +53,60 @@ def calculate_transfer_function( ) yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist)) - ( - absorption_2d_to_3d_transfer_function, - phase_2d_to_3d_transfer_function, - ) = _calculate_wrap_unsafe_transfer_function( + # Handle sector illumination case (or single channel with full aperture) + if illumination_sector_angles is None: + # Single channel with full aperture - wrap as [(0, 360)] + illumination_sector_angles = [(0, 360)] + + absorption_tfs = [] + phase_tfs = [] + + for start_angle, end_angle in illumination_sector_angles: ( - yx_shape[0] * yx_factor, - yx_shape[1] * yx_factor, - ), - yx_pixel_size / yx_factor, - z_position_list, - wavelength_illumination, - index_of_refraction_media, - numerical_aperture_illumination, - numerical_aperture_detection, - invert_phase_contrast=invert_phase_contrast, - ) + absorption_2d_to_3d_transfer_function, + phase_2d_to_3d_transfer_function, + ) = _calculate_wrap_unsafe_transfer_function( + ( + yx_shape[0] * yx_factor, + yx_shape[1] * yx_factor, + ), + yx_pixel_size / yx_factor, + z_position_list, + wavelength_illumination, + index_of_refraction_media, + numerical_aperture_illumination, + numerical_aperture_detection, + invert_phase_contrast=invert_phase_contrast, + sector_angle_start=start_angle, + sector_angle_end=end_angle, + ) - absorption_2d_to_3d_transfer_function_out = torch.zeros( - (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64 - ) - phase_2d_to_3d_transfer_function_out = torch.zeros( - (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64 - ) + absorption_2d_to_3d_transfer_function_out = torch.zeros( + (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64 + ) + phase_2d_to_3d_transfer_function_out = torch.zeros( + (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64 + ) - for z in range(len(z_position_list)): - absorption_2d_to_3d_transfer_function_out[z] = ( - sampling.nd_fourier_central_cuboid( - absorption_2d_to_3d_transfer_function[z], yx_shape + for z in range(len(z_position_list)): + absorption_2d_to_3d_transfer_function_out[z] = ( + sampling.nd_fourier_central_cuboid( + absorption_2d_to_3d_transfer_function[z], yx_shape + ) ) - ) - phase_2d_to_3d_transfer_function_out[z] = ( - sampling.nd_fourier_central_cuboid( - phase_2d_to_3d_transfer_function[z], yx_shape + phase_2d_to_3d_transfer_function_out[z] = ( + sampling.nd_fourier_central_cuboid( + phase_2d_to_3d_transfer_function[z], yx_shape + ) ) - ) + absorption_tfs.append(absorption_2d_to_3d_transfer_function_out) + phase_tfs.append(phase_2d_to_3d_transfer_function_out) + + # Always return (C, Z, Y, X) arrays, even for single channel return ( - absorption_2d_to_3d_transfer_function_out, - phase_2d_to_3d_transfer_function_out, + torch.stack(absorption_tfs, dim=0), + torch.stack(phase_tfs, dim=0), ) @@ -103,6 +119,8 @@ def _calculate_wrap_unsafe_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + sector_angle_start: float = None, + sector_angle_end: float = None, ) -> Tuple[Tensor, Tensor]: if numerical_aperture_illumination >= numerical_aperture_detection: print( @@ -119,11 +137,25 @@ def _calculate_wrap_unsafe_transfer_function( yx_shape, yx_pixel_size ) - illumination_pupil = optics.generate_pupil( - radial_frequencies, - numerical_aperture_illumination, - wavelength_illumination, - ) + # Generate illumination pupil (sector or full aperture) + if sector_angle_start is not None and sector_angle_end is not None: + fyy, fxx = util.generate_frequencies(yx_shape, yx_pixel_size) + illumination_pupil = optics.generate_sector_pupil( + radial_frequencies, + fxx, + fyy, + numerical_aperture_illumination, + wavelength_illumination, + sector_angle_start, + sector_angle_end, + ) + else: + illumination_pupil = optics.generate_pupil( + radial_frequencies, + numerical_aperture_illumination, + wavelength_illumination, + ) + detection_pupil = optics.generate_pupil( radial_frequencies, numerical_aperture_detection, @@ -162,25 +194,30 @@ def calculate_singular_system( phase_2d_to_3d_transfer_function: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """Calculates the singular system of the absoprtion and phase transfer - functions. + functions for multi-channel data. - Together, the transfer functions form a (2, Z, Vy, Vx) tensor, where - (2,) is the object-space dimension (abs, phase), (Z,) is the data-space - dimension, and (Vy, Vx) are the spatial frequency dimensions. + Together, the transfer functions form a (2, C, Z, Vy, Vx) tensor, where + (2,) is the object-space dimension (abs, phase), (C,) is the channel dimension, + (Z,) is the data-space dimension, and (Vy, Vx) are the spatial frequency dimensions. - The SVD is computed over the (2, Z) dimensions. + The SVD is computed over the (2, C*Z) dimensions, flattening channels and z together. Parameters ---------- absorption_2d_to_3d_transfer_function : Tensor - ZYX transfer function for absorption + CZYX transfer function for absorption phase_2d_to_3d_transfer_function : Tensor - ZYX transfer function for phase + CZYX transfer function for phase Returns ------- Tuple[Tensor, Tensor, Tensor] + U, S, Vh for the singular system """ + # absorption_2d_to_3d_transfer_function shape: (C, Z, Y, X) + # phase_2d_to_3d_transfer_function shape: (C, Z, Y, X) + + # Stack absorption and phase: (2, C, Z, Y, X) sfYX_transfer_function = torch.stack( ( absorption_2d_to_3d_transfer_function, @@ -188,6 +225,18 @@ def calculate_singular_system( ), dim=0, ) + + # Flatten C and Z dimensions: (2, C*Z, Y, X) + num_channels = sfYX_transfer_function.shape[1] + num_z = sfYX_transfer_function.shape[2] + sfYX_transfer_function = sfYX_transfer_function.reshape( + 2, + num_channels * num_z, + sfYX_transfer_function.shape[3], + sfYX_transfer_function.shape[4], + ) + + # Permute for SVD: (Y, X, 2, C*Z) YXsf_transfer_function = sfYX_transfer_function.permute(2, 3, 0, 1) Up, Sp, Vhp = torch.linalg.svd(YXsf_transfer_function, full_matrices=False) U = Up.permute(2, 3, 0, 1) @@ -289,17 +338,17 @@ def apply_inverse_transfer_function( TV_iterations: int = 10, bg_filter: bool = False, ) -> Tuple[Tensor, Tensor]: - """Reconstructs absorption and phase from zyx_data and a pair of - 3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and - phase_2d_to_3d_transfer_function, providing options for reconstruction - algorithms. + """Reconstructs absorption and phase from multi-channel zyx_data and the + singular system, combining all illumination channels into single absorption + and phase estimates. Parameters ---------- zyx_data : Tensor - 3D raw data, label-free defocus stack + Multi-channel 3D raw data with shape (C, Z, Y, X). + For single channel (full aperture), C=1. singular_system : Tuple[Tensor, Tensor, Tensor] - singular system of the transfer function bank + singular system of the multi-channel transfer function bank reconstruction_algorithm : Literal["Tikhonov";, "TV";], optional "Tikhonov" or "TV", by default "Tikhonov" "TV" is not implemented. @@ -318,16 +367,25 @@ def apply_inverse_transfer_function( Returns ------- Tuple[Tensor] - yx_absorption (unitless) - yx_phase (radians) + yx_absorption (unitless) with shape (Y, X) + yx_phase (radians) with shape (Y, X) Raises ------ NotImplementedError TV is not implemented """ + # zyx_data shape: (C, Z, Y, X) + num_channels = zyx_data.shape[0] + num_z = zyx_data.shape[1] + + # Flatten C and Z dimensions: (C*Z, Y, X) + czyx_data = zyx_data.reshape( + num_channels * num_z, zyx_data.shape[2], zyx_data.shape[3] + ) + # Normalize - zyx = util.inten_normalization(zyx_data, bg_filter=bg_filter) + czyx = util.inten_normalization(czyx_data, bg_filter=bg_filter) # TODO Consider refactoring with vectorial transfer function SVD if reconstruction_algorithm == "Tikhonov": @@ -338,7 +396,7 @@ def apply_inverse_transfer_function( "sj...,j...,jf...->fs...", U, S_reg, Vh ) - absorption_yx, phase_yx = apply_filter_bank(sfyx_inverse_filter, zyx) + absorption_yx, phase_yx = apply_filter_bank(sfyx_inverse_filter, czyx) # ADMM deconvolution with anisotropic TV regularization elif reconstruction_algorithm == "TV": diff --git a/waveorder/models/phase_thick_3d.py b/waveorder/models/phase_thick_3d.py index 5a2e2547..78e48d7a 100644 --- a/waveorder/models/phase_thick_3d.py +++ b/waveorder/models/phase_thick_3d.py @@ -43,7 +43,8 @@ def calculate_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, -) -> tuple[np.ndarray, np.ndarray]: + illumination_sector_angles: list[tuple[float, float]] = None, +) -> tuple[Tensor, Tensor]: transverse_nyquist = sampling.transverse_nyquist( wavelength_illumination, numerical_aperture_illumination, @@ -58,33 +59,54 @@ def calculate_transfer_function( yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist)) z_factor = int(np.ceil(z_pixel_size / axial_nyquist)) - ( - real_potential_transfer_function, - imag_potential_transfer_function, - ) = _calculate_wrap_unsafe_transfer_function( + # Handle sector illumination case (or single channel with full aperture) + if illumination_sector_angles is None: + # Single channel with full aperture - wrap as [(0, 360)] + illumination_sector_angles = [(0, 360)] + + real_tfs = [] + imag_tfs = [] + + for i, (start_angle, end_angle) in enumerate(illumination_sector_angles): + print( + f"Calculating transfer function {i+1}/{len(illumination_sector_angles)} for sector [{start_angle:.1f}, {end_angle:.1f}] degrees" + ) ( - zyx_shape[0] * z_factor, - zyx_shape[1] * yx_factor, - zyx_shape[2] * yx_factor, - ), - yx_pixel_size / yx_factor, - z_pixel_size / z_factor, - wavelength_illumination, - z_padding, - index_of_refraction_media, - numerical_aperture_illumination, - numerical_aperture_detection, - invert_phase_contrast=invert_phase_contrast, - ) + real_potential_transfer_function, + imag_potential_transfer_function, + ) = _calculate_wrap_unsafe_transfer_function( + ( + zyx_shape[0] * z_factor, + zyx_shape[1] * yx_factor, + zyx_shape[2] * yx_factor, + ), + yx_pixel_size / yx_factor, + z_pixel_size / z_factor, + wavelength_illumination, + z_padding, + index_of_refraction_media, + numerical_aperture_illumination, + numerical_aperture_detection, + invert_phase_contrast=invert_phase_contrast, + sector_angle_start=start_angle, + sector_angle_end=end_angle, + ) + zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:] + real_tfs.append( + sampling.nd_fourier_central_cuboid( + real_potential_transfer_function, zyx_out_shape + ) + ) + imag_tfs.append( + sampling.nd_fourier_central_cuboid( + imag_potential_transfer_function, zyx_out_shape + ) + ) - zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:] + # Always return (C, Z, Y, X) array, even for single channel return ( - sampling.nd_fourier_central_cuboid( - real_potential_transfer_function, zyx_out_shape - ), - sampling.nd_fourier_central_cuboid( - imag_potential_transfer_function, zyx_out_shape - ), + torch.stack(real_tfs, dim=0), + torch.stack(imag_tfs, dim=0), ) @@ -98,7 +120,9 @@ def _calculate_wrap_unsafe_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, -) -> tuple[np.ndarray, np.ndarray]: + sector_angle_start: float = None, + sector_angle_end: float = None, +) -> tuple[Tensor, Tensor]: radial_frequencies = util.generate_radial_frequencies( zyx_shape[1:], yx_pixel_size ) @@ -109,11 +133,25 @@ def _calculate_wrap_unsafe_transfer_function( if invert_phase_contrast: z_position_list = torch.flip(z_position_list, dims=(0,)) - ill_pupil = optics.generate_pupil( - radial_frequencies, - numerical_aperture_illumination, - wavelength_illumination, - ) + # Generate illumination pupil (sector or full aperture) + if sector_angle_start is not None and sector_angle_end is not None: + fyy, fxx = util.generate_frequencies(zyx_shape[1:], yx_pixel_size) + ill_pupil = optics.generate_sector_pupil( + radial_frequencies, + fxx, + fyy, + numerical_aperture_illumination, + wavelength_illumination, + sector_angle_start, + sector_angle_end, + ) + else: + ill_pupil = optics.generate_pupil( + radial_frequencies, + numerical_aperture_illumination, + wavelength_illumination, + ) + det_pupil = optics.generate_pupil( radial_frequencies, numerical_aperture_detection, @@ -150,8 +188,8 @@ def _calculate_wrap_unsafe_transfer_function( def visualize_transfer_function( viewer, - real_potential_transfer_function: np.ndarray, - imag_potential_transfer_function: np.ndarray, + real_potential_transfer_function: Tensor, + imag_potential_transfer_function: Tensor, zyx_scale: tuple[float, float, float], ) -> None: add_transfer_function_to_viewer( @@ -170,11 +208,11 @@ def visualize_transfer_function( def apply_transfer_function( - zyx_object: np.ndarray, - real_potential_transfer_function: np.ndarray, + zyx_object: Tensor, + real_potential_transfer_function: Tensor, z_padding: int, brightness: float, -) -> np.ndarray: +) -> Tensor: # This simplified forward model only handles phase, so it resuses the fluorescence forward model # TODO: extend to absorption return ( @@ -200,19 +238,18 @@ def apply_inverse_transfer_function( TV_rho_strength: float = 1e-3, TV_iterations: int = 10, ) -> Tensor: - """Reconstructs 3D phase from labelfree defocus zyx_data and a pair of - complex 3D transfer functions real_potential_transfer_function and - imag_potential_transfer_function, providing options for reconstruction - algorithms. + """Reconstructs 3D phase from labelfree defocus zyx_data and multi-channel + transfer functions, combining all illumination channels into a single phase estimate. Parameters ---------- zyx_data : Tensor - 3D raw data, label-free defocus stack + Multi-channel 3D raw data with shape (C, Z, Y, X). + For single channel (full aperture), C=1. real_potential_transfer_function : Tensor - Real potential transfer function, see calculate_transfer_function abov + Real potential transfer function with shape (C, Z, Y, X). imaginary_potential_transfer_function : Tensor - Imaginary potential transfer function, see calculate_transfer_function abov + Imaginary potential transfer function with shape (C, Z, Y, X). z_padding : int Padding for axial dimension. Use zero for defocus stacks that extend ~3 PSF widths beyond the sample. Pad by ~3 PSF widths otherwise. @@ -234,47 +271,51 @@ def apply_inverse_transfer_function( Returns ------- Tensor - zyx_phase (radians) + zyx_phase (radians) with shape (Z, Y, X) Raises ------ NotImplementedError TV is not implemented """ - # Handle padding - zyx_padded = util.pad_zyx_along_z(zyx_data, z_padding) - - # Normalize - zyx = util.inten_normalization_3D(zyx_padded) - - # Prepare TF - effective_transfer_function = ( - real_potential_transfer_function - + absorption_ratio * imaginary_potential_transfer_function - ) - - # Reconstruct - if reconstruction_algorithm == "Tikhonov": - inverse_filter = tikhonov_regularized_inverse_filter( - effective_transfer_function, regularization_strength - ) - - # [None]s and [0] are for applying a 1x1 "bank" of filters. - # For further uniformity, consider returning (1, Z, Y, X) - f_real = apply_filter_bank(inverse_filter[None, None], zyx[None])[0] - - elif reconstruction_algorithm == "TV": - raise NotImplementedError - f_real = util.single_variable_admm_tv_deconvolution_3D( - zyx, - effective_transfer_function, - reg_re=regularization_strength, - rho=TV_rho_strength, - itr=TV_iterations, + # Multi-channel reconstruction with sector illumination (or single channel) + # zyx_data shape: (C, Z, Y, X) + # TF shapes: (C, Z, Y, X) + num_channels = zyx_data.shape[0] + reconstructions = [] + + for c in range(num_channels): + print(f"Reconstructing channel {c+1}/{num_channels}") + # Handle padding + zyx_padded = util.pad_zyx_along_z(zyx_data[c], z_padding) + + # Normalize + zyx = util.inten_normalization_3D(zyx_padded) + + # Prepare TF for this channel + effective_transfer_function = ( + real_potential_transfer_function[c] + + absorption_ratio * imaginary_potential_transfer_function[c] ) - # Unpad - if z_padding != 0: - f_real = f_real[z_padding:-z_padding] + # Reconstruct this channel + if reconstruction_algorithm == "Tikhonov": + inverse_filter = tikhonov_regularized_inverse_filter( + effective_transfer_function, regularization_strength + ) + f_real = apply_filter_bank(inverse_filter[None, None], zyx[None])[ + 0 + ] + elif reconstruction_algorithm == "TV": + raise NotImplementedError + + # Unpad + if z_padding != 0: + f_real = f_real[z_padding:-z_padding] + + reconstructions.append(f_real) + + # Sum all channel reconstructions + f_real = torch.stack(reconstructions, dim=0).sum(dim=0) return f_real diff --git a/waveorder/optics.py b/waveorder/optics.py index 46204517..c5823006 100644 --- a/waveorder/optics.py +++ b/waveorder/optics.py @@ -148,6 +148,64 @@ def generate_pupil(frr, NA, lamb_in): return Pupil +def generate_sector_pupil( + radial_frequencies, + x_frequencies, + y_frequencies, + NA, + wavelength, + start_angle, + end_angle, +): + """ + Generate a sector pupil for a given angular range. + + Parameters + ---------- + radial_frequencies : torch.tensor + radial frequency coordinate in units of inverse length + x_frequencies : torch.tensor + x component of 2D spatial frequency array + y_frequencies : torch.tensor + y component of 2D spatial frequency array + NA : float + numerical aperture of the pupil function (normalized by the refractive index) + wavelength : float + wavelength of the light in units of length + start_angle : float + start angle of sector in degrees (0-360) + end_angle : float + end angle of sector in degrees (0-360) + + Returns + ------- + torch.tensor + sector pupil function + """ + # Start with circular pupil + pupil = torch.zeros(radial_frequencies.shape) + pupil[radial_frequencies < NA / wavelength] = 1 + + # If full aperture (0 to 360), return full pupil + if start_angle == 0 and end_angle == 360: + return pupil + + # Calculate angles in frequency space + # Note: atan2 returns angles in radians from -pi to pi + angles = torch.atan2(y_frequencies, x_frequencies) # radians, -pi to pi + angles_deg = torch.rad2deg(angles) % 360 # convert to degrees, 0 to 360 + + # Create sector mask + if end_angle > start_angle: + # Normal case: sector doesn't wrap around 0 + sector_mask = (angles_deg >= start_angle) & (angles_deg < end_angle) + else: + # Sector wraps around 0 degrees (e.g., 315 to 45) + sector_mask = (angles_deg >= start_angle) | (angles_deg < end_angle) + + return pupil * sector_mask.float() + + def gen_sector_Pupil(fxx, fyy, NA, lamb_in, sector_angle, rotation_angle): """