@@ -347,9 +347,14 @@ def __call__(
347
347
f"The provided inputs type is { type (inputs )} ."
348
348
)
349
349
patches_locations = inputs
350
+ if condition is not None :
351
+ condition_locations = condition
350
352
else :
351
353
# apply splitter
352
354
patches_locations = self .splitter (inputs )
355
+ if condition is not None :
356
+ # apply splitter to condition
357
+ condition_locations = self .splitter (condition )
353
358
354
359
ratios : list [float ] = []
355
360
mergers : list [Merger ] = []
@@ -776,6 +781,7 @@ def network_wrapper(
776
781
self ,
777
782
network : Callable [..., torch .Tensor | Sequence [torch .Tensor ] | dict [Any , torch .Tensor ]],
778
783
x : torch .Tensor ,
784
+ condition : torch .Tensor | None = None ,
779
785
* args : Any ,
780
786
** kwargs : Any ,
781
787
) -> torch .Tensor | tuple [torch .Tensor , ...] | dict [Any , torch .Tensor ]:
@@ -784,7 +790,12 @@ def network_wrapper(
784
790
"""
785
791
# Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
786
792
x = x .squeeze (dim = self .spatial_dim + 2 )
787
- out = network (x , * args , ** kwargs )
793
+
794
+ if condition is not None :
795
+ condition = condition .squeeze (dim = self .spatial_dim + 2 )
796
+ out = network (x , condition , * args , ** kwargs )
797
+ else :
798
+ out = network (x , * args , ** kwargs )
788
799
789
800
# Unsqueeze the network output so it is [N, C, D, H, W] as expected by
790
801
# the default SlidingWindowInferer class
0 commit comments