Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/ndimages/run_affine_on_nifti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
End-user demo: affine transformations on a 3D MRI volume
(using Heat affine_transform – final implementation).
"""

import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import heat as ht
from heat.ndimage.affine import affine_transform


# ============================================================
# Helper: normalize output to (D,H,W)
# ============================================================

def to_volume(y):
"""
Convert Heat affine output to plain (D,H,W) NumPy array,
regardless of whether a leading dimension exists.
"""
y_np = y.numpy()
if y_np.ndim == 4: # (1,D,H,W)
return y_np[0]
return y_np # (D,H,W)


# ============================================================
# STEP 1: Load MRI
# ============================================================

nii = nib.load(
"/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz"
)
x_np = nii.get_fdata().astype(np.float32)
x = ht.array(x_np)

D, H, W = x_np.shape
cx, cy, cz = D / 2, H / 2, W / 2

print("Loaded MRI with shape:", x_np.shape)


# ============================================================
# STEP 2: Define affine matrices
# ============================================================

# 1️⃣ Identity
M_identity = [
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
]

# 2️⃣ Scaling (zoom around center)
s = 1.4
M_scale = [
[s, 0, 0, cx * (1 - s)],
[0, s, 0, cy * (1 - s)],
[0, 0, s, cz * (1 - s)],
]

# 3️⃣ Rotation (around Z axis)
theta = np.deg2rad(20)
c, s_ = np.cos(theta), np.sin(theta)
M_rotate = [
[ c, -s_, 0, cx - c * cx + s_ * cy],
[ s_, c, 0, cy - s_ * cx - c * cy],
[ 0, 0, 1, 0],
]

# 4️⃣ Translation
tx = 15
M_translate = [
[1, 0, 0, tx],
[0, 1, 0, 0],
[0, 0, 1, 0],
]


# ============================================================
# STEP 3: Apply affine transforms
# ============================================================

print("Applying affine transformations...")

y_identity = to_volume(affine_transform(x, M_identity, order=1))
y_scale = to_volume(affine_transform(x, M_scale, order=1))
y_rotate = to_volume(affine_transform(x, M_rotate, order=1))
y_translate = to_volume(affine_transform(x, M_translate, order=1))

print("Transformations complete.")


# ============================================================
# STEP 4: Save transformed volumes
# ============================================================

nib.save(nib.Nifti1Image(y_identity, nii.affine), "mri_identity.nii.gz")
nib.save(nib.Nifti1Image(y_scale, nii.affine), "mri_scaled.nii.gz")
nib.save(nib.Nifti1Image(y_rotate, nii.affine), "mri_rotated.nii.gz")
nib.save(nib.Nifti1Image(y_translate, nii.affine), "mri_translated.nii.gz")

print("Saved transformed NIfTI files.")


# ============================================================
# STEP 5: Visualization (5x5 grid)
# ============================================================

slice_indices = np.linspace(0, D - 1, 5, dtype=int)

volumes = [
x_np,
y_identity,
y_scale,
y_rotate,
y_translate,
]

titles = [
"Original",
"Identity",
"Scale (1.4×)",
"Rotate (20°)",
"Translate (+x)",
]

fig, axes = plt.subplots(5, 5, figsize=(12, 12))

for row in range(5):
for col in range(5):
axes[row, col].imshow(
volumes[row][slice_indices[col]],
cmap="gray"
)
axes[row, col].axis("off")

if col == 0:
axes[row, col].set_ylabel(titles[row], fontsize=10)

if row == 0:
axes[row, col].set_title(f"Slice {slice_indices[col]}", fontsize=9)

plt.tight_layout()
plt.show()
97 changes: 97 additions & 0 deletions examples/ndimages/test_affine_real_mri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import nibabel as nib
import numpy as np
import heat as ht
import torch
from mpi4py import MPI

from heat.ndimage.affine import distributed_affine_transform


def chunk_bounds_1d(n, rank, size):
base = n // size
rem = n % size
start = rank * base + min(rank, rem)
stop = start + base + (1 if rank < rem else 0)
return start, stop


def main():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# --------------------------------------------------
# Load MRI (rank 0 only)
# --------------------------------------------------
if rank == 0:
path = "/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz"
img = nib.load(path)
data = img.get_fdata().astype(np.float32)
print("Loaded MRI shape:", data.shape)
else:
data = None

# Broadcast for test setup only
data = comm.bcast(data, root=0)

# --------------------------------------------------
# Create distributed Heat array
# --------------------------------------------------
x = ht.array(data, split=0)

D, H, W = x.gshape
z0, z1 = chunk_bounds_1d(D, rank, size)

local = x.larray

# --------------------------------------------------
# HARD PROOF OF SPLITTING
# --------------------------------------------------
nonzero_z = torch.nonzero(local.sum(dim=(1, 2))).flatten()

print(
f"\nRank {rank}\n"
f" owns global z range : [{z0}, {z1})\n"
f" local tensor shape : {tuple(local.shape)}\n"
f" local sum : {float(local.sum()):.2f}\n"
f" first nonzero local z : {int(nonzero_z[0]) if len(nonzero_z) else 'NONE'}\n"
f" last nonzero local z : {int(nonzero_z[-1]) if len(nonzero_z) else 'NONE'}\n"
)

comm.Barrier()

# --------------------------------------------------
# Apply distributed affine (small translation)
# --------------------------------------------------
M = torch.tensor(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 5],
],
dtype=torch.float32,
)

y = distributed_affine_transform(x, M, order=0, mode="constant", cval=0.0)

comm.Barrier()

local_out = y.larray
nonzero_out_z = torch.nonzero(local_out.sum(dim=(1, 2))).flatten()

print(
f"Rank {rank} AFTER affine\n"
f" output local shape : {tuple(local_out.shape)}\n"
f" output local sum : {float(local_out.sum()):.2f}\n"
f" output first z : {int(nonzero_out_z[0]) if len(nonzero_out_z) else 'NONE'}\n"
f" output last z : {int(nonzero_out_z[-1]) if len(nonzero_out_z) else 'NONE'}\n"
)

comm.Barrier()

if rank == 0:
print("\n✅ SPLIT VERIFICATION COMPLETE")


if __name__ == "__main__":
main()
140 changes: 140 additions & 0 deletions examples/ndimages/test_cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Synthetic 3D star / landmark cube test for affine_transform.

Works for:
- batched outputs (1, D, H, W)
- unbatched inputs (D, H, W)
- MPI (mpirun -np 2)
"""

