7474)
7575TypeControlPoints : TypeAlias = Tensor | npt .ArrayLike
7676TypeTargetSpace : TypeAlias = tuple [TypeThreeInts , AffineMatrix ]
77- TypeInterpolation : TypeAlias = Literal ["nearest" , "linear" ]
77+ TypeInterpolation : TypeAlias = Literal [
78+ "nearest" ,
79+ "linear" ,
80+ "quadratic" ,
81+ "cubic" ,
82+ "fourth" ,
83+ "fifth" ,
84+ "sixth" ,
85+ "seventh" ,
86+ ]
7887TypeCenter : TypeAlias = Literal ["image" , "origin" ]
7988TypePadValue : TypeAlias = Literal ["minimum" , "mean" , "otsu" ]
8089
81- _SUPPORTED_INTERPOLATIONS = ("nearest" , "linear" )
90+ _SUPPORTED_INTERPOLATIONS = (
91+ "nearest" ,
92+ "linear" ,
93+ "quadratic" ,
94+ "cubic" ,
95+ "fourth" ,
96+ "fifth" ,
97+ "sixth" ,
98+ "seventh" ,
99+ )
100+ _INTERPOLATION_TO_ORDER : dict [str , int ] = {
101+ "nearest" : 0 ,
102+ "linear" : 1 ,
103+ "quadratic" : 2 ,
104+ "cubic" : 3 ,
105+ "fourth" : 4 ,
106+ "fifth" : 5 ,
107+ "sixth" : 6 ,
108+ "seventh" : 7 ,
109+ }
82110_TORCH_INTERPOLATION_MODE = {
83111 "nearest" : "nearest" ,
84112 "linear" : "bilinear" ,
@@ -674,7 +702,8 @@ def _apply_spatial_to_batch(
674702 img_batch .data = _sample_batch (
675703 data ,
676704 grid ,
677- mode = _TORCH_INTERPOLATION_MODE [interpolation ],
705+ input_shape = input_shape ,
706+ interpolation = interpolation ,
678707 fill_value = fill_value ,
679708 )
680709 new_affine = output_affine .clone ()
@@ -791,7 +820,7 @@ def _build_sampling_grid(
791820 affine_first : bool ,
792821 device : torch .device ,
793822) -> Tensor :
794- """Build a normalized sampling grid for ``F.grid_sample`` .
823+ """Build a sampling grid in **input voxel coordinates** .
795824
796825 The grid maps each output voxel to its source location in the input
797826 volume. The mapping is:
@@ -806,8 +835,8 @@ def _build_sampling_grid(
806835 mapping or the elastic displacement is applied first.
807836
808837 Returns:
809- Grid tensor with shape ``(1, K_out, J_out, I_out , 3)`` in the
810- ``[-1, 1]`` range expected by ``F.grid_sample`` .
838+ Grid tensor with shape ``(I_out, J_out, K_out , 3)`` in input
839+ voxel coordinates .
811840 """
812841 mapping = _output_to_input_voxel_matrix (
813842 input_affine = input_affine ,
@@ -818,8 +847,7 @@ def _build_sampling_grid(
818847 output_coords = _output_voxel_coordinates (output_shape , device )
819848
820849 if control_points is None :
821- input_voxels = _apply_voxel_mapping (output_coords , mapping )
822- return _voxel_coordinates_to_grid (input_voxels , input_shape )
850+ return _apply_voxel_mapping (output_coords , mapping )
823851
824852 output_spacing = np .asarray (output_affine .spacing , dtype = np .float64 )
825853 if max_displacement is None :
@@ -855,7 +883,7 @@ def _build_sampling_grid(
855883 deformed_output = output_coords + displacement / output_spacing_t
856884 input_voxels = _apply_voxel_mapping (deformed_output , mapping )
857885
858- return _voxel_coordinates_to_grid ( input_voxels , input_shape )
886+ return input_voxels
859887
860888
861889def _output_to_input_voxel_matrix (
@@ -929,23 +957,60 @@ def _voxel_coordinates_to_grid(
929957
930958def _sample_batch (
931959 data : Tensor ,
932- grid : Tensor ,
960+ voxel_grid : Tensor ,
933961 * ,
934- mode : str ,
962+ input_shape : TypeThreeInts ,
963+ interpolation : str ,
935964 fill_value : float | Tensor ,
936965) -> Tensor :
937966 """Resample a 5D batch using a shared sampling grid.
938967
968+ For interpolation orders 0-1 (nearest, linear), the fast
969+ ``F.grid_sample`` path is used. For orders 2+ (quadratic,
970+ cubic, ...), ``interpol.grid_pull`` provides high-order B-spline
971+ interpolation.
972+
939973 Args:
940974 data: ``(B, C, I, J, K)`` image batch.
941- grid: ``(1, K_out, J_out, I_out, 3)`` sampling grid (broadcast to B).
942- mode: Interpolation mode for ``grid_sample``.
943- fill_value: Scalar or per-channel fill for out-of-bounds samples.
975+ voxel_grid: ``(I_out, J_out, K_out, 3)`` sampling grid in
976+ input voxel coordinates.
977+ input_shape: Spatial shape of the input volume ``(I, J, K)``.
978+ interpolation: Interpolation mode name.
979+ fill_value: Scalar or per-channel fill for out-of-bounds
980+ samples.
944981
945982 Returns:
946983 Resampled ``(B, C, I_out, J_out, K_out)`` tensor.
947984 """
985+ order = _INTERPOLATION_TO_ORDER [interpolation ]
986+ if order <= 1 :
987+ return _sample_batch_grid_sample (
988+ data ,
989+ voxel_grid ,
990+ input_shape = input_shape ,
991+ mode = _TORCH_INTERPOLATION_MODE [interpolation ],
992+ fill_value = fill_value ,
993+ )
994+ return _sample_batch_interpol (
995+ data ,
996+ voxel_grid ,
997+ order = order ,
998+ fill_value = fill_value ,
999+ )
1000+
1001+
1002+ def _sample_batch_grid_sample (
1003+ data : Tensor ,
1004+ voxel_grid : Tensor ,
1005+ * ,
1006+ input_shape : TypeThreeInts ,
1007+ mode : str ,
1008+ fill_value : float | Tensor ,
1009+ ) -> Tensor :
1010+ """Fast path: resample with F.grid_sample (orders 0-1)."""
9481011 batch_size = data .shape [0 ]
1012+ # Normalize voxel coords to [-1, 1] for grid_sample.
1013+ grid = _voxel_coordinates_to_grid (voxel_grid , input_shape )
9491014 # (B, C, I, J, K) -> (B, C, K, J, I) for grid_sample
9501015 input_5d = rearrange (data , "b c i j k -> b c k j i" ).float ()
9511016 # Expand grid from (1, ...) to (B, ...)
@@ -973,6 +1038,36 @@ def _sample_batch(
9731038 return rearrange (sampled , "b c k j i -> b c i j k" ).to (data .dtype )
9741039
9751040
1041+ def _sample_batch_interpol (
1042+ data : Tensor ,
1043+ voxel_grid : Tensor ,
1044+ * ,
1045+ order : int ,
1046+ fill_value : float | Tensor ,
1047+ ) -> Tensor :
1048+ """High-quality path: resample with interpol.grid_pull (orders 2+).
1049+
1050+ ``interpol.grid_pull`` works in voxel coordinates natively and
1051+ supports B-spline orders up to 7.
1052+ """
1053+ import interpol
1054+
1055+ batch_size = data .shape [0 ]
1056+ # interpol expects: input (B, C, *spatial), grid (B, *spatial, D)
1057+ # Our grid is (I_out, J_out, K_out, 3) — add batch dim.
1058+ grid_b = voxel_grid .unsqueeze (0 ).expand (batch_size , - 1 , - 1 , - 1 , - 1 )
1059+
1060+ sampled = interpol .grid_pull (
1061+ data .float (),
1062+ grid_b ,
1063+ interpolation = order ,
1064+ bound = "dct2" ,
1065+ extrapolate = False ,
1066+ prefilter = True ,
1067+ )
1068+ return sampled .to (data .dtype )
1069+
1070+
9761071def _antialias_batch (
9771072 data : Tensor ,
9781073 input_affine : AffineMatrix ,
@@ -1632,10 +1727,21 @@ def _parse_spacing(
16321727 return spacing
16331728
16341729
1635- def _parse_interpolation (interpolation : TypeInterpolation ) -> TypeInterpolation :
1636- """Validate and lower-case an interpolation mode string."""
1730+ def _parse_interpolation (
1731+ interpolation : TypeInterpolation | int ,
1732+ ) -> TypeInterpolation :
1733+ """Validate an interpolation mode (string or integer order).
1734+
1735+ Integer orders (0-7) are converted to their string equivalents.
1736+ """
1737+ if isinstance (interpolation , int ):
1738+ order_to_name = {v : k for k , v in _INTERPOLATION_TO_ORDER .items ()}
1739+ if interpolation not in order_to_name :
1740+ msg = f"Interpolation order { interpolation } is not supported. Must be 0-7."
1741+ raise ValueError (msg )
1742+ return cast (TypeInterpolation , order_to_name [interpolation ])
16371743 if not isinstance (interpolation , str ):
1638- msg = f"Interpolation must be a string, got { type (interpolation )} "
1744+ msg = f"Interpolation must be a string or int , got { type (interpolation )} "
16391745 raise TypeError (msg )
16401746 lowered = interpolation .lower ()
16411747 if lowered not in _SUPPORTED_INTERPOLATIONS :
0 commit comments