Skip to content

Commit af7ebf1

Browse files
committed
Initial commit
1 parent 8425ebf commit af7ebf1

13 files changed

+1702
-2
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Lukas Fisch
3+
Copyright (c) 2023 codingfisch
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

+90-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,91 @@
11
# torchreg
2-
Lightweight image registration library using PyTorch
2+
3+
torchreg is a tiny (~100 lines) PyTorch-based library for 2D and 3D image registration.
4+
5+
<p float="left", align="center">
6+
<img src="https://github.com/codingfisch/torchreg/blob/main/examples/alice_big.jpg" width="256"/>
7+
<img src="https://github.com/codingfisch/torchreg/blob/main/examples/alice_small.jpg" width="256"/>
8+
<img src="https://github.com/codingfisch/torchreg/assets/55840648/dbf414cc-75e5-477c-9794-32f97a16ea21" width="256"/>
9+
</p>
10+
11+
## Usage
12+
Affine Registration of two image tensors is done via:
13+
```python
14+
from torchreg import AffineRegistration
15+
16+
# Load images as torch Tensors
17+
small_alice = ... # Tensor with shape [1, 3 (color channel), 1024 (pixel), 1024 (pixel)]
18+
big_alice = ... # Tensor with shape [1, 3 (color channel), 1024 (pixel), 1024 (pixel)]
19+
# Intialize AffineRegistration
20+
reg = AffineRegistration(is_3d=False)
21+
# Run it!
22+
moved_alice = reg(small_alice, big_alice)
23+
```
24+
25+
## Features
26+
27+
Multiresolution approach to save compute (per default 1/4 + 1/2 of original resolution for 500 + 100 iterations)
28+
```python
29+
reg = AffineRegistration(scales=(4, 2), iterations=(500, 100))
30+
```
31+
Choosing which operations (translation, rotation, zoom, shear) to optimize
32+
```python
33+
reg = AffineRegistration(with_zoom=False, with_shear=False)
34+
```
35+
Custom initial parameters
36+
```python
37+
reg = AffineRegistration(zoom=torch.Tensor([[1.5, 2.]]))
38+
```
39+
Custom dissimilarity functions and optimizers
40+
```python
41+
def dice_loss(x1, x2):
42+
dim = [2, 3, 4] if len(x2.shape) == 5 else [2, 3]
43+
inter = torch.sum(x1 * x2, dim=dim)
44+
union = torch.sum(x1 + x2, dim=dim)
45+
return 1 - (2. * inter / union).mean()
46+
47+
reg = AffineRegistration(dissimilairity_function=dice_loss, optimizer=torch.optim.Adam)
48+
```
49+
CUDA support (NVIDIA GPU)
50+
```python
51+
moved_alice = reg(moving=big_alice.cuda(), static=small_alice.cuda())
52+
```
53+
MPS support (Apple M1 or M2)
54+
```python
55+
moved_alice = reg(moving=big_alice.to('mps'), static=small_alice.to('mps'))
56+
```
57+
58+
After the registration is run, you can apply it to new images (coregistration)
59+
```python
60+
another_moved_alice = reg.transform(another_alice, shape=(256, 256))
61+
```
62+
with desired output shape.
63+
64+
You can access the affine
65+
```python
66+
affine = reg.get_affine()
67+
```
68+
and the four parameters (translation, rotation, zoom, shear)
69+
```python
70+
translation = reg.parameters[0]
71+
rotation = reg.parameters[1]
72+
zoom = reg.parameters[2]
73+
shear = reg.parameters[3]
74+
```
75+
76+
## Installation
77+
```bash
78+
pip install torchreg
79+
```
80+
81+
## Examples/Tutorials
82+
83+
There are three example notebooks:
84+
85+
- [examples/basics.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/basic.ipynb) shows the basics by using small cubes/squares as image data
86+
- [examples/images.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/image.ipynb) shows how to register alice_big.jpg to alice_small.jpg
87+
- [examples/mri.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/mri.ipynb) shows how to register MR images (Nifti files) including co-, parallel and multimodal registration
88+
89+
## Background
90+
91+
If you want to know how the core of this package works, read [the blog post]()!

examples/alice_big.jpg

162 KB
Loading

examples/alice_small.jpg

158 KB
Loading

examples/basic.ipynb

+484
Large diffs are not rendered by default.

examples/image.ipynb

+239
Large diffs are not rendered by default.

examples/mri.ipynb

+687
Large diffs are not rendered by default.

pyproject.toml

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[tool.poetry]
2+
name = 'torchreg'
3+
version = '0.0.1'
4+
description = 'Lightweight image registration library using PyTorch'
5+
authors = ['codingfisch <[email protected]>']
6+
license = 'MIT'
7+
readme = 'README.md'
8+
repository = 'https://github.com/codingfisch/torchreg'
9+
classifiers = [
10+
'Programming Language :: Python :: 3',
11+
'Operating System :: OS Independent',
12+
'Intended Audience :: Science/Research'
13+
]
14+
15+
16+
[tool.poetry.dependencies]
17+
python = '^3.0'
18+
torch = '*'
19+
tqdm = '*'
20+
21+
22+
[tool.poetry.group.test.dependencies]
23+
pytest = '*'
24+
25+
26+
[build-system]
27+
requires = ['poetry-core']
28+
build-backend = 'poetry.core.masonry.api'