import numpy as np
import heat as ht
import matplotlib.pyplot as plt
from heat.ndimage.affine import affine_transform


# ============================================================
# 1. Create a synthetic 3D star cube
# ============================================================

def make_star_cube(size=64, arm_length=20, arm_thickness=2):
cube = np.zeros((size, size, size), dtype=np.float32)
c = size // 2

# Main axes
cube[c-arm_length:c+arm_length, c, c] = 5
cube[c, c-arm_length:c+arm_length, c] = 5
cube[c, c, c-arm_length:c+arm_length] = 5

# Diagonals
for i in range(-arm_length, arm_length):
cube[c+i, c+i, c] = 4
cube[c+i, c-i, c] = 4

# Landmarks
cube[c, c, c] = 10
cube[c, c, c+10] = 8
cube[c, c+10, c] = 8
cube[c+10, c, c] = 8

return ht.array(cube)


# ============================================================
# 2. Build test volume
# ============================================================

x = make_star_cube()
D, H, W = x.shape
c = D // 2

print("Cube shape:", x.shape)


# ============================================================
# 3. Affine matrices
# ============================================================

M_identity = [
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
]

M_translate = [
[1, 0, 0, 8],
[0, 1, 0, 0],
[0, 0, 1, 0],
]

theta = np.deg2rad(30)
ct, st = np.cos(theta), np.sin(theta)
M_rotate = [
[ ct, -st, 0, 0],
[ st, ct, 0, 0],
[ 0, 0, 1, 0],
]


# ============================================================
# 4. Apply affine transforms
# ============================================================

print("Applying affine transforms...")

y_id = affine_transform(x, M_identity, order=1)
y_tr = affine_transform(x, M_translate, order=0)
y_rot = affine_transform(x, M_rotate, order=1)


# ============================================================
# 5. Numeric checks
# ============================================================

print("\nNumeric checks:")

assert ht.allclose(x, y_id)
print("✓ Identity transform OK")

assert y_tr.numpy()[0, c, c, c+18] == 8
print("✓ Translation moves landmark correctly")


# ============================================================
# 6. Visualization (robust)
# ============================================================

def to_volume(arr):
"""
Convert numpy array to (D, H, W) no matter what.
"""
if arr.ndim == 4: # (1, D, H, W)
return arr[0]
if arr.ndim == 3: # (D, H, W)
return arr
raise ValueError(f"Unexpected shape {arr.shape}")


def show_slice(arr, title):
vol = to_volume(arr)
z = vol.shape[0] // 2
plt.imshow(vol[z], cmap="gray")
plt.title(title)
plt.axis("off")


plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
show_slice(x.numpy(), "Original")

plt.subplot(1, 3, 2)
show_slice(y_tr.numpy(), "Translate +X")

plt.subplot(1, 3, 3)
show_slice(y_rot.numpy(), "Rotate 30°")

plt.tight_layout()
plt.show()

print("\nAll tests completed successfully.")
Loading