Skip to content

Add support for optional conditioning in PatchInferer, SliceInferer, and SlidingWindowInferer #8400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
59 changes: 49 additions & 10 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def __call__(
kwargs: optional keyword args to be passed to ``network``.

"""
# check if there is a conditioning signal
condition = kwargs.pop("condition", None)

patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
if self.splitter is None:
# handle situations where the splitter is not provided
Expand All @@ -344,20 +347,39 @@ def __call__(
f"The provided inputs type is {type(inputs)}."
)
patches_locations = inputs
if condition is not None:
condition_locations = condition
else:
# apply splitter
patches_locations = self.splitter(inputs)
if condition is not None:
# apply splitter to condition
condition_locations = self.splitter(condition)

ratios: list[float] = []
mergers: list[Merger] = []
for patches, locations, batch_size in self._batch_sampler(patches_locations):
# run inference
outputs = self._run_inference(network, patches, *args, **kwargs)
# initialize the mergers
if not mergers:
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
# aggregate outputs
self._aggregate(outputs, locations, batch_size, mergers, ratios)
if condition is not None:
for (patches, locations, batch_size), (condition_patches, _, _) in zip(
self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)
):
# add patched condition to kwargs
kwargs["condition"] = condition_patches
# run inference
outputs = self._run_inference(network, patches, *args, **kwargs)
# initialize the mergers
if not mergers:
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
# aggregate outputs
self._aggregate(outputs, locations, batch_size, mergers, ratios)
else:
for patches, locations, batch_size in self._batch_sampler(patches_locations):
# run inference
outputs = self._run_inference(network, patches, *args, **kwargs)
# initialize the mergers
if not mergers:
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
# aggregate outputs
self._aggregate(outputs, locations, batch_size, mergers, ratios)

# finalize the mergers and get the results
merged_outputs = [merger.finalize() for merger in mergers]
Expand Down Expand Up @@ -742,12 +764,24 @@ def __call__(
f"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported."
)

return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs))
# check if there is a conditioning signal
condition = kwargs.get("condition", None)
if condition is not None:
return super().__call__(
inputs=inputs,
network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs),
condition=condition,
)
else:
return super().__call__(
inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs)
)

def network_wrapper(
self,
network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
x: torch.Tensor,
condition: torch.Tensor | None = None,
*args: Any,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
Expand All @@ -756,7 +790,12 @@ def network_wrapper(
"""
# Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
x = x.squeeze(dim=self.spatial_dim + 2)
out = network(x, *args, **kwargs)

if condition is not None:
condition = condition.squeeze(dim=self.spatial_dim + 2)
out = network(x, condition, *args, **kwargs)
else:
out = network(x, *args, **kwargs)

# Unsqueeze the network output so it is [N, C, D, H, W] as expected by
# the default SlidingWindowInferer class
Expand Down
16 changes: 13 additions & 3 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def sliding_window_inference(
device = device or inputs.device
sw_device = sw_device or inputs.device

condition = kwargs.pop("condition", None)

temp_meta = None
if isinstance(inputs, MetaTensor):
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
Expand All @@ -168,6 +170,8 @@ def sliding_window_inference(
pad_size.extend([half, diff - half])
if any(pad_size):
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
if condition is not None:
condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

# Store all slices
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
Expand Down Expand Up @@ -220,13 +224,19 @@ def sliding_window_inference(
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
if condition is not None:
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
kwargs["condition"] = win_condition
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
if condition is not None:
win_condition = condition[unravel_slice[0]].to(sw_device)
kwargs["condition"] = win_condition

if with_coord:
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
else:
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch

seg_prob_out = predictor(win_data, *args, **kwargs)
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
if process_fn:
Expand Down
226 changes: 226 additions & 0 deletions tests/inferers/test_patch_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,231 @@ def test_patch_inferer_errors(self, inputs, arguments, expected_error):
inferer(inputs=inputs, network=lambda x: x)


# ----------------------------------------------------------------------------
# Error test cases with conditionign
# ----------------------------------------------------------------------------

# no-overlapping 2x2 patches
TEST_CASE_0_TENSOR_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# no-overlapping 2x2 patches using all default parameters (except for splitter)
TEST_CASE_1_TENSOR_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# divisible batch_size
TEST_CASE_2_TENSOR_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=2),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# non-divisible batch_size
TEST_CASE_3_TENSOR_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=3),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# patches that are already split (Splitter should be None)
TEST_CASE_4_SPLIT_LIST_c = [
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
(TENSOR_4x4[..., 2:, :2], (2, 0)),
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
],
dict(splitter=None, merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# using all default parameters (patches are already split)
TEST_CASE_5_SPLIT_LIST_c = [
[
(TENSOR_4x4[..., :2, :2], (0, 0)),
(TENSOR_4x4[..., :2, 2:], (0, 2)),
(TENSOR_4x4[..., 2:, :2], (2, 0)),
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
],
dict(merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# output smaller than input patches
TEST_CASE_6_SMALLER_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
lambda x, condition: torch.mean(x, dim=(-1, -2), keepdim=True) + torch.mean(condition, dim=(-1, -2), keepdim=True),
TENSOR_2x2 * 2,
]

# preprocess patches
TEST_CASE_7_PREPROCESS_c = [
TENSOR_4x4,
dict(
splitter=SlidingWindowSplitter(patch_size=(2, 2)),
merger_cls=AvgMerger,
preprocessing=lambda x: 2 * x,
postprocessing=None,
),
lambda x, condition: x + condition,
2 * TENSOR_4x4 + TENSOR_4x4,
]

# preprocess patches
TEST_CASE_8_POSTPROCESS_c = [
TENSOR_4x4,
dict(
splitter=SlidingWindowSplitter(patch_size=(2, 2)),
merger_cls=AvgMerger,
preprocessing=None,
postprocessing=lambda x: 4 * x,
),
lambda x, condition: x + condition,
4 * TENSOR_4x4 * 2,
]

# str merger as the class name
TEST_CASE_9_STR_MERGER_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="AvgMerger"),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# str merger as dotted patch
TEST_CASE_10_STR_MERGER_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls="monai.inferers.merger.AvgMerger"),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# non-divisible patch_size leading to larger image (without matching spatial shape)
TEST_CASE_11_PADDING_c = [
TENSOR_4x4,
dict(
splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode="constant", pad_value=0.0),
merger_cls=AvgMerger,
match_spatial_shape=False,
),
lambda x, condition: x + condition,
pad(TENSOR_4x4, (0, 2), value=0.0) * 2,
]

# non-divisible patch_size with matching spatial shapes
TEST_CASE_12_MATCHING_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=None), merger_cls=AvgMerger),
lambda x, condition: x + condition,
pad(TENSOR_4x4[..., :3], (0, 1), value=float("nan")) * 2,
]