test/__init__.py

Whitespace-only changes.

test/test_affine.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
from unittest import TestCase
3+
from torchreg import AffineRegistration
4+
from torchreg.affine import compose_affine, affine_transform, init_parameters, _check_parameter_shapes
5+
6+
7+
class TestAffineRegistration(TestCase):
8+
def test_fit(self):
9+
for batch_size in [1, 2]:
10+
for n_dim in [2, 3]:
11+
reg = AffineRegistration(scales=(1,), is_3d=n_dim == 3, learning_rate=1e-1, verbose=False)
12+
moving = synthetic_image(batch_size, n_dim, shift=1)
13+
static = synthetic_image(batch_size, n_dim, shift=0)
14+
fitted_moved = reg(moving, static, return_moved=True)
15+
fitted_affine = reg.get_affine()
16+
affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]])
17+
affine[:, -1, -1] += -1/3
18+
self.assertTrue(torch.allclose(fitted_affine, affine, atol=1e-2))
19+
moved = affine_transform(moving, affine)
20+
self.assertTrue(torch.allclose(fitted_moved, moved, atol=1e-2))
21+
22+
def test_affine_transform(self):
23+
for batch_size in [1, 2]:
24+
for n_dim in [2, 3]:
25+
moving = synthetic_image(batch_size, n_dim, shift=1)
26+
static = synthetic_image(batch_size, n_dim, shift=0)
27+
affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]])
28+
affine[:, -1, -1] += -1/3
29+
moved = affine_transform(moving, affine)
30+
self.assertTrue(torch.allclose(moved, static, atol=1e-6))
31+
32+
def test_init_parameters(self):
33+
for batch_size in [1, 2]:
34+
for is_3d in [False, True]:
35+
params = init_parameters(is_3d=is_3d, batch_size=batch_size)
36+
self.assertIsInstance(params, list)
37+
self.assertEqual(len(params), 4)
38+
for param in params:
39+
self.assertTrue(isinstance(param, torch.nn.Parameter))
40+
_check_parameter_shapes(*params, is_3d=is_3d, batch_size=batch_size)
41+
42+
def test_compose_affine(self):
43+
for batch_size in [1, 2]:
44+
for n_dim in [2, 3]:
45+
translation = torch.zeros(batch_size, n_dim)
46+
rotation = torch.stack(batch_size * [torch.eye(n_dim)])
47+
zoom = torch.ones(batch_size, n_dim)
48+
shear = torch.zeros(batch_size, n_dim)
49+
affine = compose_affine(translation, rotation, zoom, shear)
50+
id_affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]])
51+
self.assertTrue(torch.equal(affine, id_affine))
52+
53+
54+
def synthetic_image(batch_size, n_dim, shift):
55+
shape = [batch_size, 1, 7, 7, 7][:2 + n_dim]
56+
x = torch.zeros(*shape)
57+
if n_dim == 3:
58+
x[:, :, 2 - shift:5 - shift, 2:5, 2:5] = 1
59+
else:
60+
x[:, :, 2 - shift:5 - shift, 2:5] = 1
61+
return x

torchreg/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .affine import AffineRegistration

