Skip to content

Commit b0619f6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d986169 commit b0619f6

File tree

3 files changed

+59
-72
lines changed

3 files changed

+59
-72
lines changed

examples/mri_testscan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def show_results(titles, images, save_path=None):
6464
# ============================================================
6565
def main():
6666
# -------- LOAD MRI --------
67-
vol = load_mri("/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz")
67+
vol = load_mri("/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz")
6868
print("Loaded MRI:", vol.shape)
6969

7070
orig = middle_slice(vol) # (H,W)

heat/ndimage/affine.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Utility: normalize input → (N, C, spatial…)
1717
# ============================================================
1818

19+
1920
def _normalize_input(x, ND):
2021
"""
2122
Normalize a Heat array to include batch and channel dimensions.
@@ -38,14 +39,14 @@ def _normalize_input(x, ND):
3839
t = x.larray
3940

4041
if ND == 2:
41-
if x.ndim == 2: # (H, W)
42+
if x.ndim == 2: # (H, W)
4243
t = t.unsqueeze(0).unsqueeze(0)
43-
elif x.ndim == 3: # (C, H, W)
44+
elif x.ndim == 3: # (C, H, W)
4445
t = t.unsqueeze(0)
4546
else:
46-
if x.ndim == 3: # (D, H, W)
47+
if x.ndim == 3: # (D, H, W)
4748
t = t.unsqueeze(0).unsqueeze(0)
48-
elif x.ndim == 4: # (C, D, H, W)
49+
elif x.ndim == 4: # (C, D, H, W)
4950
t = t.unsqueeze(0)
5051

5152
return t, orig_shape
@@ -55,6 +56,7 @@ def _normalize_input(x, ND):
5556
# Utility: build coordinate grid in Heat order
5657
# ============================================================
5758

59+
5860
def _make_grid(spatial, device):
5961
"""
6062
Construct a coordinate grid in Heat axis order.
@@ -90,6 +92,7 @@ def _make_grid(spatial, device):
9092
# Padding helper
9193
# ============================================================
9294

95+
9396
def _apply_padding(pix, spatial, mode, constant_value):
9497
"""
9598
Apply boundary handling to sampled pixel indices.
@@ -140,6 +143,7 @@ def _apply_padding(pix, spatial, mode, constant_value):
140143
# Nearest neighbor sampler
141144
# ============================================================
142145

146+
143147
def _nearest_sample(x_local, coords_h, mode, constant_value):
144148
"""
145149
Sample an image using nearest-neighbor interpolation.
@@ -187,6 +191,7 @@ def _nearest_sample(x_local, coords_h, mode, constant_value):
187191
# Bilinear sampling (2D only)
188192
# ============================================================
189193

