Skip to content

Commit 3b7f055

Browse files
fepegarCopilot
andcommitted
Fix RandomAnisotropy with copy=False producing wrong shape
Capture spatial_shape and affine as values before downsampling instead of reading them from the (now-mutated) image reference afterward. When copy=False, the downsample Resample modifies the image in-place, so the previously captured reference already reflects the downsampled state by the time the upsample target is constructed. This caused the upsample to be a no-op, leaving images at the smaller shape. Closes #1436 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 3318645 commit 3b7f055

2 files changed

Lines changed: 19 additions & 1 deletion

File tree

src/torchio/transforms/augmentation/spatial/random_anisotropy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def apply_transform(self, subject: Subject) -> Subject:
101101
# NOTE: If copy=False, the underlying image data will be modified in place.
102102
# We have to obtain the target spatial shape and affine before the transform
103103
image = subject.get_first_image()
104+
original_shape = image.spatial_shape
105+
original_affine = image.affine.copy()
104106
downsample = Resample(
105107
target=(
106108
float(target_spacing[0]),
@@ -118,7 +120,7 @@ def apply_transform(self, subject: Subject) -> Subject:
118120
)
119121
downsampled = downsample(subject)
120122
upsample = Resample(
121-
target=(image.spatial_shape, image.affine),
123+
target=(original_shape, original_affine),
122124
image_interpolation=self.image_interpolation,
123125
scalars_only=self.scalars_only,
124126
copy=self.copy,

tests/transforms/augmentation/test_random_anisotropy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@ def test_2d_rgb(self):
4646
image = ScalarImage(tensor=torch.rand(3, 4, 5, 6))
4747
RandomAnisotropy()(image)
4848

49+
def test_copy_false_preserves_shape(self):
50+
"""Output shape and spacing must match input when copy=False (#1436)."""
51+
subject = tio.Subject(
52+
t1=tio.ScalarImage(tensor=torch.randn(1, 20, 22, 18)),
53+
)
54+
transform = RandomAnisotropy(
55+
axes=1,
56+
downsampling=(2, 2),
57+
copy=False,
58+
)
59+
original_shape = subject.shape
60+
original_spacing = subject.spacing
61+
result = transform(subject)
62+
assert result.shape == original_shape
63+
assert result.spacing == original_spacing
64+
4965
def test_2d_with_axis_2_warns(self):
5066
"""Applying to 2D image with axis 2 in axes warns and excludes it."""
5167
image = ScalarImage(tensor=torch.rand(1, 10, 10, 1))

0 commit comments

Comments
 (0)