diff --git a/docs/examples/models/phase_thick_3d.py b/docs/examples/models/phase_thick_3d.py index bfbe6d30..65bd46f7 100644 --- a/docs/examples/models/phase_thick_3d.py +++ b/docs/examples/models/phase_thick_3d.py @@ -11,17 +11,22 @@ # Parameters # all lengths must use consistent units e.g. um simulation_arguments = { - "zyx_shape": (100, 256, 256), - "yx_pixel_size": 6.5 / 63, - "z_pixel_size": 0.25, - "index_of_refraction_media": 1.3, + "zyx_shape": (256, 256, 256), + "yx_pixel_size": 6.5 / 100, + "z_pixel_size": 0.1, + "index_of_refraction_media": 1.4, } -phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 5} +phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 0.5} transfer_function_arguments = { "z_padding": 0, - "wavelength_illumination": 0.532, - "numerical_aperture_illumination": 0.9, - "numerical_aperture_detection": 1.2, + "wavelength_illumination": 0.47, + "numerical_aperture_illumination": 0.52, + "numerical_aperture_detection": 1.35, + "tilt_angle_degrees": 30, + "index_of_refraction_o2": 1.0, + "numerical_aperture_o2": 0.95, + "index_of_refraction_o3": 1.5, + "numerical_aperture_o3": 1.0, } # Create a phantom diff --git a/waveorder/models/phase_thick_3d.py b/waveorder/models/phase_thick_3d.py index 5a2e2547..60701cbc 100644 --- a/waveorder/models/phase_thick_3d.py +++ b/waveorder/models/phase_thick_3d.py @@ -43,6 +43,11 @@ def calculate_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + tilt_angle_degrees: float = 0.0, + index_of_refraction_o2: float = 1.0, + numerical_aperture_o2: float = 1.0, + index_of_refraction_o3: float = 1.0, + numerical_aperture_o3: float = 1.0, ) -> tuple[np.ndarray, np.ndarray]: transverse_nyquist = sampling.transverse_nyquist( wavelength_illumination, @@ -75,6 +80,11 @@ def calculate_transfer_function( numerical_aperture_illumination, numerical_aperture_detection, invert_phase_contrast=invert_phase_contrast, + tilt_angle_degrees=tilt_angle_degrees, + index_of_refraction_o2=index_of_refraction_o2, + numerical_aperture_o2=numerical_aperture_o2, + index_of_refraction_o3=index_of_refraction_o3, + numerical_aperture_o3=numerical_aperture_o3, ) zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:] @@ -98,10 +108,17 @@ def _calculate_wrap_unsafe_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + tilt_angle_degrees: float = 0.0, + index_of_refraction_o2: float = 1.0, + numerical_aperture_o2: float = 1.0, + index_of_refraction_o3: float = 1.0, + numerical_aperture_o3: float = 1.0, ) -> tuple[np.ndarray, np.ndarray]: radial_frequencies = util.generate_radial_frequencies( zyx_shape[1:], yx_pixel_size ) + fyy, fxx = util.generate_frequencies(zyx_shape[1:], yx_pixel_size) + z_total = zyx_shape[0] + 2 * z_padding z_position_list = torch.fft.ifftshift( (torch.arange(z_total) - z_total // 2) * z_pixel_size @@ -114,25 +131,48 @@ def _calculate_wrap_unsafe_transfer_function( numerical_aperture_illumination, wavelength_illumination, ) - det_pupil = optics.generate_pupil( + # det_pupil = optics.generate_pupil( + # radial_frequencies, + # numerical_aperture_detection, + # wavelength_illumination, + # ) + + det_pupil = optics.generate_tilted_pupil( + fxx, fyy, + wavelength_illumination, + tilt_angle_degrees=tilt_angle_degrees, + n1=index_of_refraction_media, + n2=index_of_refraction_o2, + n3=index_of_refraction_o3, + na_o1=numerical_aperture_illumination, + na_o2=numerical_aperture_o2, + na_o3=numerical_aperture_o3, + ) + + effective_n = ( + index_of_refraction_media + * index_of_refraction_o3 + / index_of_refraction_o2 + ) + outer_detection_pupil = optics.generate_pupil( radial_frequencies, - numerical_aperture_detection, + 0.99 * effective_n, wavelength_illumination, ) + propagation_kernel = optics.generate_propagation_kernel( radial_frequencies, - det_pupil, - wavelength_illumination / index_of_refraction_media, + outer_detection_pupil, + wavelength_illumination / effective_n, z_position_list, ) greens_function_z = optics.generate_greens_function_z( radial_frequencies, - det_pupil, - wavelength_illumination / index_of_refraction_media, + outer_detection_pupil, + wavelength_illumination / effective_n, z_position_list, axially_even=False, ) - ( real_potential_transfer_function, imag_potential_transfer_function, diff --git a/waveorder/optics.py b/waveorder/optics.py index 46204517..27654fab 100644 --- a/waveorder/optics.py +++ b/waveorder/optics.py @@ -144,10 +144,55 @@ def generate_pupil(frr, NA, lamb_in): Pupil = torch.zeros(frr.shape) Pupil[frr < NA / lamb_in] = 1 - + return Pupil +def generate_tilted_pupil( + fxx, + fyy, + wavelength_illumination, + tilt_angle_degrees=0, + n1=1.0, + n2=1.0, + n3=1.0, + na_o1=1.0, + na_o2=1.0, + na_o3=1.0, +): + tilt_angle_rad = np.deg2rad(tilt_angle_degrees) + frr = torch.sqrt(fxx**2 + fyy**2) + tt = torch.arcsin(frr * wavelength_illumination / (n1 * n3 / n2)) + pp = torch.arctan2(fyy, fxx) + + st = np.sin(tilt_angle_rad) + ct = np.cos(tilt_angle_rad) + theta_in = torch.arccos( + st * torch.sin(tt) * torch.cos(pp) + ct * torch.cos(tt) + ) + theta_out = torch.arcsin(n2 / n3 * torch.sin(theta_in)) + + pupil = torch.zeros_like(theta_out) + tt2 = torch.arcsin(frr * wavelength_illumination / (n1 * n3 / n2)) + # pupil = fxx + + pupil[theta_out < np.arcsin(na_o1 / n2)] = 1 + + # Minimally working?! + #pupil *= fxx + #pupil[pupil < 0] = 0 + + pupil = torch.nan_to_num(pupil, nan=0.0) + + tt2 = torch.arcsin(frr * wavelength_illumination / (n1 * n3 / n2)) + #pupil *= (n3 / n2) * np.cos(tt2) / np.sqrt(1 - (((n3 / n2) * np.sin(tt2)) ** 2)) + pupil *= np.cos(tt2) / np.sqrt(1 - 1.3*np.sin(tt2) ** 2) + pupil = torch.nan_to_num(pupil, nan=0.0) + + + return pupil + + def gen_sector_Pupil(fxx, fyy, NA, lamb_in, sector_angle, rotation_angle): """ @@ -944,6 +989,24 @@ def compute_weak_object_transfer_function_3D( detection_pupil[None, :, :] * greens_function_z, dim=(1, 2) ) + import napari + + v = napari.Viewer() + v.add_image( + np.fft.fftshift(np.array(detection_pupil)), name="detection_pupil" + ) + v.add_image( + np.fft.fftshift(np.array(torch.real(greens_function_z))), + name="greens_function_z", + ) + v.add_image( + np.fft.fftshift(torch.real(detection_pupil[None] * greens_function_z)), + name="PG", + ) + import pdb + + pdb.set_trace() + H1 = torch.fft.ifft2(torch.conj(SPHz_hat) * PG_hat, dim=(1, 2)) H1 = H1 * window[:, None, None] H1 = torch.fft.fft(H1, dim=0)