diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index bfb2756ebe..5d0ec152d8 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -324,6 +324,9 @@ def __call__( kwargs: optional keyword args to be passed to ``network``. """ + # check if there is a conditioning signal + condition = kwargs.pop("condition", None) + patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor if self.splitter is None: # handle situations where the splitter is not provided @@ -344,20 +347,39 @@ def __call__( f"The provided inputs type is {type(inputs)}." ) patches_locations = inputs + if condition is not None: + condition_locations = condition else: # apply splitter patches_locations = self.splitter(inputs) + if condition is not None: + # apply splitter to condition + condition_locations = self.splitter(condition) ratios: list[float] = [] mergers: list[Merger] = [] - for patches, locations, batch_size in self._batch_sampler(patches_locations): - # run inference - outputs = self._run_inference(network, patches, *args, **kwargs) - # initialize the mergers - if not mergers: - mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) - # aggregate outputs - self._aggregate(outputs, locations, batch_size, mergers, ratios) + if condition is not None: + for (patches, locations, batch_size), (condition_patches, _, _) in zip( + self._batch_sampler(patches_locations), self._batch_sampler(condition_locations) + ): + # add patched condition to kwargs + kwargs["condition"] = condition_patches + # run inference + outputs = self._run_inference(network, patches, *args, **kwargs) + # initialize the mergers + if not mergers: + mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) + # aggregate outputs + self._aggregate(outputs, locations, batch_size, mergers, ratios) + else: + for patches, locations, batch_size in self._batch_sampler(patches_locations): + # run inference + outputs = self._run_inference(network, patches, *args, **kwargs) + # initialize the mergers + if not mergers: + mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size) + # aggregate outputs + self._aggregate(outputs, locations, batch_size, mergers, ratios) # finalize the mergers and get the results merged_outputs = [merger.finalize() for merger in mergers] @@ -742,12 +764,24 @@ def __call__( f"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported." ) - return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs)) + # check if there is a conditioning signal + condition = kwargs.get("condition", None) + if condition is not None: + return super().__call__( + inputs=inputs, + network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs), + condition=condition, + ) + else: + return super().__call__( + inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs) + ) def network_wrapper( self, network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], x: torch.Tensor, + condition: torch.Tensor | None = None, *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: @@ -756,7 +790,12 @@ def network_wrapper( """ # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D. x = x.squeeze(dim=self.spatial_dim + 2) - out = network(x, *args, **kwargs) + + if condition is not None: + condition = condition.squeeze(dim=self.spatial_dim + 2) + out = network(x, condition, *args, **kwargs) + else: + out = network(x, *args, **kwargs) # Unsqueeze the network output so it is [N, C, D, H, W] as expected by # the default SlidingWindowInferer class diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 8adba8fa25..766486a807 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -153,6 +153,8 @@ def sliding_window_inference( device = device or inputs.device sw_device = sw_device or inputs.device + condition = kwargs.pop("condition", None) + temp_meta = None if isinstance(inputs, MetaTensor): temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False) @@ -168,6 +170,8 @@ def sliding_window_inference( pad_size.extend([half, diff - half]) if any(pad_size): inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) + if condition is not None: + condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) # Store all slices scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) @@ -220,13 +224,19 @@ def sliding_window_inference( ] if sw_batch_size > 1: win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + if condition is not None: + win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device) + kwargs["condition"] = win_condition else: win_data = inputs[unravel_slice[0]].to(sw_device) + if condition is not None: + win_condition = condition[unravel_slice[0]].to(sw_device) + kwargs["condition"] = win_condition + if with_coord: - seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch + seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) else: - seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch - + seg_prob_out = predictor(win_data, *args, **kwargs) # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. dict_keys, seg_tuple = _flatten_struct(seg_prob_out) if process_fn: diff --git a/tests/inferers/test_patch_inferer.py b/tests/inferers/test_patch_inferer.py index 964f08e6fe..02c6a37837 100644 --- a/tests/inferers/test_patch_inferer.py +++ b/tests/inferers/test_patch_inferer.py @@ -305,5 +305,231 @@ def test_patch_inferer_errors(self, inputs, arguments, expected_error): inferer(inputs=inputs, network=lambda x: x) +# ---------------------------------------------------------------------------- +# Error test cases with conditionign +# ---------------------------------------------------------------------------- + +# no-overlapping 2x2 patches +TEST_CASE_0_TENSOR_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# no-overlapping 2x2 patches using all default parameters (except for splitter) +TEST_CASE_1_TENSOR_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# divisible batch_size +TEST_CASE_2_TENSOR_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=2), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# non-divisible batch_size +TEST_CASE_3_TENSOR_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=3), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# patches that are already split (Splitter should be None) +TEST_CASE_4_SPLIT_LIST_c = [ + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + dict(splitter=None, merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# using all default parameters (patches are already split) +TEST_CASE_5_SPLIT_LIST_c = [ + [ + (TENSOR_4x4[..., :2, :2], (0, 0)), + (TENSOR_4x4[..., :2, 2:], (0, 2)), + (TENSOR_4x4[..., 2:, :2], (2, 0)), + (TENSOR_4x4[..., 2:, 2:], (2, 2)), + ], + dict(merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# output smaller than input patches +TEST_CASE_6_SMALLER_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x, condition: torch.mean(x, dim=(-1, -2), keepdim=True) + torch.mean(condition, dim=(-1, -2), keepdim=True), + TENSOR_2x2 * 2, +] + +# preprocess patches +TEST_CASE_7_PREPROCESS_c = [ + TENSOR_4x4, + dict( + splitter=SlidingWindowSplitter(patch_size=(2, 2)), + merger_cls=AvgMerger, + preprocessing=lambda x: 2 * x, + postprocessing=None, + ), + lambda x, condition: x + condition, + 2 * TENSOR_4x4 + TENSOR_4x4, +] + +# preprocess patches +TEST_CASE_8_POSTPROCESS_c = [ + TENSOR_4x4, + dict( + splitter=SlidingWindowSplitter(patch_size=(2, 2)), + merger_cls=AvgMerger, + preprocessing=None, + postprocessing=lambda x: 4 * x, + ), + lambda x, condition: x + condition, + 4 * TENSOR_4x4 * 2, +] + +# str merger as the class name +TEST_CASE_9_STR_MERGER_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="AvgMerger"), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# str merger as dotted patch +TEST_CASE_10_STR_MERGER_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="monai.inferers.merger.AvgMerger"), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# non-divisible patch_size leading to larger image (without matching spatial shape) +TEST_CASE_11_PADDING_c = [ + TENSOR_4x4, + dict( + splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode="constant", pad_value=0.0), + merger_cls=AvgMerger, + match_spatial_shape=False, + ), + lambda x, condition: x + condition, + pad(TENSOR_4x4, (0, 2), value=0.0) * 2, +] + +# non-divisible patch_size with matching spatial shapes +TEST_CASE_12_MATCHING_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=None), merger_cls=AvgMerger), + lambda x, condition: x + condition, + pad(TENSOR_4x4[..., :3], (0, 1), value=float("nan")) * 2, +] + +# non-divisible patch_size with matching spatial shapes +TEST_CASE_13_PADDING_MATCHING_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 3)), merger_cls=AvgMerger), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# multi-threading +TEST_CASE_14_MULTITHREAD_BUFFER_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# multi-threading with batch +TEST_CASE_15_MULTITHREADD_BUFFER_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4), + lambda x, condition: x + condition, + TENSOR_4x4 * 2, +] + +# list of tensor output +TEST_CASE_0_LIST_TENSOR_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x, condition: (x + condition, x + condition), + (TENSOR_4x4 * 2, TENSOR_4x4 * 2), +] + +# list of tensor output +TEST_CASE_0_DICT_c = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger), + lambda x, condition: {"model_output": x + condition}, + {"model_output": TENSOR_4x4 * 2}, +] + + +class PatchInfererTestsCond(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0_TENSOR_c, + TEST_CASE_1_TENSOR_c, + TEST_CASE_2_TENSOR_c, + TEST_CASE_3_TENSOR_c, + TEST_CASE_4_SPLIT_LIST_c, + TEST_CASE_5_SPLIT_LIST_c, + TEST_CASE_6_SMALLER_c, + TEST_CASE_7_PREPROCESS_c, + TEST_CASE_8_POSTPROCESS_c, + TEST_CASE_9_STR_MERGER_c, + TEST_CASE_10_STR_MERGER_c, + TEST_CASE_11_PADDING_c, + TEST_CASE_12_MATCHING_c, + TEST_CASE_13_PADDING_MATCHING_c, + TEST_CASE_14_MULTITHREAD_BUFFER_c, + TEST_CASE_15_MULTITHREADD_BUFFER_c, + ] + ) + def test_patch_inferer_tensor(self, inputs, arguments, network, expected): + if isinstance(inputs, list): # case 4 and 5 + condition = [(x[0].clone(), x[1]) for x in inputs] + else: + condition = inputs.clone() + inferer = PatchInferer(**arguments) + output = inferer(inputs=inputs, network=network, condition=condition) + assert_allclose(output, expected) + + @parameterized.expand([TEST_CASE_0_LIST_TENSOR_c]) + def test_patch_inferer_list_tensor(self, inputs, arguments, network, expected): + if isinstance(inputs, list): # case 4 and 5 + condition = [(x[0].clone(), x[1]) for x in inputs] + else: + condition = inputs.clone() + inferer = PatchInferer(**arguments) + output = inferer(inputs=inputs, network=network, condition=condition) + for out, exp in zip(output, expected): + assert_allclose(out, exp) + + @parameterized.expand([TEST_CASE_0_DICT_c]) + def test_patch_inferer_dict(self, inputs, arguments, network, expected): + if isinstance(inputs, list): # case 4 and 5 + condition = [(x[0].clone(), x[1]) for x in inputs] + else: + condition = inputs.clone() + inferer = PatchInferer(**arguments) + output = inferer(inputs=inputs, network=network, condition=condition) + for k in expected: + assert_allclose(output[k], expected[k]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/inferers/test_slice_inferer.py b/tests/inferers/test_slice_inferer.py index 526542943e..8dac3a4e76 100644 --- a/tests/inferers/test_slice_inferer.py +++ b/tests/inferers/test_slice_inferer.py @@ -53,5 +53,39 @@ def test_shape(self, spatial_dim): result = inferer(input_volume, model) +class TestSliceInfererCond(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_shape(self, spatial_dim): + spatial_dim = int(spatial_dim) + + model = UNet( + spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2 + ) + + # overwrite the forward method to test the inferer with a model that takes a condition + model.forward = lambda x, condition: x + condition if condition is not None else x + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + + # Initialize a dummy 3D tensor volume with shape (N,C,D,H,W) + input_volume = torch.ones(1, 1, 64, 256, 256, device=device) + condition_volume = torch.ones(1, 1, 64, 256, 256, device=device) + # Remove spatial dim to slide across from the roi_size + roi_size = list(input_volume.shape[2:]) + roi_size.pop(spatial_dim) + + # Initialize and run inferer + inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1) + result = inferer(input_volume, model, condition=condition_volume) + + self.assertTupleEqual(result.shape, input_volume.shape) + self.assertEqual(result.sum(), (input_volume + condition_volume).sum()) + # test that the inferer can be run multiple times + result = inferer(input_volume, model, condition=condition_volume) + + if __name__ == "__main__": unittest.main() diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index 997822edd3..f97cbb9299 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -373,5 +373,317 @@ def compute_dict(data): np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) +class TestSlidingWindowInferenceCond(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device): + n_total = np.prod(image_shape) + if mode == "constant": + inputs = torch.arange(n_total, dtype=torch.float).reshape(*image_shape) + else: + inputs = torch.ones(*image_shape, dtype=torch.float) + if device.type == "cuda" and not torch.cuda.is_available(): + device = torch.device("cpu:0") + + # condition + condition = torch.ones(*image_shape, dtype=torch.float) + + def compute(data, condition): + return data + condition + + if mode == "constant": + expected_val = np.arange(n_total, dtype=np.float32).reshape(*image_shape) + 1.0 + else: + expected_val = np.ones(image_shape, dtype=np.float32) + 1.0 + + result = sliding_window_inference( + inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode, condition=condition.to(device) + ) + np.testing.assert_string_equal(device.type, result.device.type) + np.testing.assert_allclose(result.cpu().numpy(), expected_val) + + result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap, mode)( + inputs.to(device), compute, condition=condition.to(device) + ) + np.testing.assert_string_equal(device.type, result.device.type) + np.testing.assert_allclose(result.cpu().numpy(), expected_val) + + @parameterized.expand([[x] for x in TEST_TORCH_AND_META_TENSORS]) + def test_default_device(self, data_type): + device = "cuda" if torch.cuda.is_available() else "cpu:0" + inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device) + condition = torch.ones((3, 16, 15, 7)).to(device=device) + inputs = list_data_collate([inputs]) # make a proper batch + condition = list_data_collate([condition]) # make a proper batch + roi_shape = (4, 10, 7) + sw_batch_size = 10 + + def compute(data, condition): + return data + condition + + inputs.requires_grad = True + result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, condition=condition) + self.assertTrue(result.requires_grad) + np.testing.assert_string_equal(inputs.device.type, result.device.type) + expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val) + + @parameterized.expand(list(itertools.product(TEST_TORCH_AND_META_TENSORS, ("cpu", "cuda"), ("cpu", "cuda", None)))) + @skip_if_no_cuda + def test_sw_device(self, data_type, device, sw_device): + inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device) + condition = torch.ones((3, 16, 15, 7)).to(device=device) + inputs = list_data_collate([inputs]) # make a proper batch + condition = list_data_collate([condition]) # make a proper batch + roi_shape = (4, 10, 7) + sw_batch_size = 10 + + def compute(data, condition): + self.assertEqual(data.device.type, sw_device or device) + self.assertEqual(condition.device.type, sw_device or device) + + return data + condition + + result = sliding_window_inference( + inputs, roi_shape, sw_batch_size, compute, sw_device=sw_device, device="cpu", condition=condition + ) + np.testing.assert_string_equal("cpu", result.device.type) + expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 + np.testing.assert_allclose(result.cpu().numpy(), expected_val) + + def test_sigma(self): + device = "cuda" if torch.cuda.is_available() else "cpu:0" + inputs = torch.ones((1, 1, 7, 7)).to(device=device) + roi_shape = (3, 3) + sw_batch_size = 10 + + class _Pred: + add = 1 + + def compute(self, data): + self.add += 1 + return data + self.add + + result = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + _Pred().compute, + overlap=0.5, + padding_mode="constant", + cval=-1, + mode="constant", + sigma_scale=1.0, + ) + + expected = np.array( + [ + [ + [ + [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000], + [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000], + [3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333], + [3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667], + [4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333], + [4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000], + [5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000], + ] + ] + ] + ) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + result = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + _Pred().compute, + overlap=0.5, + padding_mode="constant", + cval=-1, + mode="gaussian", + sigma_scale=1.0, + progress=has_tqdm, + ) + expected = np.array( + [ + [ + [ + [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], + [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0], + [3.3271625, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271625], + [3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377], + [4.3271623, 4.3271623, 4.3271627, 4.3271627, 4.3271627, 4.3271623, 4.3271623], + [4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757], + [4.9999995, 5.0, 5.0, 5.0, 5.0, 5.0, 4.9999995], + ] + ] + ] + ) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="gaussian", sigma_scale=1.0)( + inputs, _Pred().compute + ) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="gaussian", sigma_scale=[1.0, 1.0])( + inputs, _Pred().compute + ) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + result = SlidingWindowInferer( + roi_shape, sw_batch_size, overlap=0.5, mode="gaussian", sigma_scale=[1.0, 1.0], cache_roi_weight_map=True + )(inputs, _Pred().compute) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + def test_cval(self): + device = "cuda" if torch.cuda.is_available() else "cpu:0" + inputs = torch.ones((1, 1, 3, 3)).to(device=device) + condition = torch.ones((1, 1, 3, 3)).to(device=device) + roi_shape = (5, 5) + sw_batch_size = 10 + + def compute(data, condition): + return data + data.sum() + condition + + result = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + compute, + overlap=0.5, + padding_mode="constant", + cval=-1, + mode="constant", + sigma_scale=1.0, + condition=condition, + ) + expected = np.ones((1, 1, 3, 3)) * -6.0 + 1.0 + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1)( + inputs, compute, condition=condition + ) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + def test_args_kwargs(self): + device = "cuda" if torch.cuda.is_available() else "cpu:0" + inputs = torch.ones((1, 1, 3, 3)).to(device=device) + condition = torch.ones((1, 1, 3, 3)).to(device=device) + t1 = torch.ones(1).to(device=device) + t2 = torch.ones(1).to(device=device) + roi_shape = (5, 5) + sw_batch_size = 10 + + def compute(data, test1, test2, condition): + return data + test1 + test2 + condition + + result = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + compute, + 0.5, + "constant", + 1.0, + "constant", + 0.0, + device, + device, + has_tqdm, + None, + None, + None, + 0, + False, + t1, + condition=condition, + test2=t2, + ) + expected = np.ones((1, 1, 3, 3)) + 3.0 + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + result = SlidingWindowInferer( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute, t1, condition=condition, test2=t2) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + result = SlidingWindowInfererAdapt( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute, t1, condition=condition, test2=t2) + np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) + + def test_multioutput(self): + device = "cuda" if torch.cuda.is_available() else "cpu:0" + inputs = torch.ones((1, 6, 20, 20)).to(device=device) + condition = torch.ones((1, 6, 20, 20)).to(device=device) + roi_shape = (8, 8) + sw_batch_size = 10 + + def compute(data, condition): + return ( + data + 1 + condition, + data[:, ::3, ::2, ::2] + 2 + condition[:, ::3, ::2, ::2], + data[:, ::2, ::4, ::4] + 3 + condition[:, ::2, ::4, ::4], + ) + + def compute_dict(data, condition): + return { + 1: data + 1 + condition, + 2: data[:, ::3, ::2, ::2] + 2 + condition[:, ::3, ::2, ::2], + 3: data[:, ::2, ::4, ::4] + 3 + condition[:, ::2, ::4, ::4], + } + + result = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + compute, + 0.5, + "constant", + 1.0, + "constant", + 0.0, + device, + device, + has_tqdm, + None, + condition=condition, + ) + result_dict = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + compute_dict, + 0.5, + "constant", + 1.0, + "constant", + 0.0, + device, + device, + has_tqdm, + None, + condition=condition, + ) + expected = (np.ones((1, 6, 20, 20)) + 2, np.ones((1, 2, 10, 10)) + 3, np.ones((1, 3, 5, 5)) + 4) + expected_dict = {1: np.ones((1, 6, 20, 20)) + 2, 2: np.ones((1, 2, 10, 10)) + 3, 3: np.ones((1, 3, 5, 5)) + 4} + for rr, ee in zip(result, expected): + np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4) + for rr, _ in zip(result_dict, expected_dict): + np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) + + result = SlidingWindowInferer( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute, condition=condition) + for rr, ee in zip(result, expected): + np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4) + + result_dict = SlidingWindowInferer( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute_dict, condition=condition) + for rr, _ in zip(result_dict, expected_dict): + np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) + + if __name__ == "__main__": unittest.main()