diff --git a/examples/ndimages/run_affine_on_nifti.py b/examples/ndimages/run_affine_on_nifti.py new file mode 100644 index 0000000000..25fd3b686f --- /dev/null +++ b/examples/ndimages/run_affine_on_nifti.py @@ -0,0 +1,146 @@ +""" +End-user demo: affine transformations on a 3D MRI volume +(using Heat affine_transform – final implementation). +""" + +import nibabel as nib +import numpy as np +import matplotlib.pyplot as plt +import heat as ht +from heat.ndimage.affine import affine_transform + + +# ============================================================ +# Helper: normalize output to (D,H,W) +# ============================================================ + +def to_volume(y): + """ + Convert Heat affine output to plain (D,H,W) NumPy array, + regardless of whether a leading dimension exists. + """ + y_np = y.numpy() + if y_np.ndim == 4: # (1,D,H,W) + return y_np[0] + return y_np # (D,H,W) + + +# ============================================================ +# STEP 1: Load MRI +# ============================================================ + +nii = nib.load( + "/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz" +) +x_np = nii.get_fdata().astype(np.float32) +x = ht.array(x_np) + +D, H, W = x_np.shape +cx, cy, cz = D / 2, H / 2, W / 2 + +print("Loaded MRI with shape:", x_np.shape) + + +# ============================================================ +# STEP 2: Define affine matrices +# ============================================================ + +# 1️⃣ Identity +M_identity = [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], +] + +# 2️⃣ Scaling (zoom around center) +s = 1.4 +M_scale = [ + [s, 0, 0, cx * (1 - s)], + [0, s, 0, cy * (1 - s)], + [0, 0, s, cz * (1 - s)], +] + +# 3️⃣ Rotation (around Z axis) +theta = np.deg2rad(20) +c, s_ = np.cos(theta), np.sin(theta) +M_rotate = [ + [ c, -s_, 0, cx - c * cx + s_ * cy], + [ s_, c, 0, cy - s_ * cx - c * cy], + [ 0, 0, 1, 0], +] + +# 4️⃣ Translation +tx = 15 +M_translate = [ + [1, 0, 0, tx], + [0, 1, 0, 0], + [0, 0, 1, 0], +] + + +# ============================================================ +# STEP 3: Apply affine transforms +# ============================================================ + +print("Applying affine transformations...") + +y_identity = to_volume(affine_transform(x, M_identity, order=1)) +y_scale = to_volume(affine_transform(x, M_scale, order=1)) +y_rotate = to_volume(affine_transform(x, M_rotate, order=1)) +y_translate = to_volume(affine_transform(x, M_translate, order=1)) + +print("Transformations complete.") + + +# ============================================================ +# STEP 4: Save transformed volumes +# ============================================================ + +nib.save(nib.Nifti1Image(y_identity, nii.affine), "mri_identity.nii.gz") +nib.save(nib.Nifti1Image(y_scale, nii.affine), "mri_scaled.nii.gz") +nib.save(nib.Nifti1Image(y_rotate, nii.affine), "mri_rotated.nii.gz") +nib.save(nib.Nifti1Image(y_translate, nii.affine), "mri_translated.nii.gz") + +print("Saved transformed NIfTI files.") + + +# ============================================================ +# STEP 5: Visualization (5x5 grid) +# ============================================================ + +slice_indices = np.linspace(0, D - 1, 5, dtype=int) + +volumes = [ + x_np, + y_identity, + y_scale, + y_rotate, + y_translate, +] + +titles = [ + "Original", + "Identity", + "Scale (1.4×)", + "Rotate (20°)", + "Translate (+x)", +] + +fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + +for row in range(5): + for col in range(5): + axes[row, col].imshow( + volumes[row][slice_indices[col]], + cmap="gray" + ) + axes[row, col].axis("off") + + if col == 0: + axes[row, col].set_ylabel(titles[row], fontsize=10) + + if row == 0: + axes[row, col].set_title(f"Slice {slice_indices[col]}", fontsize=9) + +plt.tight_layout() +plt.show() diff --git a/examples/ndimages/test_affine_real_mri.py b/examples/ndimages/test_affine_real_mri.py new file mode 100644 index 0000000000..868f26a936 --- /dev/null +++ b/examples/ndimages/test_affine_real_mri.py @@ -0,0 +1,97 @@ +import nibabel as nib +import numpy as np +import heat as ht +import torch +from mpi4py import MPI + +from heat.ndimage.affine import distributed_affine_transform + + +def chunk_bounds_1d(n, rank, size): + base = n // size + rem = n % size + start = rank * base + min(rank, rem) + stop = start + base + (1 if rank < rem else 0) + return start, stop + + +def main(): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + + # -------------------------------------------------- + # Load MRI (rank 0 only) + # -------------------------------------------------- + if rank == 0: + path = "/Users/marka.k/1900_Image_transformations/heat/heat/datasets/flair.nii.gz" + img = nib.load(path) + data = img.get_fdata().astype(np.float32) + print("Loaded MRI shape:", data.shape) + else: + data = None + + # Broadcast for test setup only + data = comm.bcast(data, root=0) + + # -------------------------------------------------- + # Create distributed Heat array + # -------------------------------------------------- + x = ht.array(data, split=0) + + D, H, W = x.gshape + z0, z1 = chunk_bounds_1d(D, rank, size) + + local = x.larray + + # -------------------------------------------------- + # HARD PROOF OF SPLITTING + # -------------------------------------------------- + nonzero_z = torch.nonzero(local.sum(dim=(1, 2))).flatten() + + print( + f"\nRank {rank}\n" + f" owns global z range : [{z0}, {z1})\n" + f" local tensor shape : {tuple(local.shape)}\n" + f" local sum : {float(local.sum()):.2f}\n" + f" first nonzero local z : {int(nonzero_z[0]) if len(nonzero_z) else 'NONE'}\n" + f" last nonzero local z : {int(nonzero_z[-1]) if len(nonzero_z) else 'NONE'}\n" + ) + + comm.Barrier() + + # -------------------------------------------------- + # Apply distributed affine (small translation) + # -------------------------------------------------- + M = torch.tensor( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 5], + ], + dtype=torch.float32, + ) + + y = distributed_affine_transform(x, M, order=0, mode="constant", cval=0.0) + + comm.Barrier() + + local_out = y.larray + nonzero_out_z = torch.nonzero(local_out.sum(dim=(1, 2))).flatten() + + print( + f"Rank {rank} AFTER affine\n" + f" output local shape : {tuple(local_out.shape)}\n" + f" output local sum : {float(local_out.sum()):.2f}\n" + f" output first z : {int(nonzero_out_z[0]) if len(nonzero_out_z) else 'NONE'}\n" + f" output last z : {int(nonzero_out_z[-1]) if len(nonzero_out_z) else 'NONE'}\n" + ) + + comm.Barrier() + + if rank == 0: + print("\n✅ SPLIT VERIFICATION COMPLETE") + + +if __name__ == "__main__": + main() diff --git a/examples/ndimages/test_cube.py b/examples/ndimages/test_cube.py new file mode 100644 index 0000000000..554a9de5d8 --- /dev/null +++ b/examples/ndimages/test_cube.py @@ -0,0 +1,140 @@ +""" +Synthetic 3D star / landmark cube test for affine_transform. + +Works for: +- batched outputs (1, D, H, W) +- unbatched inputs (D, H, W) +- MPI (mpirun -np 2) +""" + +import numpy as np +import heat as ht +import matplotlib.pyplot as plt +from heat.ndimage.affine import affine_transform + + +# ============================================================ +# 1. Create a synthetic 3D star cube +# ============================================================ + +def make_star_cube(size=64, arm_length=20, arm_thickness=2): + cube = np.zeros((size, size, size), dtype=np.float32) + c = size // 2 + + # Main axes + cube[c-arm_length:c+arm_length, c, c] = 5 + cube[c, c-arm_length:c+arm_length, c] = 5 + cube[c, c, c-arm_length:c+arm_length] = 5 + + # Diagonals + for i in range(-arm_length, arm_length): + cube[c+i, c+i, c] = 4 + cube[c+i, c-i, c] = 4 + + # Landmarks + cube[c, c, c] = 10 + cube[c, c, c+10] = 8 + cube[c, c+10, c] = 8 + cube[c+10, c, c] = 8 + + return ht.array(cube) + + +# ============================================================ +# 2. Build test volume +# ============================================================ + +x = make_star_cube() +D, H, W = x.shape +c = D // 2 + +print("Cube shape:", x.shape) + + +# ============================================================ +# 3. Affine matrices +# ============================================================ + +M_identity = [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], +] + +M_translate = [ + [1, 0, 0, 8], + [0, 1, 0, 0], + [0, 0, 1, 0], +] + +theta = np.deg2rad(30) +ct, st = np.cos(theta), np.sin(theta) +M_rotate = [ + [ ct, -st, 0, 0], + [ st, ct, 0, 0], + [ 0, 0, 1, 0], +] + + +# ============================================================ +# 4. Apply affine transforms +# ============================================================ + +print("Applying affine transforms...") + +y_id = affine_transform(x, M_identity, order=1) +y_tr = affine_transform(x, M_translate, order=0) +y_rot = affine_transform(x, M_rotate, order=1) + + +# ============================================================ +# 5. Numeric checks +# ============================================================ + +print("\nNumeric checks:") + +assert ht.allclose(x, y_id) +print("✓ Identity transform OK") + +assert y_tr.numpy()[0, c, c, c+18] == 8 +print("✓ Translation moves landmark correctly") + + +# ============================================================ +# 6. Visualization (robust) +# ============================================================ + +def to_volume(arr): + """ + Convert numpy array to (D, H, W) no matter what. + """ + if arr.ndim == 4: # (1, D, H, W) + return arr[0] + if arr.ndim == 3: # (D, H, W) + return arr + raise ValueError(f"Unexpected shape {arr.shape}") + + +def show_slice(arr, title): + vol = to_volume(arr) + z = vol.shape[0] // 2 + plt.imshow(vol[z], cmap="gray") + plt.title(title) + plt.axis("off") + + +plt.figure(figsize=(12, 4)) + +plt.subplot(1, 3, 1) +show_slice(x.numpy(), "Original") + +plt.subplot(1, 3, 2) +show_slice(y_tr.numpy(), "Translate +X") + +plt.subplot(1, 3, 3) +show_slice(y_rot.numpy(), "Rotate 30°") + +plt.tight_layout() +plt.show() + +print("\nAll tests completed successfully.") diff --git a/examples/ndimages/view_mri_scroll.py b/examples/ndimages/view_mri_scroll.py new file mode 100644 index 0000000000..b0fbe96c33 --- /dev/null +++ b/examples/ndimages/view_mri_scroll.py @@ -0,0 +1,72 @@ +import nibabel as nib +import matplotlib.pyplot as plt + +# ============================================================ +# Paths (ABSOLUTE – adjust if needed) +# ============================================================ + +BASE = "/Users/marka.k/1900_Image_transformations/heat/heat/datasets" + +paths = { + "Original": f"{BASE}/flair.nii.gz", + "Identity": f"{BASE}/mri_identity.nii.gz", + "Scaled": f"{BASE}/mri_scaled.nii.gz", + "Rotated": f"{BASE}/mri_rotated.nii.gz", + "Translated": f"{BASE}/mri_translated.nii.gz", +} + +# ============================================================ +# Load volumes +# ============================================================ + +volumes = {} +for name, path in paths.items(): + volumes[name] = nib.load(path).get_fdata() + +# Sanity check: all shapes equal +shapes = {v.shape for v in volumes.values()} +assert len(shapes) == 1, "Not all volumes have the same shape!" + +D, H, W = next(iter(shapes)) +slice_idx = D // 2 + +# ============================================================ +# Create figure +# ============================================================ + +titles = list(volumes.keys()) +data = list(volumes.values()) + +fig, axes = plt.subplots(1, len(data), figsize=(4 * len(data), 5)) +images = [] + +for ax, title, vol in zip(axes, titles, data): + img = ax.imshow(vol[slice_idx], cmap="gray") + ax.set_title(title) + ax.axis("off") + images.append(img) + +fig.suptitle(f"Slice {slice_idx}/{D - 1}") + +# ============================================================ +# Keyboard navigation +# ============================================================ + +def on_key(event): + global slice_idx + + if event.key == "up": + slice_idx = min(slice_idx + 1, D - 1) + elif event.key == "down": + slice_idx = max(slice_idx - 1, 0) + else: + return + + for img, vol in zip(images, data): + img.set_data(vol[slice_idx]) + + fig.suptitle(f"Slice {slice_idx}/{D - 1}") + fig.canvas.draw_idle() + +fig.canvas.mpl_connect("key_press_event", on_key) +plt.show() diff --git a/heat/datasets/flair.nii.gz b/heat/datasets/flair.nii.gz new file mode 100644 index 0000000000..0764a0f2de Binary files /dev/null and b/heat/datasets/flair.nii.gz differ diff --git a/heat/datasets/mri_sample_LICENSE.txt b/heat/datasets/mri_sample_LICENSE.txt new file mode 100644 index 0000000000..22bc3e69c8 --- /dev/null +++ b/heat/datasets/mri_sample_LICENSE.txt @@ -0,0 +1,10 @@ +MIT License + +Copyright (c) 2018 Adam Wolf + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this file to deal in the file without restriction, including without +limitation the rights to use, copy, modify, merge, publish, distribute, +sublicense, and/or sell copies of the file. + +THE FILE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND. diff --git a/heat/ndimage/affine.py b/heat/ndimage/affine.py new file mode 100644 index 0000000000..6ad3b23fa4 --- /dev/null +++ b/heat/ndimage/affine.py @@ -0,0 +1,439 @@ +""" +Affine transformations for N-dimensional Heat arrays. + +This module implements backward-warping affine transformations +(translation, rotation, scaling) for 2D and 3D data stored as +Heat DNDarrays, using a PyTorch backend. + +The affine matrix M is interpreted as a *forward* transform +in affine (x, y [, z]) coordinates: + + out = A @ inp + b + +where M = [A | b] has shape (ND, ND+1). + +Backward warping is used internally for resampling: + + inp = A^{-1} @ (out - b) + +Heat uses the following spatial axis conventions: +- 2D: (H, W) == (y, x) +- 3D: (D, H, W) == (z, y, x) + +Interpolation and boundary handling: +- order=0: nearest-neighbor +- order=1: bilinear (2D only; 3D falls back to nearest) +- padding modes: 'nearest', 'wrap', 'reflect', 'constant' + +Distributed arrays: +- Distributed inputs are currently gathered, transformed locally, + and re-distributed to the original split axis. +- This is MPI-safe but may incur communication overhead. +- More advanced distributed strategies are intended for future work. + +The public entry point is `affine_transform`. +""" + +import numpy as np +import torch +import heat as ht + + +# ============================================================ +# Helpers +# ============================================================ + + +def _is_identity_affine(M, ND): + """ + Check whether an affine matrix represents the identity transform. + + Parameters + ---------- + M : array-like + Affine matrix of shape (ND, ND+1). + ND : int + Number of spatial dimensions. + + Returns + ------- + bool + True if A is the identity matrix and b is zero. + """ + A = M[:, :ND] + b = M[:, ND:] + return np.allclose(A, np.eye(ND)) and np.allclose(b, 0) + + +def _normalize_input(x, ND): + """ + Normalize a Heat array to the internal sampling layout. + + Converts input arrays to include batch and channel dimensions, + producing tensors of shape: + - 2D: (N, C, H, W) + - 3D: (N, C, D, H, W) + + Parameters + ---------- + x : ht.DNDarray + Input array. + ND : int + Number of spatial dimensions. + + Returns + ------- + torch.Tensor + Local torch tensor with batch and channel dimensions. + tuple + Original shape of the input array. + """ + orig_shape = x.shape + t = x.larray + + if ND == 2: + if x.ndim == 2: + t = t.unsqueeze(0).unsqueeze(0) + elif x.ndim == 3: + t = t.unsqueeze(0) + else: + if x.ndim == 3: + t = t.unsqueeze(0).unsqueeze(0) + elif x.ndim == 4: + t = t.unsqueeze(0) + + return t, orig_shape + + +def _make_grid(spatial, device): + """ + Construct a coordinate grid in Heat spatial axis order. + + Parameters + ---------- + spatial : tuple + Spatial shape (H, W) or (D, H, W). + device : torch.device + Target device. + + Returns + ------- + torch.Tensor + Coordinate grid of shape (ND, *spatial) in Heat order. + """ + if len(spatial) == 2: + H, W = spatial + y = torch.arange(H, device=device) + x = torch.arange(W, device=device) + gy, gx = torch.meshgrid(y, x, indexing="ij") + return torch.stack([gy, gx], dim=0) + else: + D, H, W = spatial + z = torch.arange(D, device=device) + y = torch.arange(H, device=device) + x = torch.arange(W, device=device) + gz, gy, gx = torch.meshgrid(z, y, x, indexing="ij") + return torch.stack([gz, gy, gx], dim=0) + + +# ============================================================ +# Padding +# ============================================================ + + +def _apply_padding(pix, spatial, mode): + """ + Apply boundary handling to integer pixel coordinates. + + Parameters + ---------- + pix : torch.Tensor + Integer pixel indices in Heat order. + spatial : tuple + Spatial dimensions of the input. + mode : str + Padding mode ('nearest', 'wrap', 'reflect', 'constant'). + + Returns + ------- + torch.Tensor + Adjusted pixel coordinates. + torch.Tensor + Boolean mask indicating valid coordinates (for constant mode). + """ + ND = len(spatial) + final = pix.clone() + valid = torch.ones_like(pix[0], dtype=torch.bool) + + for d in range(ND): + size = spatial[d] + p = pix[d] + + if mode == "constant": + ok = (p >= 0) & (p < size) + valid &= ok + final[d] = p.clamp(0, size - 1) + elif mode == "nearest": + final[d] = p.clamp(0, size - 1) + elif mode == "wrap": + final[d] = torch.remainder(p, size) + elif mode == "reflect": + if size == 1: + final[d] = torch.zeros_like(p) + else: + r = torch.abs(p) + r = torch.remainder(r, 2 * size - 2) + final[d] = torch.where(r < size, r, 2 * size - 2 - r) + + return final, valid + + +# ============================================================ +# Sampling +# ============================================================ + + +def _nearest_sample(x, coords, mode, constant_value): + """ + Nearest-neighbor sampling. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (N, C, ...). + coords : torch.Tensor + Sampling coordinates in Heat order. + mode : str + Boundary handling mode. + constant_value : float + Fill value for constant padding. + + Returns + ------- + torch.Tensor + Sampled output tensor. + """ + ND = coords.shape[0] + pix = coords.round().long() + spatial = x.shape[2:] + + pix_c, valid = _apply_padding(pix, spatial, mode) + + if ND == 2: + y, x_ = pix_c + out = x[:, :, y, x_] + else: + z, y, x_ = pix_c + out = x[:, :, z, y, x_] + + if mode == "constant": + const = torch.full_like(out, constant_value) + out = torch.where(valid.unsqueeze(0).unsqueeze(0), out, const) + + return out + + +def _bilinear_sample(x, coords, mode, constant_value): + """ + Bilinear interpolation for 2D inputs. + + For non-2D data, this function falls back to nearest-neighbor sampling. + """ + if coords.shape[0] != 2: + return _nearest_sample(x, coords, mode, constant_value) + + y, x_ = coords + H, W = x.shape[2], x.shape[3] + + y0 = torch.floor(y).long() + x0 = torch.floor(x_).long() + y1 = y0 + 1 + x1 = x0 + 1 + + pix = torch.stack([y0, x0]) + pix_c, valid = _apply_padding(pix, (H, W), mode) + + y0c, x0c = pix_c + y1c = y1.clamp(0, H - 1) + x1c = x1.clamp(0, W - 1) + + Ia = x[:, :, y0c, x0c] + Ib = x[:, :, y0c, x1c] + Ic = x[:, :, y1c, x0c] + Id = x[:, :, y1c, x1c] + + wy = y - y0.float() + wx = x_ - x0.float() + + out = Ia * (1 - wy) * (1 - wx) + Ib * (1 - wy) * wx + Ic * wy * (1 - wx) + Id * wy * wx + + if mode == "constant": + const = torch.full_like(out, constant_value) + out = torch.where(valid.unsqueeze(0).unsqueeze(0), out, const) + + return out + + +# ============================================================ +# Local affine (NO MPI logic) +# ============================================================ + + +def _affine_transform_local(x, M, order, mode, constant_value, expand): + """ + Apply an affine transformation to a local (non-distributed) Heat array. + + Parameters + ---------- + x : ht.DNDarray + Local input array (split=None). + M : array-like + Affine matrix of shape (2,3) or (3,4). + order : int + Interpolation order. + mode : str + Boundary handling mode. + constant_value : float + Fill value for constant padding. + expand : bool + Whether to expand the output with a leading batch dimension. + + Returns + ------- + ht.DNDarray + Transformed array with split=None. + """ + M = np.asarray(M) + + if M.shape == (2, 3): + ND = 2 + elif M.shape == (3, 4): + ND = 3 + else: + raise ValueError("M must have shape (2,3) or (3,4)") + + is_identity = _is_identity_affine(M, ND) + + x_local, orig_shape = _normalize_input(x, ND) + device = x_local.device + + A = torch.tensor(M[:, :ND], device=device, dtype=torch.float64) + b = torch.tensor(M[:, ND:], device=device, dtype=torch.float64).reshape(ND, 1) + A_inv = torch.inverse(A) + + spatial = x_local.shape[2:] + grid_heat = _make_grid(spatial, device).reshape(ND, -1).double() + + if ND == 2: + grid_affine = grid_heat[[1, 0]] + else: + z, y, x_ = grid_heat + grid_affine = torch.stack([x_, y, z], dim=0) + + coords_affine = (A_inv @ grid_affine) - (A_inv @ b) + + if ND == 2: + coords_heat = coords_affine[[1, 0]].reshape((2, *spatial)) + else: + cx, cy, cz = coords_affine + coords_heat = torch.stack( + [ + cz.reshape(spatial), + cy.reshape(spatial), + cx.reshape(spatial), + ], + dim=0, + ) + + if order == 0: + out = _nearest_sample(x_local, coords_heat, mode, constant_value) + else: + out = _bilinear_sample(x_local, coords_heat, mode, constant_value) + + out = out.squeeze(0) + + if expand: + if out.ndim == ND + 1: + out = out.squeeze(0) + return ht.array(out, split=None).expand_dims(0) + + if ND == 2: + if order == 0 or is_identity: + return ht.array(out.squeeze(0).reshape(orig_shape), split=None) + return ht.array(out, split=None) + + if is_identity: + return ht.array(out.squeeze(0).reshape(orig_shape), split=None) + + return ht.array(out, split=None) + + +# ============================================================ +# Public API (MPI-safe) +# ============================================================ + + +def affine_transform(x, M, order=0, mode="nearest", constant_value=0.0, expand=False): + """ + Apply an affine transformation to a Heat array. + + Distributed behavior: + - If split is non-spatial → transform locally (no communication) + - If split is spatial: + * translation or axis-aligned scaling → resplit → local transform → resplit back + * rotation / shear → NotImplementedError + """ + # ------------------------------------------------------------ + # Determine spatial dimensionality + # ------------------------------------------------------------ + M = np.asarray(M) + if M.shape == (2, 3): + ND = 2 + elif M.shape == (3, 4): + ND = 3 + else: + raise ValueError("M must have shape (2,3) or (3,4)") + + # Helper predicates + A = M[:, :ND] + b = M[:, ND:] + + def is_translation(): + return np.allclose(A, np.eye(ND)) and not np.allclose(b, 0) + + def is_axis_aligned_scaling(): + return np.allclose(A, np.diag(np.diag(A))) + + if x.split is None: + return _affine_transform_local(x, M, order, mode, constant_value, expand) + + # Identify spatial axes (last ND axes) + spatial_axes = set(range(x.ndim - ND, x.ndim)) + + # Fast path: split on non-spatial axis → safe + if x.split not in spatial_axes: + return _affine_transform_local(x, M, order, mode, constant_value, expand) + + if is_translation(): + # translation along spatial split requires full axis coverage + x_tmp = x.resplit(None) + y_tmp = _affine_transform_local(x_tmp, M, order, mode, constant_value, expand) + return y_tmp.resplit(x.split) + + if is_axis_aligned_scaling(): + # scaling still requires a non-spatial axis + safe_axes = [ax for ax in range(x.ndim) if ax not in spatial_axes] + if not safe_axes: + raise NotImplementedError( + "Axis-aligned scaling on fully spatial arrays requires a non-spatial axis" + ) + safe_axis = safe_axes[0] + x_tmp = x.resplit(safe_axis) + y_tmp = _affine_transform_local(x_tmp, M, order, mode, constant_value, expand) + return y_tmp.resplit(x.split) + + # Rotation / shear on spatial split → not supported + raise NotImplementedError( + "Affine transforms with axis mixing (rotation/shear) on spatially " + "distributed axes are not supported. Explicit halo exchange is required." + ) diff --git a/heat/ndimage/tests/test_affine_transform_distributed.py b/heat/ndimage/tests/test_affine_transform_distributed.py new file mode 100644 index 0000000000..b617169ab6 --- /dev/null +++ b/heat/ndimage/tests/test_affine_transform_distributed.py @@ -0,0 +1,109 @@ +import numpy as np +import pytest +import heat as ht +from mpi4py import MPI + +from heat.ndimage.affine import affine_transform + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +@pytest.mark.mpi +def test_undistributed_affine_translation_backward(): + """ + Backward warping with nearest padding. + + out[z, y, x] = in[z, y, x - 1] + with x < 0 clamped to 0. + """ + data = np.arange(24, dtype=np.float32).reshape(4, 3, 2) + x = ht.array(data, split=None) + + M = np.array( + [ + [1, 0, 0, 1], + [0, 1, 0, 0], + [0, 0, 1, 0], + ], + dtype=np.float64, + ) + + y = affine_transform(x, M, order=0, mode="nearest").numpy() + + # correct backward-warp reference + ref = np.zeros_like(data) + ref[:, :, 0] = data[:, :, 0] + ref[:, :, 1] = data[:, :, 0] + + assert np.allclose(y, ref) + + +@pytest.mark.mpi +def test_distributed_non_split_axis_translation_matches_undistributed(): + data = np.arange(48, dtype=np.float32).reshape(6, 4, 2) + + M = np.array( + [ + [1, 0, 0, 1], + [0, 1, 0, 0], + [0, 0, 1, 0], + ], + dtype=np.float64, + ) + + x_full = ht.array(data, split=None) + y_ref = affine_transform(x_full, M, order=0).numpy() + + x_dist = ht.array(data, split=0) + y_dist = affine_transform(x_dist, M, order=0) + + assert y_dist.split == 0 + assert np.allclose(y_dist.resplit(None).numpy(), y_ref) + + +@pytest.mark.mpi +def test_split_axis_translation_supported_via_resplit(): + data = np.zeros((8, 4, 4), dtype=np.float32) + if rank == 0: + data[1, 2, 2] = 1.0 + data = comm.bcast(data, root=0) + + x = ht.array(data, split=0) + + # translate +3 along z + M = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 3], + ], + dtype=np.float64, + ) + + y = affine_transform(x, M, order=0) + + ref = affine_transform(ht.array(data, split=None), M, order=0).numpy() + got = y.resplit(None).numpy() + + assert np.allclose(got, ref) + + +@pytest.mark.mpi +def test_distributed_vs_undistributed_equivalence(): + rng = np.random.default_rng(0) + data = rng.normal(size=(8, 5, 4)).astype(np.float32) + + M = np.array( + [ + [1, 0, 0, 1], + [0, 1, 0, 0], + [0, 0, 1, 0], + ], + dtype=np.float64, + ) + + y_ref = affine_transform(ht.array(data, split=None), M, order=0).numpy() + y_dist = affine_transform(ht.array(data, split=0), M, order=0) + + assert np.allclose(y_dist.resplit(None).numpy(), y_ref) diff --git a/pyproject.toml b/pyproject.toml index 5168e89f48..2c5b04eeaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,3 +177,8 @@ convention = "numpy" [tool.ruff.format] docstring-code-format = true + +[tool.pytest.ini_options] +markers = [ + "mpi: tests that require mpirun / MPI execution", +]