Skip to content

Commit c4c65ae

Browse files
committed
fix inferer
1 parent 248d3c3 commit c4c65ae

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

monai/inferers/inferer.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,14 @@ def __call__(
347347
f"The provided inputs type is {type(inputs)}."
348348
)
349349
patches_locations = inputs
350+
if condition is not None:
351+
condition_locations = condition
350352
else:
351353
# apply splitter
352354
patches_locations = self.splitter(inputs)
355+
if condition is not None:
356+
# apply splitter to condition
357+
condition_locations = self.splitter(condition)
353358

354359
ratios: list[float] = []
355360
mergers: list[Merger] = []
@@ -776,6 +781,7 @@ def network_wrapper(
776781
self,
777782
network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
778783
x: torch.Tensor,
784+
condition: torch.Tensor | None = None,
779785
*args: Any,
780786
**kwargs: Any,
781787
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
@@ -784,7 +790,12 @@ def network_wrapper(
784790
"""
785791
# Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
786792
x = x.squeeze(dim=self.spatial_dim + 2)
787-
out = network(x, *args, **kwargs)
793+
794+
if condition is not None:
795+
condition = condition.squeeze(dim=self.spatial_dim + 2)
796+
out = network(x, condition, *args, **kwargs)
797+
else:
798+
out = network(x, *args, **kwargs)
788799

789800
# Unsqueeze the network output so it is [N, C, D, H, W] as expected by
790801
# the default SlidingWindowInferer class

0 commit comments

Comments
 (0)