diff --git a/docs/examples/configs/birefringence-and-phase.yml b/docs/examples/configs/birefringence-and-phase.yml index 6c53dfe6..1f92c609 100644 --- a/docs/examples/configs/birefringence-and-phase.yml +++ b/docs/examples/configs/birefringence-and-phase.yml @@ -20,6 +20,7 @@ phase: yx_pixel_size: 0.325 z_pixel_size: 2.0 z_padding: 0 + z_focus_offset: 0 index_of_refraction_media: 1.3 numerical_aperture_detection: 1.2 numerical_aperture_illumination: 0.5 diff --git a/docs/examples/configs/fluorescence.yml b/docs/examples/configs/fluorescence.yml index 3e84d884..96d4ef0c 100644 --- a/docs/examples/configs/fluorescence.yml +++ b/docs/examples/configs/fluorescence.yml @@ -7,6 +7,7 @@ fluorescence: yx_pixel_size: 0.325 z_pixel_size: 2.0 z_padding: 0 + z_focus_offset: 0 index_of_refraction_media: 1.3 numerical_aperture_detection: 1.2 wavelength_emission: 0.507 diff --git a/docs/examples/configs/phase.yml b/docs/examples/configs/phase.yml index 381b487e..0c1b287c 100644 --- a/docs/examples/configs/phase.yml +++ b/docs/examples/configs/phase.yml @@ -8,6 +8,7 @@ phase: yx_pixel_size: 0.325 z_pixel_size: 2.0 z_padding: 0 + z_focus_offset: 0 index_of_refraction_media: 1.3 numerical_aperture_detection: 1.2 numerical_aperture_illumination: 0.5 diff --git a/docs/examples/models/isotropic_thin_3d.py b/docs/examples/models/isotropic_thin_3d.py index 4d9701cb..577c4064 100644 --- a/docs/examples/models/isotropic_thin_3d.py +++ b/docs/examples/models/isotropic_thin_3d.py @@ -19,16 +19,25 @@ phantom_arguments = {"index_of_refraction_sample": 1.33, "sphere_radius": 5} z_shape = 100 z_pixel_size = 0.25 +zyx_scale = np.array( + [ + z_pixel_size, + simulation_arguments["yx_pixel_size"], + simulation_arguments["yx_pixel_size"], + ] +) transfer_function_arguments = { "z_position_list": (np.arange(z_shape) - z_shape // 2) * z_pixel_size, "numerical_aperture_illumination": 0.9, "numerical_aperture_detection": 1.2, } -# Create a phantom +# Create a disk phantom yx_absorption, yx_phase = isotropic_thin_3d.generate_test_phantom( **simulation_arguments, **phantom_arguments ) +yx_absorption[:, 128:] = 0 # half absorbing +yx_phase[128:] = 0 # half phase # Calculate transfer function ( @@ -38,6 +47,12 @@ **simulation_arguments, **transfer_function_arguments ) +# Calculate singular system +singular_system = isotropic_thin_3d.calculate_singular_system( + absorption_2d_to_3d_transfer_function, + phase_2d_to_3d_transfer_function, +) + # Display transfer function viewer = napari.Viewer() zyx_scale = np.array( @@ -70,8 +85,8 @@ yx_phase_recon, ) = isotropic_thin_3d.apply_inverse_transfer_function( zyx_data, - absorption_2d_to_3d_transfer_function, - phase_2d_to_3d_transfer_function, + singular_system, + regularization_strength=1e-2, ) # Display @@ -84,5 +99,10 @@ ] for array in arrays: - viewer.add_image(array[0].cpu().numpy(), name=array[1]) + scale = zyx_scale[1:] if array[0].ndim == 2 else zyx_scale + viewer.add_image(array[0].cpu().numpy(), name=array[1], scale=scale) + +viewer.grid.enabled = True +viewer.dims.current_step = (z_shape // 2, 0, 0) + input("Showing object, data, and recon. Press to quit...") diff --git a/tests/cli_tests/test_compute_tf.py b/tests/cli_tests/test_compute_tf.py index 2c6e1425..a31d5b9e 100644 --- a/tests/cli_tests/test_compute_tf.py +++ b/tests/cli_tests/test_compute_tf.py @@ -1,7 +1,10 @@ +import numpy as np +import pytest from click.testing import CliRunner from waveorder.cli import settings from waveorder.cli.compute_transfer_function import ( + _position_list_from_shape_scale_offset, generate_and_save_birefringence_transfer_function, generate_and_save_fluorescence_transfer_function, generate_and_save_phase_transfer_function, @@ -10,6 +13,18 @@ from waveorder.io import utils +@pytest.mark.parametrize( + "shape, scale, offset, expected", + [ + (5, 1.0, 0.0, [2.0, 1.0, 0.0, -1.0, -2.0]), + (4, 0.5, 1.0, [1.5, 1.0, 0.5, 0.0]), + ], +) +def test_position_list_from_shape_scale_offset(shape, scale, offset, expected): + result = _position_list_from_shape_scale_offset(shape, scale, offset) + np.testing.assert_allclose(result, expected) + + def test_compute_transfer(tmp_path, example_plate): recon_settings = settings.ReconstructionSettings( input_channel_names=[f"State{i}" for i in range(4)], @@ -116,9 +131,10 @@ def test_phase_3dim_write(birefringence_phase_recon_settings_function): settings, dataset = birefringence_phase_recon_settings_function settings.reconstruction_dimension = 2 generate_and_save_phase_transfer_function(settings, dataset, (3, 4, 5)) - assert dataset["absorption_transfer_function"] - assert dataset["phase_transfer_function"] - assert dataset["phase_transfer_function"].shape == (1, 1, 3, 4, 5) + assert dataset["singular_system_U"] + assert dataset["singular_system_U"].shape == (1, 2, 2, 4, 5) + assert dataset["singular_system_S"] + assert dataset["singular_system_Vh"] assert "real_potential_transfer_function" not in dataset assert "imaginary_potential_transfer_function" not in dataset diff --git a/waveorder/cli/apply_inverse_models.py b/waveorder/cli/apply_inverse_models.py index 52f72c3d..c86c6189 100644 --- a/waveorder/cli/apply_inverse_models.py +++ b/waveorder/cli/apply_inverse_models.py @@ -65,23 +65,27 @@ def phase( # [phase only, 2] if recon_dim == 2: # Load transfer functions - absorption_transfer_function = torch.tensor( - transfer_function_dataset["absorption_transfer_function"][0, 0] + U = torch.from_numpy(transfer_function_dataset["singular_system_U"][0]) + S = torch.from_numpy( + transfer_function_dataset["singular_system_S"][0, 0] ) - phase_transfer_function = torch.tensor( - transfer_function_dataset["phase_transfer_function"][0, 0] + Vh = torch.from_numpy( + transfer_function_dataset["singular_system_Vh"][0] ) # Apply ( - _, - output, + absorption_yx, + phase_yx, ) = isotropic_thin_3d.apply_inverse_transfer_function( czyx_data[0], - absorption_transfer_function, - phase_transfer_function, + (U, S, Vh), **settings_phase.apply_inverse.dict(), ) + # Stack to C1YX + output = phase_yx[None, None] + # TODO: Write phase and absorption to CZYX + # torch.stack((phase_yx[None], absorption_yx[None])) # [phase only, 3] elif recon_dim == 3: @@ -127,12 +131,13 @@ def birefringence_and_phase( # [biref and phase, 2] if recon_dim == 2: - # Load phase transfer functions - absorption_transfer_function = torch.tensor( - transfer_function_dataset["absorption_transfer_function"][0, 0] + # Load transfer functions + U = torch.from_numpy(transfer_function_dataset["singular_system_U"][0]) + S = torch.from_numpy( + transfer_function_dataset["singular_system_S"][0, 0] ) - phase_transfer_function = torch.tensor( - transfer_function_dataset["phase_transfer_function"][0, 0] + Vh = torch.from_numpy( + transfer_function_dataset["singular_system_Vh"][0] ) # Apply @@ -163,8 +168,7 @@ def birefringence_and_phase( yx_phase, ) = isotropic_thin_3d.apply_inverse_transfer_function( brightfield_3d, - absorption_transfer_function, - phase_transfer_function, + (U, S, Vh), **settings_phase.apply_inverse.dict(), ) diff --git a/waveorder/cli/apply_inverse_transfer_function.py b/waveorder/cli/apply_inverse_transfer_function.py index 368cf496..4c086852 100644 --- a/waveorder/cli/apply_inverse_transfer_function.py +++ b/waveorder/cli/apply_inverse_transfer_function.py @@ -79,6 +79,7 @@ def get_reconstruction_output_metadata(position_path: Path, config_path: Path): if recon_phase: if recon_dim == 2: channel_names.append("Phase2D") + # channel_names.append("Absorption2D") elif recon_dim == 3: channel_names.append("Phase3D") if recon_fluo: diff --git a/waveorder/cli/compute_transfer_function.py b/waveorder/cli/compute_transfer_function.py index 992301b8..24497005 100644 --- a/waveorder/cli/compute_transfer_function.py +++ b/waveorder/cli/compute_transfer_function.py @@ -4,6 +4,7 @@ import numpy as np from iohub.ngff import Position, open_ome_zarr +from waveorder import focus from waveorder.cli.parsing import ( config_filepath, input_position_dirpaths, @@ -20,6 +21,22 @@ ) +def _position_list_from_shape_scale_offset( + shape: int, scale: float, offset: float +) -> list: + """ + Generates a list of positions based on the given array shape, pixel size (scale), and offset. + + Examples + -------- + >>> _position_list_from_shape_scale_offset(5, 1.0, 0.0) + [2.0, 1.0, 0.0, -1.0, -2.0] + >>> _position_list_from_shape_scale_offset(4, 0.5, 1.0) + [1.5, 1.0, 0.5, 0.0] + """ + return list((-np.arange(shape) + (shape // 2) + offset) * scale) + + def generate_and_save_birefringence_transfer_function(settings, dataset): """Generates and saves the birefringence transfer function to the dataset, based on the settings. @@ -61,18 +78,22 @@ def generate_and_save_phase_transfer_function( echo_headline("Generating phase transfer function with settings:") echo_settings(settings.phase.transfer_function) + settings_dict = settings.phase.transfer_function.dict() if settings.reconstruction_dimension == 2: # Convert zyx_shape and z_pixel_size into yx_shape and z_position_list - settings_dict = settings.phase.transfer_function.dict() settings_dict["yx_shape"] = [zyx_shape[1], zyx_shape[2]] - settings_dict["z_position_list"] = list( - -(np.arange(zyx_shape[0]) - zyx_shape[0] // 2) - * settings_dict["z_pixel_size"] + settings_dict["z_position_list"] = ( + _position_list_from_shape_scale_offset( + shape=zyx_shape[0], + scale=settings_dict["z_pixel_size"], + offset=settings_dict["z_focus_offset"], + ) ) # Remove unused parameters settings_dict.pop("z_pixel_size") settings_dict.pop("z_padding") + settings_dict.pop("z_focus_offset") # Calculate transfer functions ( @@ -82,26 +103,36 @@ def generate_and_save_phase_transfer_function( **settings_dict, ) + # Calculate singular system + U, S, Vh = isotropic_thin_3d.calculate_singular_system( + absorption_transfer_function, + phase_transfer_function, + ) + # Save dataset.create_image( - "absorption_transfer_function", - absorption_transfer_function.cpu().numpy()[None, None, ...], - chunks=(1, 1, 1, zyx_shape[1], zyx_shape[2]), + "singular_system_U", + U.cpu().numpy()[None], ) dataset.create_image( - "phase_transfer_function", - phase_transfer_function.cpu().numpy()[None, None, ...], - chunks=(1, 1, 1, zyx_shape[1], zyx_shape[2]), + "singular_system_S", + S.cpu().numpy()[None, None], + ) + dataset.create_image( + "singular_system_Vh", + Vh.cpu().numpy()[None], ) elif settings.reconstruction_dimension == 3: + settings_dict.pop("z_focus_offset") # not used in 3D + # Calculate transfer functions ( real_potential_transfer_function, imaginary_potential_transfer_function, ) = phase_thick_3d.calculate_transfer_function( zyx_shape=zyx_shape, - **settings.phase.transfer_function.dict(), + **settings_dict, ) # Save dataset.create_image( @@ -133,6 +164,9 @@ def generate_and_save_fluorescence_transfer_function( """ echo_headline("Generating fluorescence transfer function with settings:") echo_settings(settings.fluorescence.transfer_function) + # Remove unused parameters + settings_dict = settings.fluorescence.transfer_function.dict() + settings_dict.pop("z_focus_offset") if settings.reconstruction_dimension == 2: raise NotImplementedError @@ -141,7 +175,7 @@ def generate_and_save_fluorescence_transfer_function( optical_transfer_function = ( isotropic_fluorescent_thick_3d.calculate_transfer_function( zyx_shape=zyx_shape, - **settings.fluorescence.transfer_function.dict(), + **settings_dict, ) ) # Save @@ -182,9 +216,40 @@ def compute_transfer_function_cli( f"Each of the input_channel_names = {settings.input_channel_names} in {config_filepath} must appear in the dataset {input_position_dirpaths[0]} which currently contains channel_names = {input_dataset.channel_names}." ) + # Find in-focus slices for 2D reconstruction in "auto" mode + if ( + settings.phase is not None + and settings.reconstruction_dimension == 2 + and settings.phase.transfer_function.z_focus_offset == "auto" + ): + + c_idx = input_dataset.get_channel_index( + settings.input_channel_names[0] + ) + zyx_array = input_dataset["0"][0, c_idx] + + in_focus_index = focus.focus_from_transverse_band( + zyx_array, + NA_det=settings.phase.transfer_function.numerical_aperture_detection, + lambda_ill=settings.phase.transfer_function.wavelength_illumination, + pixel_size=settings.phase.transfer_function.yx_pixel_size, + mode="min", + polynomial_fit_order=4, + ) + + z_focus_offset = in_focus_index - (zyx_shape[0] // 2) + settings.phase.transfer_function.z_focus_offset = z_focus_offset + print("Found z_focus_offset:", z_focus_offset) + # Prepare output dataset + num_channels = ( + 2 if settings.reconstruction_dimension == 2 else 1 + ) # space for SVD output_dataset = open_ome_zarr( - output_dirpath, layout="fov", mode="w", channel_names=["None"] + output_dirpath, + layout="fov", + mode="w", + channel_names=num_channels * ["None"], ) # Pass settings to appropriate calculate_transfer_function and save diff --git a/waveorder/cli/settings.py b/waveorder/cli/settings.py index dca82607..903614be 100644 --- a/waveorder/cli/settings.py +++ b/waveorder/cli/settings.py @@ -61,6 +61,7 @@ class FourierTransferFunctionSettings(MyBaseModel): yx_pixel_size: PositiveFloat = 6.5 / 20 z_pixel_size: PositiveFloat = 2.0 z_padding: NonNegativeInt = 0 + z_focus_offset: Union[int, Literal["auto"]] = 0 index_of_refraction_media: PositiveFloat = 1.3 numerical_aperture_detection: PositiveFloat = 1.2 diff --git a/waveorder/focus.py b/waveorder/focus.py index 5d470b57..38128f3c 100644 --- a/waveorder/focus.py +++ b/waveorder/focus.py @@ -15,6 +15,7 @@ def focus_from_transverse_band( pixel_size, midband_fractions=(0.125, 0.25), mode: Literal["min", "max"] = "max", + polynomial_fit_order: Optional[int] = None, plot_path: Optional[str] = None, threshold_FWHM: float = 0, ): @@ -36,8 +37,10 @@ def focus_from_transverse_band( midband_fractions: Tuple[float, float], optional The minimum and maximum fraction of the cutoff frequency that define the midband. Requires: 0 <= midband_fractions[0] < midband_fractions[1] <= 1. - mode: {'max', 'min'}, optional - Option to choose the in-focus slice by minimizing or maximizing the midband frequency. + mode: {'min', 'max'}, optional + Option to choose the in-focus slice by minimizing or maximizing the midband power. By default 'max'. + polynomial_fit_order: int, optional + Default None is no fit. If integer, a polynomial of that degree is fit to the midband power before choosing the extreme point as the in-focus slice. plot_path: str or None, optional File name for a diagnostic plot (supports matplotlib filetypes .png, .pdf, .svg, etc.). Use None to skip. @@ -93,7 +96,13 @@ def focus_from_transverse_band( # Find slice index with min/max power in midband midband_sum = np.sum(xy_abs_fft[:, midband_mask], axis=1) - peak_index = minmaxfunc(midband_sum) + + if polynomial_fit_order is None: + peak_index = minmaxfunc(midband_sum) + else: + x = np.arange(len(midband_sum)) + coeffs = np.polyfit(x, midband_sum, polynomial_fit_order) + peak_index = minmaxfunc(np.poly1d(coeffs)(x)) peak_results = peak_widths(midband_sum, [peak_index]) peak_FWHM = peak_results[0][0] diff --git a/waveorder/models/isotropic_thin_3d.py b/waveorder/models/isotropic_thin_3d.py index 3a006f22..85a6a21a 100644 --- a/waveorder/models/isotropic_thin_3d.py +++ b/waveorder/models/isotropic_thin_3d.py @@ -5,6 +5,7 @@ from torch import Tensor from waveorder import optics, sampling, util +from waveorder.filter import apply_filter_bank def generate_test_phantom( @@ -29,7 +30,7 @@ def generate_test_phantom( / wavelength_illumination ) # phase in radians - yx_absorption = 0.02 * sphere[1] + yx_absorption = torch.clone(yx_phase) return yx_absorption, yx_phase @@ -103,9 +104,17 @@ def _calculate_wrap_unsafe_transfer_function( numerical_aperture_detection: float, invert_phase_contrast: bool = False, ) -> Tuple[Tensor, Tensor]: - if invert_phase_contrast: - z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,)) + if numerical_aperture_illumination >= numerical_aperture_detection: + print( + "Warning: numerical_aperture_illumination is >= " + "numerical_aperture_detection. Setting " + "numerical_aperture_illumination to 0.9 * " + "numerical_aperture_detection to avoid singularities." + ) + numerical_aperture_illumination = 0.9 * numerical_aperture_detection + if invert_phase_contrast: + z_position_list = [-1 * x for x in z_position_list] radial_frequencies = util.generate_radial_frequencies( yx_shape, yx_pixel_size ) @@ -148,6 +157,45 @@ def _calculate_wrap_unsafe_transfer_function( ) +def calculate_singular_system( + absorption_2d_to_3d_transfer_function: Tensor, + phase_2d_to_3d_transfer_function: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Calculates the singular system of the absoprtion and phase transfer + functions. + + Together, the transfer functions form a (2, Z, Vy, Vx) tensor, where + (2,) is the object-space dimension (abs, phase), (Z,) is the data-space + dimension, and (Vy, Vx) are the spatial frequency dimensions. + + The SVD is computed over the (2, Z) dimensions. + + Parameters + ---------- + absorption_2d_to_3d_transfer_function : Tensor + ZYX transfer function for absorption + phase_2d_to_3d_transfer_function : Tensor + ZYX transfer function for phase + + Returns + ------- + Tuple[Tensor, Tensor, Tensor] + """ + sfYX_transfer_function = torch.stack( + ( + absorption_2d_to_3d_transfer_function, + phase_2d_to_3d_transfer_function, + ), + dim=0, + ) + YXsf_transfer_function = sfYX_transfer_function.permute(2, 3, 0, 1) + Up, Sp, Vhp = torch.linalg.svd(YXsf_transfer_function, full_matrices=False) + U = Up.permute(2, 3, 0, 1) + S = Sp.permute(2, 0, 1) + Vh = Vhp.permute(2, 3, 0, 1) + return U, S, Vh + + def visualize_transfer_function( viewer, absorption_2d_to_3d_transfer_function: Tensor, @@ -202,8 +250,8 @@ def visualize_point_spread_function( def apply_transfer_function( yx_absorption: Tensor, yx_phase: Tensor, - phase_2d_to_3d_transfer_function: Tensor, absorption_2d_to_3d_transfer_function: Tensor, + phase_2d_to_3d_transfer_function: Tensor, ) -> Tensor: # Very simple simulation, consider adding noise and bkg knobs @@ -233,14 +281,13 @@ def apply_transfer_function( def apply_inverse_transfer_function( zyx_data: Tensor, - absorption_2d_to_3d_transfer_function: Tensor, - phase_2d_to_3d_transfer_function: Tensor, + singular_system: Tuple[Tensor, Tensor, Tensor], reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov", - regularization_strength: float = 1e-6, + regularization_strength: float = 1e-3, reg_p: float = 1e-6, # TODO: use this parameter TV_rho_strength: float = 1e-3, TV_iterations: int = 10, - bg_filter: bool = True, + bg_filter: bool = False, ) -> Tuple[Tensor, Tensor]: """Reconstructs absorption and phase from zyx_data and a pair of 3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and @@ -251,15 +298,13 @@ def apply_inverse_transfer_function( ---------- zyx_data : Tensor 3D raw data, label-free defocus stack - absorption_2d_to_3d_transfer_function : Tensor - 3D-to-2D absorption transfer function, see calculate_transfer_function above - phase_2d_to_3d_transfer_function : Tensor - 3D-to-2D phase transfer function, see calculate_transfer_function above - reconstruction_algorithm : Literal["Tikhonov", "TV"], optional + singular_system : Tuple[Tensor, Tensor, Tensor] + singular system of the transfer function bank + reconstruction_algorithm : Literal["Tikhonov";, "TV";], optional "Tikhonov" or "TV", by default "Tikhonov" "TV" is not implemented. regularization_strength : float, optional - regularization parameter, by default 1e-6 + regularization parameter, by default 1e-3 reg_p : float, optional TV-specific phase regularization parameter, by default 1e-6 "TV" is not implemented. @@ -268,7 +313,7 @@ def apply_inverse_transfer_function( "TV" is not implemented. bg_filter : bool, optional option for slow-varying 2D background normalization with uniform filter - by default True + by default False Returns ------- @@ -281,66 +326,22 @@ def apply_inverse_transfer_function( NotImplementedError TV is not implemented """ - zyx_data_normalized = util.inten_normalization( - zyx_data, bg_filter=bg_filter - ) + # Normalize + zyx = util.inten_normalization(zyx_data, bg_filter=bg_filter) - zyx_data_hat = torch.fft.fft2(zyx_data_normalized, dim=(1, 2)) - - # TODO AHA and b_vec calculations should be moved into tikhonov/tv calculations - # TODO Reformulate to use filter.apply_filter_bank - AHA = [ - torch.sum(torch.abs(absorption_2d_to_3d_transfer_function) ** 2, dim=0) - + regularization_strength, - torch.sum( - torch.conj(absorption_2d_to_3d_transfer_function) - * phase_2d_to_3d_transfer_function, - dim=0, - ), - torch.sum( - torch.conj( - phase_2d_to_3d_transfer_function, - ) - * absorption_2d_to_3d_transfer_function, - dim=0, - ), - torch.sum( - torch.abs( - phase_2d_to_3d_transfer_function, - ) - ** 2, - dim=0, - ) - + reg_p, - ] - - b_vec = [ - torch.sum( - torch.conj(absorption_2d_to_3d_transfer_function) * zyx_data_hat, - dim=0, - ), - torch.sum( - torch.conj( - phase_2d_to_3d_transfer_function, - ) - * zyx_data_hat, - dim=0, - ), - ] - - # Deconvolution with Tikhonov regularization + # TODO Consider refactoring with vectorial transfer function SVD if reconstruction_algorithm == "Tikhonov": - absorption, phase = util.dual_variable_tikhonov_deconvolution_2d( - AHA, b_vec + print("Computing inverse filter") + U, S, Vh = singular_system + S_reg = S / (S**2 + regularization_strength) + sfyx_inverse_filter = torch.einsum( + "sj...,j...,jf...->fs...", U, S_reg, Vh ) + absorption_yx, phase_yx = apply_filter_bank(sfyx_inverse_filter, zyx) + # ADMM deconvolution with anisotropic TV regularization elif reconstruction_algorithm == "TV": raise NotImplementedError - absorption, phase = util.dual_variable_admm_tv_deconv_2d( - AHA, b_vec, rho=TV_rho_strength, itr=TV_iterations - ) - - phase -= torch.mean(phase) - return absorption, phase + return absorption_yx, phase_yx diff --git a/waveorder/util.py b/waveorder/util.py index f5e13602..17ccff72 100644 --- a/waveorder/util.py +++ b/waveorder/util.py @@ -714,7 +714,7 @@ def inten_normalization(img_stack, bg_filter=True): img_stack[i], size=X // 2 ) else: - img_norm_stack[i] = img_stack[i].copy() + img_norm_stack[i] = img_stack[i] img_norm_stack[i] /= torch.mean(img_norm_stack[i]) img_norm_stack[i] -= 1