Skip to content

Commit af7ebf1

Browse files
committed
Initial commit
1 parent 8425ebf commit af7ebf1

File tree

13 files changed

+1702
-2
lines changed

13 files changed

+1702
-2
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2023 Lukas Fisch
3+
Copyright (c) 2023 codingfisch
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,91 @@
11
# torchreg
2-
Lightweight image registration library using PyTorch
2+
3+
torchreg is a tiny (~100 lines) PyTorch-based library for 2D and 3D image registration.
4+
5+
<p float="left", align="center">
6+
<img src="https://github.com/codingfisch/torchreg/blob/main/examples/alice_big.jpg" width="256"/>
7+
<img src="https://github.com/codingfisch/torchreg/blob/main/examples/alice_small.jpg" width="256"/>
8+
<img src="https://github.com/codingfisch/torchreg/assets/55840648/dbf414cc-75e5-477c-9794-32f97a16ea21" width="256"/>
9+
</p>
10+
11+
## Usage
12+
Affine Registration of two image tensors is done via:
13+
```python
14+
from torchreg import AffineRegistration
15+
16+
# Load images as torch Tensors
17+
small_alice = ... # Tensor with shape [1, 3 (color channel), 1024 (pixel), 1024 (pixel)]
18+
big_alice = ... # Tensor with shape [1, 3 (color channel), 1024 (pixel), 1024 (pixel)]
19+
# Intialize AffineRegistration
20+
reg = AffineRegistration(is_3d=False)
21+
# Run it!
22+
moved_alice = reg(small_alice, big_alice)
23+
```
24+
25+
## Features
26+
27+
Multiresolution approach to save compute (per default 1/4 + 1/2 of original resolution for 500 + 100 iterations)
28+
```python
29+
reg = AffineRegistration(scales=(4, 2), iterations=(500, 100))
30+
```
31+
Choosing which operations (translation, rotation, zoom, shear) to optimize
32+
```python
33+
reg = AffineRegistration(with_zoom=False, with_shear=False)
34+
```
35+
Custom initial parameters
36+
```python
37+
reg = AffineRegistration(zoom=torch.Tensor([[1.5, 2.]]))
38+
```
39+
Custom dissimilarity functions and optimizers
40+
```python
41+
def dice_loss(x1, x2):
42+
dim = [2, 3, 4] if len(x2.shape) == 5 else [2, 3]
43+
inter = torch.sum(x1 * x2, dim=dim)
44+
union = torch.sum(x1 + x2, dim=dim)
45+
return 1 - (2. * inter / union).mean()
46+
47+
reg = AffineRegistration(dissimilairity_function=dice_loss, optimizer=torch.optim.Adam)
48+
```
49+
CUDA support (NVIDIA GPU)
50+
```python
51+
moved_alice = reg(moving=big_alice.cuda(), static=small_alice.cuda())
52+
```
53+
MPS support (Apple M1 or M2)
54+
```python
55+
moved_alice = reg(moving=big_alice.to('mps'), static=small_alice.to('mps'))
56+
```
57+
58+
After the registration is run, you can apply it to new images (coregistration)
59+
```python
60+
another_moved_alice = reg.transform(another_alice, shape=(256, 256))
61+
```
62+
with desired output shape.
63+
64+
You can access the affine
65+
```python
66+
affine = reg.get_affine()
67+
```
68+
and the four parameters (translation, rotation, zoom, shear)
69+
```python
70+
translation = reg.parameters[0]
71+
rotation = reg.parameters[1]
72+
zoom = reg.parameters[2]
73+
shear = reg.parameters[3]
74+
```
75+
76+
## Installation
77+
```bash
78+
pip install torchreg
79+
```
80+
81+
## Examples/Tutorials
82+
83+
There are three example notebooks:
84+
85+
- [examples/basics.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/basic.ipynb) shows the basics by using small cubes/squares as image data
86+
- [examples/images.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/image.ipynb) shows how to register alice_big.jpg to alice_small.jpg
87+
- [examples/mri.ipynb](https://github.com/codingfisch/torchreg/blob/main/examples/mri.ipynb) shows how to register MR images (Nifti files) including co-, parallel and multimodal registration
88+
89+
## Background
90+
91+
If you want to know how the core of this package works, read [the blog post]()!

examples/alice_big.jpg

162 KB
Loading

examples/alice_small.jpg

158 KB
Loading

examples/basic.ipynb

Lines changed: 484 additions & 0 deletions
Large diffs are not rendered by default.

examples/image.ipynb

Lines changed: 239 additions & 0 deletions
Large diffs are not rendered by default.

examples/mri.ipynb

Lines changed: 687 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
[tool.poetry]
2+
name = 'torchreg'
3+
version = '0.0.1'
4+
description = 'Lightweight image registration library using PyTorch'
5+
authors = ['codingfisch <[email protected]>']
6+
license = 'MIT'
7+
readme = 'README.md'
8+
repository = 'https://github.com/codingfisch/torchreg'
9+
classifiers = [
10+
'Programming Language :: Python :: 3',
11+
'Operating System :: OS Independent',
12+
'Intended Audience :: Science/Research'
13+
]
14+
15+
16+
[tool.poetry.dependencies]
17+
python = '^3.0'
18+
torch = '*'
19+
tqdm = '*'
20+
21+
22+
[tool.poetry.group.test.dependencies]
23+
pytest = '*'
24+
25+
26+
[build-system]
27+
requires = ['poetry-core']
28+
build-backend = 'poetry.core.masonry.api'

test/__init__.py

Whitespace-only changes.

test/test_affine.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
from unittest import TestCase
3+
from torchreg import AffineRegistration
4+
from torchreg.affine import compose_affine, affine_transform, init_parameters, _check_parameter_shapes
5+
6+
7+
class TestAffineRegistration(TestCase):
8+
def test_fit(self):
9+
for batch_size in [1, 2]:
10+
for n_dim in [2, 3]:
11+
reg = AffineRegistration(scales=(1,), is_3d=n_dim == 3, learning_rate=1e-1, verbose=False)
12+
moving = synthetic_image(batch_size, n_dim, shift=1)
13+
static = synthetic_image(batch_size, n_dim, shift=0)
14+
fitted_moved = reg(moving, static, return_moved=True)
15+
fitted_affine = reg.get_affine()
16+
affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]])
17+
affine[:, -1, -1] += -1/3
18+
self.assertTrue(torch.allclose(fitted_affine, affine, atol=1e-2))
19+
moved = affine_transform(moving, affine)
20+
self.assertTrue(torch.allclose(fitted_moved, moved, atol=1e-2))
21+
22+
def test_affine_transform(self):
23+
for batch_size in [1, 2]:
24+
for n_dim in [2, 3]:
25+
moving = synthetic_image(batch_size, n_dim, shift=1)
26+
static = synthetic_image(batch_size, n_dim, shift=0)
27+
affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]])
28+
affine[:, -1, -1] += -1/3
29+
moved = affine_transform(moving, affine)
30+
self.assertTrue(torch.allclose(moved, static, atol=1e-6))
31+
32+
def test_init_parameters(self):
33+
for batch_size in [1, 2]:
34+
for is_3d in [False, True]:
35+
params = init_parameters(is_3d=is_3d, batch_size=batch_size)
36+
self.assertIsInstance(params, list)
37+
self.assertEqual(len(params), 4)
38+
for param in params:
39+
self.assertTrue(isinstance(param, torch.nn.Parameter))
40+
_check_parameter_shapes(*params, is_3d=is_3d, batch_size=batch_size)
41+
42+
def test_compose_affine(self):
43+
for batch_size in [1, 2]:
44+
for n_dim in [2, 3]:
45+
translation = torch.zeros(batch_size, n_dim)
46+
rotation = torch.stack(batch_size * [torch.eye(n_dim)])
47+
zoom = torch.ones(batch_size, n_dim)
48+
shear = torch.zeros(batch_size, n_dim)
49+
affine = compose_affine(translation, rotation, zoom, shear)
50+
id_affine = torch.stack(batch_size * [torch.eye(n_dim + 1)[:n_dim]])
51+
self.assertTrue(torch.equal(affine, id_affine))
52+
53+
54+
def synthetic_image(batch_size, n_dim, shift):
55+
shape = [batch_size, 1, 7, 7, 7][:2 + n_dim]
56+
x = torch.zeros(*shape)
57+
if n_dim == 3:
58+
x[:, :, 2 - shift:5 - shift, 2:5, 2:5] = 1
59+
else:
60+
x[:, :, 2 - shift:5 - shift, 2:5] = 1
61+
return x

0 commit comments

Comments
 (0)