194+
190195
def _bilinear_sample(x_local, coords_h, mode, constant_value):
191196
"""
192197
Sample a 2D image using bilinear interpolation.
@@ -229,6 +234,7 @@ def _bilinear_sample(x_local, coords_h, mode, constant_value):
229234
# Public API
230235
# ============================================================
231236

237+
232238
def affine_transform(
233239
x,
234240
M,
@@ -301,9 +307,7 @@ def affine_transform(
301307
else:
302308
cx, cy, cz = coords_pt
303309
coords_h = torch.stack(
304-
[cz.reshape(spatial),
305-
cy.reshape(spatial),
306-
cx.reshape(spatial)],
310+
[cz.reshape(spatial), cy.reshape(spatial), cx.reshape(spatial)],
307311
dim=0,
308312
)
309313

heat/ndimage/mri_testscan.py

Lines changed: 47 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def show_results(titles, images, save_path=None):
4646
ax.set_title(title)
4747
ax.axis("off")
4848

49-
for ax in axs[len(images):]:
49+
for ax in axs[len(images) :]:
5050
ax.axis("off")
5151

5252
plt.tight_layout()
@@ -60,10 +60,12 @@ def show_results(titles, images, save_path=None):
6060
# ============================================================
6161
def main():
6262
# -------- LOAD MRI --------
63-
vol = load_mri("/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz") # You already downloaded your own
63+
vol = load_mri(
64+
"/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz"
65+
) # You already downloaded your own
6466
print("Loaded MRI:", vol.shape)
6567

66-
orig = middle_slice(vol) # (H,W)
68+
orig = middle_slice(vol) # (H,W)
6769
x = to_heat_slice(orig)
6870

6971
H, W = orig.shape
@@ -72,93 +74,74 @@ def main():
7274
# Define transforms (all center-aware)
7375
# ========================================================
7476

75-
cx, cy = W/2, H/2
77+
cx, cy = W / 2, H / 2
7678

7779
# Helper to shift center for rotation/scale
7880
def recenter(M):
79-
"""
80-
Input: 2x3 affine matrix.
81-
Output: 2x3 affine matrix recentered around the image center.
82-
"""
83-
84-
cx, cy = W/2, H/2
85-
86-
# Convert 2×3 → 3×3 homogeneous
87-
M3 = np.array([
88-
[M[0,0], M[0,1], M[0,2]],
89-
[M[1,0], M[1,1], M[1,2]],
90-
[0, 0, 1 ]
91-
], dtype=np.float32)
92-
93-
# Center shift matrices
94-
T1 = np.array([
95-
[1, 0, -cx],
96-
[0, 1, -cy],
97-
[0, 0, 1 ]
98-
], np.float32)
99-
100-
T2 = np.array([
101-
[1, 0, cx],
102-
[0, 1, cy],
103-
[0, 0, 1 ]
104-
], np.float32)
105-
106-
# Recenter: T2 * M * T1
107-
M_centered = T2 @ M3 @ T1
108-
109-
# Return as 2×3
110-
return np.array([
111-
[M_centered[0,0], M_centered[0,1], M_centered[0,2]],
112-
[M_centered[1,0], M_centered[1,1], M_centered[1,2]],
113-
], dtype=np.float32)
81+
"""
82+
Input: 2x3 affine matrix.
83+
Output: 2x3 affine matrix recentered around the image center.
84+
"""
85+
cx, cy = W / 2, H / 2
86+
87+
# Convert 2×3 → 3×3 homogeneous
88+
M3 = np.array(
89+
[[M[0, 0], M[0, 1], M[0, 2]], [M[1, 0], M[1, 1], M[1, 2]], [0, 0, 1]], dtype=np.float32
90+
)
91+
92+
# Center shift matrices
93+
T1 = np.array([[1, 0, -cx], [0, 1, -cy], [0, 0, 1]], np.float32)
94+
95+
T2 = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]], np.float32)
96+
97+
# Recenter: T2 * M * T1
98+
M_centered = T2 @ M3 @ T1
99+
100+
# Return as 2×3
101+
return np.array(
102+
[
103+
[M_centered[0, 0], M_centered[0, 1], M_centered[0, 2]],
104+
[M_centered[1, 0], M_centered[1, 1], M_centered[1, 2]],
105+
],
106+
dtype=np.float32,
107+
)
114108

115109
# ROTATION
116110
angle = np.radians(20)
117-
M_rot = np.array([
118-
[np.cos(angle), -np.sin(angle), 0],
119-
[np.sin(angle), np.cos(angle), 0]
120-
], np.float32)
111+
M_rot = np.array(
112+
[[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0]], np.float32
113+
)
121114
M_rot = recenter(M_rot)
122115

123116
# SCALE
124117
s = 1.2
125-
M_scale = np.array([
126-
[s, 0, 0],
127-
[0, s, 0]
128-
], np.float32)
118+
M_scale = np.array([[s, 0, 0], [0, s, 0]], np.float32)
129119
M_scale = recenter(M_scale)
130120

131121
# TRANSLATE
132-
M_trans = np.array([
133-
[1, 0, 20],
134-
[0, 1, -20]
135-
], np.float32)
122+
M_trans = np.array([[1, 0, 20], [0, 1, -20]], np.float32)
136123

137124
# SHEAR
138125
sh = 0.3
139-
M_shear = np.array([
140-
[1, sh, 0],
141-
[0, 1, 0]
142-
], np.float32)
126+
M_shear = np.array([[1, sh, 0], [0, 1, 0]], np.float32)
143127
M_shear = recenter(M_shear)
144128

145129
# 3D ROTATION ABOUT Z AXIS APPLIED TO 2D SLICE (equivalent)
146130
angle = np.radians(35)
147-
M_rotZ = np.array([
148-
[np.cos(angle), -np.sin(angle), 0],
149-
[np.sin(angle), np.cos(angle), 0]
150-
], np.float32)
131+
M_rotZ = np.array(
132+
[[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0]], np.float32
133+
)
151134
M_rotZ = recenter(M_rotZ)
152135

153136
# ========================================================
154137
# Run transformations
155138
# ========================================================
156139

157-
out_rot = affine_transform(x, M_rot, order=1).numpy()
158-
out_scale = affine_transform(x, M_scale, order=1).numpy()
159-
out_trans = affine_transform(x, M_trans, order=1).numpy()
160-
out_shear = affine_transform(x, M_shear, order=1).numpy()
161-
out_rotZ = affine_transform(x, M_rotZ, order=1).numpy()
140+
out_rot = affine_transform(x, M_rot, order=1).numpy()
141+
out_scale = affine_transform(x, M_scale, order=1).numpy()
142+
out_trans = affine_transform(x, M_trans, order=1).numpy()
143+
out_shear = affine_transform(x, M_shear, order=1).numpy()
144+
out_rotZ = affine_transform(x, M_rotZ, order=1).numpy()
162145

163146
# ========================================================
164147
# Show + save

0 commit comments

Comments
 (0)