torchreg/affine.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from tqdm import tqdm
4+
5+
6+
class AffineRegistration:
7+
def __init__(self, scales=(4, 2), iterations=(500, 100), is_3d=True, learning_rate=1e-2,
8+
verbose=True, dissimilarity_function=torch.nn.MSELoss(), optimizer=torch.optim.Adam,
9+
init_translation=None, init_rotation=None, init_zoom=None, init_shear=None,
10+
with_translation=True, with_rotation=True, with_zoom=True, with_shear=False,
11+
align_corners=True, interp_mode=None, padding_mode='border'):
12+
self.scales = scales
13+
self.iterations = iterations[:len(scales)]
14+
self.is_3d = is_3d
15+
self.learning_rate = learning_rate
16+
self.verbose = verbose
17+
self.dissimilarity_function = dissimilarity_function
18+
self.optimizer = optimizer
19+
self.inits = (init_translation, init_rotation, init_zoom, init_shear)
20+
self.withs = (with_translation, with_rotation, with_zoom, with_shear)
21+
self.align_corners = align_corners
22+
self.interp_mode = 'trilinear' if is_3d else 'bilinear' if interp_mode is None else interp_mode
23+
self.padding_mode = padding_mode
24+
self._parameters = None
25+
26+
def __call__(self, moving, static, return_moved=True):
27+
if len(moving.shape) - 4 != self.is_3d or len(static.shape) - 4 != self.is_3d:
28+
raise ValueError(f'Expected moving and static to be {4 + self.is_3d}D Tensors (2 + Spatial Dims.). '
29+
f'Got size {moving.shape} and {static.shape}.')
30+
if moving.shape != static.shape:
31+
raise ValueError(f'Expected moving and static to have the same size. '
32+
f'Got size {moving.shape} and {static.shape}.')
33+
34+
self._parameters = init_parameters(self.is_3d, len(static), static.device, *self.withs, *self.inits)
35+
interp_kwargs = {'mode': self.interp_mode, 'align_corners': self.align_corners}
36+
moving_ = F.interpolate(moving, static.shape[2:], **interp_kwargs)
37+
for scale, iters in zip(self.scales, self.iterations):
38+
moving_small = F.interpolate(moving_, scale_factor=1 / scale, **interp_kwargs)
39+
static_small = F.interpolate(static, scale_factor=1 / scale, **interp_kwargs)
40+
self._fit(moving_small, static_small, iters)
41+
return self.transform(moving, static.shape[2:]).detach() if return_moved else None
42+
43+
def _fit(self, moving, static, iterations):
44+
optimizer = self.optimizer(self._parameters, self.learning_rate)
45+
progress_bar = tqdm(range(iterations), disable=not self.verbose)
46+
for self.iter in progress_bar:
47+
optimizer.zero_grad()
48+
moved = self.transform(moving, static.shape[2:], with_grad=True)
49+
loss = self.dissimilarity_function(moved, static)
50+
progress_bar.set_description(f'Shape: {[*static.shape]}; Dissimiliarity: {loss.item()}')
51+
loss.backward()
52+
optimizer.step()
53+
54+
def transform(self, moving, shape=None, with_grad=False):
55+
affine = self.get_affine(with_grad)
56+
return affine_transform(moving, affine, shape, self.interp_mode, self.padding_mode, self.align_corners)
57+
58+
def get_affine(self, with_grad=False):
59+
affine = compose_affine(*self._parameters)
60+
return affine if with_grad else affine.detach()
61+
62+
63+
def affine_transform(x, affine, shape=None, mode='bilinear', padding_mode='border', align_corners=True):
64+
shape = x.shape[2:] if shape is None else shape
65+
grid = F.affine_grid(affine, [len(x), len(shape), *shape], align_corners)
66+
sample_mode = 'bilinear' if mode == 'trilinear' else mode # grid_sample converts 'bi-' to 'trilinear' internally
67+
return F.grid_sample(x, grid, sample_mode, padding_mode, align_corners)
68+
69+
70+
def init_parameters(is_3d=True, batch_size=1, device='cpu', with_translation=True, with_rotation=True, with_zoom=True,
71+
with_shear=True, init_translation=None, init_rotation=None, init_zoom=None, init_shear=None):
72+
_check_parameter_shapes(init_translation, init_rotation, init_zoom, init_shear, is_3d, batch_size)
73+
n_dim = 2 + is_3d
74+
translation = torch.zeros(batch_size, n_dim).to(device) if init_translation is None else init_translation
75+
rotation = torch.stack(batch_size * [torch.eye(n_dim)]).to(device) if init_rotation is None else init_rotation
76+
zoom = torch.ones(batch_size, n_dim).to(device) if init_zoom is None else init_zoom
77+
shear = torch.zeros(batch_size, n_dim).to(device) if init_shear is None else init_shear
78+
params = [translation, rotation, zoom, shear]
79+
with_grad = [with_translation, with_rotation, with_zoom, with_shear]
80+
return [torch.nn.Parameter(param, requires_grad=grad) for param, grad in zip(params, with_grad)]
81+
82+
83+
def compose_affine(translation, rotation, zoom, shear):
84+
_check_parameter_shapes(translation, rotation, zoom, shear, zoom.shape[-1] == 3, zoom.shape[0])
85+
square_matrix = torch.diag_embed(zoom)
86+
if zoom.shape[-1] == 3:
87+
square_matrix[..., 0, 1:] = shear[..., :2]
88+
square_matrix[..., 1, 2] = shear[..., 2]
89+
else:
90+
square_matrix[..., 0, 1] = shear[..., 0]
91+
square_matrix = rotation @ square_matrix
92+
return torch.cat([square_matrix, translation[:, :, None]], dim=-1)
93+
94+
95+
def _check_parameter_shapes(translation, rotation, zoom, shear, is_3d=True, batch_size=1):
96+
n_dim = 2 + is_3d
97+
params = {'translation': translation, 'rotation': rotation, 'zoom': zoom, 'shear': shear}
98+
for name, param in params.items():
99+
if param is not None:
100+
desired_shape = (batch_size, n_dim, n_dim) if name == 'rotation' else (batch_size, n_dim)
101+
if param.shape != desired_shape:
102+
raise ValueError(f'Expected {name} to be size {desired_shape} since batch_size is {batch_size} '
103+
f'and is_3d is {is_3d} -> {2 + is_3d} dimensions. Got size {param.shape}.')

torchreg/metrics.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
3+
4+
def dice_loss(x1, x2):
5+
dim = [2, 3, 4] if len(x2.shape) == 5 else [2, 3]
6+
inter = torch.sum(x1 * x2, dim=dim)
7+
union = torch.sum(x1 + x2, dim=dim)
8+
return 1 - (2. * inter / union).mean()

0 commit comments

Comments
 (0)