Skip to content

Commit 74f37fa

Browse files
committed
examples
1 parent 20ec2d9 commit 74f37fa

File tree

5 files changed

+184
-0
lines changed

5 files changed

+184
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
End-user example: apply an affine transformation to a 3D NIfTI image
3+
and visualize the result directly in Python.
4+
5+
This script:
6+
1. Loads x.nii.gz
7+
2. Applies a 3D affine transformation using heat.ndimage.affine_transform
8+
3. Saves the transformed volume as x_transformed.nii.gz
9+
4. Displays a side-by-side comparison of the middle slice
10+
11+
Requirements:
12+
- nibabel
13+
- matplotlib
14+
- heat
15+
"""
16+
17+
import nibabel as nib
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
import heat as ht
21+
from heat.ndimage.affine import affine_transform
22+
23+
24+
# ============================================================
25+
# STEP 1: Load NIfTI file
26+
# ============================================================
27+
28+
print("Loading x.nii.gz ...")
29+
30+
nii = nib.load("heat/datasets/flair.nii.gz")
31+
x_np = nii.get_fdata().astype(np.float32)
32+
33+
print("Input shape:", x_np.shape)
34+
35+
36+
# ============================================================
37+
# STEP 2: Convert to Heat array
38+
# ============================================================
39+
40+
x = ht.array(x_np)
41+
42+
print("Converted to Heat array.")
43+
44+
45+
# ============================================================
46+
# STEP 3: Define affine transform (3D)
47+
# ============================================================
48+
49+
"""
50+
Affine matrix (3x4):
51+
52+
[ a11 a12 a13 tx ]
53+
[ a21 a22 a23 ty ]
54+
[ a31 a32 a33 tz ]
55+
56+
Below: translate volume by +20 voxels in x-direction
57+
"""
58+
D, H, W = x_np.shape
59+
cx, cy, cz = D / 2, H / 2, W / 2
60+
s = 1.4
61+
62+
M = [
63+
[s, 0, 0, cx * (1 - s)],
64+
[0, s, 0, cy * (1 - s)],
65+
[0, 0, s, cz * (1 - s)],
66+
]
67+
68+
69+
# ============================================================
70+
# STEP 4: Apply affine transform
71+
# ============================================================
72+
73+
print("Applying affine transform...")
74+
75+
y = affine_transform(
76+
x,
77+
M,
78+
order=1, # bilinear interpolation
79+
mode="nearest"
80+
)
81+
82+
print("Transformation complete.")
83+
84+
85+
# ============================================================
86+
# STEP 5: Convert back to NumPy
87+
# ============================================================
88+
89+
y_np = y.numpy()
90+
91+
# Remove leading batch/channel dimension if present
92+
if y_np.ndim == 4:
93+
y_np = y_np[0]
94+
95+
print("Output shape:", y_np.shape)
96+
97+
98+
# ============================================================
99+
# STEP 6: Save transformed volume
100+
# ============================================================
101+
102+
out_nii = nib.Nifti1Image(y_np, affine=nii.affine)
103+
nib.save(out_nii, "x_transformed.nii.gz")
104+
105+
print("Saved x_transformed.nii.gz")
106+
107+
108+
# ============================================================
109+
# STEP 7: Visualize middle slice
110+
# ============================================================
111+
112+
mid = x_np.shape[0] // 2
113+
114+
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
115+
116+
ax[0].imshow(x_np[mid], cmap="gray")
117+
ax[0].set_title("Original (middle slice)")
118+
ax[0].axis("off")
119+
120+
ax[1].imshow(y_np[mid], cmap="gray")
121+
ax[1].set_title("Transformed (middle slice)")
122+
ax[1].axis("off")
123+
124+
plt.tight_layout()
125+
plt.show()
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import nibabel as nib
2+
import matplotlib.pyplot as plt
3+
4+
# ============================================================
5+
# Load original and transformed MRI
6+
# ============================================================
7+
8+
orig_nii = nib.load("heat/datasets/flair.nii.gz")
9+
trans_nii = nib.load("heat/datasets/x_transformed.nii.gz")
10+
11+
orig = orig_nii.get_fdata()
12+
trans = trans_nii.get_fdata()
13+
14+
# Sanity check
15+
assert orig.shape == trans.shape, "Original and transformed shapes do not match!"
16+
17+
num_slices = orig.shape[0]
18+
slice_idx = num_slices // 2 # start in the middle
19+
20+
21+
# ============================================================
22+
# Create figure
23+
# ============================================================
24+
25+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
26+
27+
img_orig = ax[0].imshow(orig[slice_idx], cmap="gray")
28+
ax[0].set_title("Original")
29+
ax[0].axis("off")
30+
31+
img_trans = ax[1].imshow(trans[slice_idx], cmap="gray")
32+
ax[1].set_title("Transformed")
33+
ax[1].axis("off")
34+
35+
fig.suptitle(f"Slice {slice_idx}/{num_slices - 1}")
36+
37+
38+
# ============================================================
39+
# Keyboard interaction
40+
# ============================================================
41+
42+
def on_key(event):
43+
global slice_idx
44+
45+
if event.key == "up":
46+
slice_idx = min(slice_idx + 1, num_slices - 1)
47+
elif event.key == "down":
48+
slice_idx = max(slice_idx - 1, 0)
49+
else:
50+
return
51+
52+
img_orig.set_data(orig[slice_idx])
53+
img_trans.set_data(trans[slice_idx])
54+
fig.suptitle(f"Slice {slice_idx}/{num_slices - 1}")
55+
fig.canvas.draw_idle()
56+
57+
58+
fig.canvas.mpl_connect("key_press_event", on_key)
59+
plt.show()

heat/datasets/x_transformed.nii.gz

4.52 MB
Binary file not shown.

x_transformed.nii.gz

4.52 MB
Binary file not shown.

0 commit comments

Comments
 (0)