# non-divisible patch_size with matching spatial shapes
TEST_CASE_13_PADDING_MATCHING_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 3)), merger_cls=AvgMerger),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# multi-threading
TEST_CASE_14_MULTITHREAD_BUFFER_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# multi-threading with batch
TEST_CASE_15_MULTITHREADD_BUFFER_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4),
lambda x, condition: x + condition,
TENSOR_4x4 * 2,
]

# list of tensor output
TEST_CASE_0_LIST_TENSOR_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
lambda x, condition: (x + condition, x + condition),
(TENSOR_4x4 * 2, TENSOR_4x4 * 2),
]

# list of tensor output
TEST_CASE_0_DICT_c = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),
lambda x, condition: {"model_output": x + condition},
{"model_output": TENSOR_4x4 * 2},
]


class PatchInfererTestsCond(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_0_TENSOR_c,
TEST_CASE_1_TENSOR_c,
TEST_CASE_2_TENSOR_c,
TEST_CASE_3_TENSOR_c,
TEST_CASE_4_SPLIT_LIST_c,
TEST_CASE_5_SPLIT_LIST_c,
TEST_CASE_6_SMALLER_c,
TEST_CASE_7_PREPROCESS_c,
TEST_CASE_8_POSTPROCESS_c,
TEST_CASE_9_STR_MERGER_c,
TEST_CASE_10_STR_MERGER_c,
TEST_CASE_11_PADDING_c,
TEST_CASE_12_MATCHING_c,
TEST_CASE_13_PADDING_MATCHING_c,
TEST_CASE_14_MULTITHREAD_BUFFER_c,
TEST_CASE_15_MULTITHREADD_BUFFER_c,
]
)
def test_patch_inferer_tensor(self, inputs, arguments, network, expected):
if isinstance(inputs, list): # case 4 and 5
condition = [(x[0].clone(), x[1]) for x in inputs]
else:
condition = inputs.clone()
inferer = PatchInferer(**arguments)
output = inferer(inputs=inputs, network=network, condition=condition)
assert_allclose(output, expected)

@parameterized.expand([TEST_CASE_0_LIST_TENSOR_c])
def test_patch_inferer_list_tensor(self, inputs, arguments, network, expected):
if isinstance(inputs, list): # case 4 and 5
condition = [(x[0].clone(), x[1]) for x in inputs]
else:
condition = inputs.clone()
inferer = PatchInferer(**arguments)
output = inferer(inputs=inputs, network=network, condition=condition)
for out, exp in zip(output, expected):
assert_allclose(out, exp)

@parameterized.expand([TEST_CASE_0_DICT_c])
def test_patch_inferer_dict(self, inputs, arguments, network, expected):
if isinstance(inputs, list): # case 4 and 5
condition = [(x[0].clone(), x[1]) for x in inputs]
else:
condition = inputs.clone()
inferer = PatchInferer(**arguments)
output = inferer(inputs=inputs, network=network, condition=condition)
for k in expected:
assert_allclose(output[k], expected[k])


if __name__ == "__main__":
unittest.main()
Loading
Loading