Skip to content

Commit ee55e01

Browse files
committed
Add torch-interpol for better interpolation
1 parent 33d7d71 commit ee55e01

4 files changed

Lines changed: 193 additions & 25 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"rich",
4747
"simpleitk",
4848
"torch",
49+
"torch-interpol>=0.2.6",
4950
"typing_extensions>=4.0",
5051
"tyro>=1.0.12",
5152
]

src/torchio/transforms/spatial.py

Lines changed: 123 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,39 @@
7474
)
7575
TypeControlPoints: TypeAlias = Tensor | npt.ArrayLike
7676
TypeTargetSpace: TypeAlias = tuple[TypeThreeInts, AffineMatrix]
77-
TypeInterpolation: TypeAlias = Literal["nearest", "linear"]
77+
TypeInterpolation: TypeAlias = Literal[
78+
"nearest",
79+
"linear",
80+
"quadratic",
81+
"cubic",
82+
"fourth",
83+
"fifth",
84+
"sixth",
85+
"seventh",
86+
]
7887
TypeCenter: TypeAlias = Literal["image", "origin"]
7988
TypePadValue: TypeAlias = Literal["minimum", "mean", "otsu"]
8089

81-
_SUPPORTED_INTERPOLATIONS = ("nearest", "linear")
90+
_SUPPORTED_INTERPOLATIONS = (
91+
"nearest",
92+
"linear",
93+
"quadratic",
94+
"cubic",
95+
"fourth",
96+
"fifth",
97+
"sixth",
98+
"seventh",
99+
)
100+
_INTERPOLATION_TO_ORDER: dict[str, int] = {
101+
"nearest": 0,
102+
"linear": 1,
103+
"quadratic": 2,
104+
"cubic": 3,
105+
"fourth": 4,
106+
"fifth": 5,
107+
"sixth": 6,
108+
"seventh": 7,
109+
}
82110
_TORCH_INTERPOLATION_MODE = {
83111
"nearest": "nearest",
84112
"linear": "bilinear",
@@ -674,7 +702,8 @@ def _apply_spatial_to_batch(
674702
img_batch.data = _sample_batch(
675703
data,
676704
grid,
677-
mode=_TORCH_INTERPOLATION_MODE[interpolation],
705+
input_shape=input_shape,
706+
interpolation=interpolation,
678707
fill_value=fill_value,
679708
)
680709
new_affine = output_affine.clone()
@@ -791,7 +820,7 @@ def _build_sampling_grid(
791820
affine_first: bool,
792821
device: torch.device,
793822
) -> Tensor:
794-
"""Build a normalized sampling grid for ``F.grid_sample``.
823+
"""Build a sampling grid in **input voxel coordinates**.
795824
796825
The grid maps each output voxel to its source location in the input
797826
volume. The mapping is:
@@ -806,8 +835,8 @@ def _build_sampling_grid(
806835
mapping or the elastic displacement is applied first.
807836
808837
Returns:
809-
Grid tensor with shape ``(1, K_out, J_out, I_out, 3)`` in the
810-
``[-1, 1]`` range expected by ``F.grid_sample``.
838+
Grid tensor with shape ``(I_out, J_out, K_out, 3)`` in input
839+
voxel coordinates.
811840
"""
812841
mapping = _output_to_input_voxel_matrix(
813842
input_affine=input_affine,
@@ -818,8 +847,7 @@ def _build_sampling_grid(
818847
output_coords = _output_voxel_coordinates(output_shape, device)
819848

820849
if control_points is None:
821-
input_voxels = _apply_voxel_mapping(output_coords, mapping)
822-
return _voxel_coordinates_to_grid(input_voxels, input_shape)
850+
return _apply_voxel_mapping(output_coords, mapping)
823851

824852
output_spacing = np.asarray(output_affine.spacing, dtype=np.float64)
825853
if max_displacement is None:
@@ -855,7 +883,7 @@ def _build_sampling_grid(
855883
deformed_output = output_coords + displacement / output_spacing_t
856884
input_voxels = _apply_voxel_mapping(deformed_output, mapping)
857885

858-
return _voxel_coordinates_to_grid(input_voxels, input_shape)
886+
return input_voxels
859887

860888

861889
def _output_to_input_voxel_matrix(
@@ -929,23 +957,60 @@ def _voxel_coordinates_to_grid(
929957

930958
def _sample_batch(
931959
data: Tensor,
932-
grid: Tensor,
960+
voxel_grid: Tensor,
933961
*,
934-
mode: str,
962+
input_shape: TypeThreeInts,
963+
interpolation: str,
935964
fill_value: float | Tensor,
936965
) -> Tensor:
937966
"""Resample a 5D batch using a shared sampling grid.
938967
968+
For interpolation orders 0-1 (nearest, linear), the fast
969+
``F.grid_sample`` path is used. For orders 2+ (quadratic,
970+
cubic, ...), ``interpol.grid_pull`` provides high-order B-spline
971+
interpolation.
972+
939973
Args:
940974
data: ``(B, C, I, J, K)`` image batch.
941-
grid: ``(1, K_out, J_out, I_out, 3)`` sampling grid (broadcast to B).
942-
mode: Interpolation mode for ``grid_sample``.
943-
fill_value: Scalar or per-channel fill for out-of-bounds samples.
975+
voxel_grid: ``(I_out, J_out, K_out, 3)`` sampling grid in
976+
input voxel coordinates.
977+
input_shape: Spatial shape of the input volume ``(I, J, K)``.
978+
interpolation: Interpolation mode name.
979+
fill_value: Scalar or per-channel fill for out-of-bounds
980+
samples.
944981
945982
Returns:
946983
Resampled ``(B, C, I_out, J_out, K_out)`` tensor.
947984
"""
985+
order = _INTERPOLATION_TO_ORDER[interpolation]
986+
if order <= 1:
987+
return _sample_batch_grid_sample(
988+
data,
989+
voxel_grid,
990+
input_shape=input_shape,
991+
mode=_TORCH_INTERPOLATION_MODE[interpolation],
992+
fill_value=fill_value,
993+
)
994+
return _sample_batch_interpol(
995+
data,
996+
voxel_grid,
997+
order=order,
998+
fill_value=fill_value,
999+
)
1000+
1001+
1002+
def _sample_batch_grid_sample(
1003+
data: Tensor,
1004+
voxel_grid: Tensor,
1005+
*,
1006+
input_shape: TypeThreeInts,
1007+
mode: str,
1008+
fill_value: float | Tensor,
1009+
) -> Tensor:
1010+
"""Fast path: resample with F.grid_sample (orders 0-1)."""
9481011
batch_size = data.shape[0]
1012+
# Normalize voxel coords to [-1, 1] for grid_sample.
1013+
grid = _voxel_coordinates_to_grid(voxel_grid, input_shape)
9491014
# (B, C, I, J, K) -> (B, C, K, J, I) for grid_sample
9501015
input_5d = rearrange(data, "b c i j k -> b c k j i").float()
9511016
# Expand grid from (1, ...) to (B, ...)
@@ -973,6 +1038,36 @@ def _sample_batch(
9731038
return rearrange(sampled, "b c k j i -> b c i j k").to(data.dtype)
9741039

9751040

1041+
def _sample_batch_interpol(
1042+
data: Tensor,
1043+
voxel_grid: Tensor,
1044+
*,
1045+
order: int,
1046+
fill_value: float | Tensor,
1047+
) -> Tensor:
1048+
"""High-quality path: resample with interpol.grid_pull (orders 2+).
1049+
1050+
``interpol.grid_pull`` works in voxel coordinates natively and
1051+
supports B-spline orders up to 7.
1052+
"""
1053+
import interpol
1054+
1055+
batch_size = data.shape[0]
1056+
# interpol expects: input (B, C, *spatial), grid (B, *spatial, D)
1057+
# Our grid is (I_out, J_out, K_out, 3) — add batch dim.
1058+
grid_b = voxel_grid.unsqueeze(0).expand(batch_size, -1, -1, -1, -1)
1059+
1060+
sampled = interpol.grid_pull(
1061+
data.float(),
1062+
grid_b,
1063+
interpolation=order,
1064+
bound="dct2",
1065+
extrapolate=False,
1066+
prefilter=True,
1067+
)
1068+
return sampled.to(data.dtype)
1069+
1070+
9761071
def _antialias_batch(
9771072
data: Tensor,
9781073
input_affine: AffineMatrix,
@@ -1632,10 +1727,21 @@ def _parse_spacing(
16321727
return spacing
16331728

16341729

1635-
def _parse_interpolation(interpolation: TypeInterpolation) -> TypeInterpolation:
1636-
"""Validate and lower-case an interpolation mode string."""
1730+
def _parse_interpolation(
1731+
interpolation: TypeInterpolation | int,
1732+
) -> TypeInterpolation:
1733+
"""Validate an interpolation mode (string or integer order).
1734+
1735+
Integer orders (0-7) are converted to their string equivalents.
1736+
"""
1737+
if isinstance(interpolation, int):
1738+
order_to_name = {v: k for k, v in _INTERPOLATION_TO_ORDER.items()}
1739+
if interpolation not in order_to_name:
1740+
msg = f"Interpolation order {interpolation} is not supported. Must be 0-7."
1741+
raise ValueError(msg)
1742+
return cast(TypeInterpolation, order_to_name[interpolation])
16371743
if not isinstance(interpolation, str):
1638-
msg = f"Interpolation must be a string, got {type(interpolation)}"
1744+
msg = f"Interpolation must be a string or int, got {type(interpolation)}"
16391745
raise TypeError(msg)
16401746
lowered = interpolation.lower()
16411747
if lowered not in _SUPPORTED_INTERPOLATIONS:

tests/test_spatial.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,19 @@ def test_parse_control_points_axis_too_small(self) -> None:
463463

464464
def test_parse_interpolation_invalid(self) -> None:
465465
with pytest.raises(ValueError, match="not supported"):
466-
_parse_interpolation("cubic") # type: ignore[arg-type]
466+
_parse_interpolation("bicubic") # type: ignore[arg-type]
467+
468+
def test_parse_interpolation_int(self) -> None:
469+
assert _parse_interpolation(3) == "cubic"
470+
assert _parse_interpolation(0) == "nearest"
471+
472+
def test_parse_interpolation_int_invalid(self) -> None:
473+
with pytest.raises(ValueError, match="not supported"):
474+
_parse_interpolation(99)
467475

468476
def test_parse_interpolation_not_string(self) -> None:
469-
with pytest.raises(TypeError, match="string"):
470-
_parse_interpolation(42) # type: ignore[arg-type]
477+
with pytest.raises(TypeError, match="string or int"):
478+
_parse_interpolation(42.5) # type: ignore[arg-type]
471479

472480
def test_parse_default_pad_value_invalid_string(self) -> None:
473481
with pytest.raises(ValueError, match="minimum"):
@@ -639,7 +647,7 @@ def test_build_grid_elastic_no_max_displacement(self) -> None:
639647
affine_first=True,
640648
device=torch.device("cpu"),
641649
)
642-
assert grid.shape == (1, 11, 11, 11, 3)
650+
assert grid.shape == (11, 11, 11, 3)
643651

644652
def test_batch_fill_value_bad_type(self) -> None:
645653
"""_batch_fill_value raises TypeError for non-str non-number (lines 907-911)."""
@@ -693,3 +701,56 @@ def test_root_exports_expose_transform_and_matrix(self) -> None:
693701
assert hasattr(tio, "ElasticDeformation")
694702
assert tio.Affine is AffineTransform
695703
assert tio.AffineMatrix is AffineMatrix
704+
705+
706+
# ---------------------------------------------------------------------------
707+
# High-order interpolation (torch-interpol)
708+
# ---------------------------------------------------------------------------
709+
710+
711+
class TestHighOrderInterpolation:
712+
def test_cubic_produces_different_result_from_linear(self) -> None:
713+
subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 16, 16, 16)))
714+
linear = tio.Affine(
715+
degrees=10,
716+
image_interpolation="linear",
717+
)(subject)
718+
cubic = tio.Affine(
719+
degrees=10,
720+
image_interpolation="cubic",
721+
)(subject)
722+
# Same params won't be sampled, so compare shapes at least
723+
assert linear.t1.data.shape == cubic.t1.data.shape
724+
725+
def test_cubic_resample(self) -> None:
726+
subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 16, 16, 16)))
727+
result = tio.Resample(
728+
target=2.0,
729+
image_interpolation="cubic",
730+
)(subject)
731+
assert result.t1.data.shape[1:] == (8, 8, 8)
732+
733+
def test_quadratic_interpolation(self) -> None:
734+
subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 10, 10, 10)))
735+
result = tio.Affine(
736+
degrees=5,
737+
image_interpolation="quadratic",
738+
)(subject)
739+
assert result.t1.data.shape == subject.t1.data.shape
740+
741+
def test_int_order_3(self) -> None:
742+
subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 10, 10, 10)))
743+
result = tio.Affine(
744+
degrees=5,
745+
image_interpolation=3,
746+
)(subject)
747+
assert result.t1.data.shape == subject.t1.data.shape
748+
749+
def test_order_0_uses_fast_path(self) -> None:
750+
"""Nearest interpolation should still work via F.grid_sample."""
751+
subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 10, 10, 10)))
752+
result = tio.Affine(
753+
degrees=5,
754+
image_interpolation="nearest",
755+
)(subject)
756+
assert result.t1.data.shape == subject.t1.data.shape

tests/test_visualization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,17 @@ def test_orientation_labels_show_tensor_axis(self) -> None:
115115
ax = fig.axes[0]
116116
xlabel = ax.get_xlabel()
117117
ylabel = ax.get_ylabel()
118-
# Default (voxels=False): "Anterior (J)" format
119-
assert any(c in xlabel for c in ("I", "J", "K"))
120-
assert any(c in ylabel for c in ("I", "J", "K"))
118+
# Default (voxels=False): "Anterior [mm] (j)" format
119+
assert any(c in xlabel for c in ("i", "j", "k"))
120+
assert any(c in ylabel for c in ("i", "j", "k"))
121121

122122
def test_voxel_labels_show_arrow(self) -> None:
123123
img = tio.ScalarImage(torch.rand(1, 10, 10, 10))
124124
fig = img.plot(show=False, voxels=True)
125125
ax = fig.axes[0]
126126
xlabel = ax.get_xlabel()
127127
ylabel = ax.get_ylabel()
128-
# voxels=True: "J (A ↔ P)" format
128+
# voxels=True: "j (A ↔ P)" format
129129
assert "↔" in xlabel
130130
assert "↔" in ylabel
131131

0 commit comments

Comments
 (0)