|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from tqdm import tqdm |
| 4 | + |
| 5 | + |
| 6 | +class AffineRegistration: |
| 7 | + def __init__(self, scales=(4, 2), iterations=(500, 100), is_3d=True, learning_rate=1e-2, |
| 8 | + verbose=True, dissimilarity_function=torch.nn.MSELoss(), optimizer=torch.optim.Adam, |
| 9 | + init_translation=None, init_rotation=None, init_zoom=None, init_shear=None, |
| 10 | + with_translation=True, with_rotation=True, with_zoom=True, with_shear=False, |
| 11 | + align_corners=True, interp_mode=None, padding_mode='border'): |
| 12 | + self.scales = scales |
| 13 | + self.iterations = iterations[:len(scales)] |
| 14 | + self.is_3d = is_3d |
| 15 | + self.learning_rate = learning_rate |
| 16 | + self.verbose = verbose |
| 17 | + self.dissimilarity_function = dissimilarity_function |
| 18 | + self.optimizer = optimizer |
| 19 | + self.inits = (init_translation, init_rotation, init_zoom, init_shear) |
| 20 | + self.withs = (with_translation, with_rotation, with_zoom, with_shear) |
| 21 | + self.align_corners = align_corners |
| 22 | + self.interp_mode = 'trilinear' if is_3d else 'bilinear' if interp_mode is None else interp_mode |
| 23 | + self.padding_mode = padding_mode |
| 24 | + self._parameters = None |
| 25 | + |
| 26 | + def __call__(self, moving, static, return_moved=True): |
| 27 | + if len(moving.shape) - 4 != self.is_3d or len(static.shape) - 4 != self.is_3d: |
| 28 | + raise ValueError(f'Expected moving and static to be {4 + self.is_3d}D Tensors (2 + Spatial Dims.). ' |
| 29 | + f'Got size {moving.shape} and {static.shape}.') |
| 30 | + if moving.shape != static.shape: |
| 31 | + raise ValueError(f'Expected moving and static to have the same size. ' |
| 32 | + f'Got size {moving.shape} and {static.shape}.') |
| 33 | + |
| 34 | + self._parameters = init_parameters(self.is_3d, len(static), static.device, *self.withs, *self.inits) |
| 35 | + interp_kwargs = {'mode': self.interp_mode, 'align_corners': self.align_corners} |
| 36 | + moving_ = F.interpolate(moving, static.shape[2:], **interp_kwargs) |
| 37 | + for scale, iters in zip(self.scales, self.iterations): |
| 38 | + moving_small = F.interpolate(moving_, scale_factor=1 / scale, **interp_kwargs) |
| 39 | + static_small = F.interpolate(static, scale_factor=1 / scale, **interp_kwargs) |
| 40 | + self._fit(moving_small, static_small, iters) |
| 41 | + return self.transform(moving, static.shape[2:]).detach() if return_moved else None |
| 42 | + |
| 43 | + def _fit(self, moving, static, iterations): |
| 44 | + optimizer = self.optimizer(self._parameters, self.learning_rate) |
| 45 | + progress_bar = tqdm(range(iterations), disable=not self.verbose) |
| 46 | + for self.iter in progress_bar: |
| 47 | + optimizer.zero_grad() |
| 48 | + moved = self.transform(moving, static.shape[2:], with_grad=True) |
| 49 | + loss = self.dissimilarity_function(moved, static) |
| 50 | + progress_bar.set_description(f'Shape: {[*static.shape]}; Dissimiliarity: {loss.item()}') |
| 51 | + loss.backward() |
| 52 | + optimizer.step() |
| 53 | + |
| 54 | + def transform(self, moving, shape=None, with_grad=False): |
| 55 | + affine = self.get_affine(with_grad) |
| 56 | + return affine_transform(moving, affine, shape, self.interp_mode, self.padding_mode, self.align_corners) |
| 57 | + |
| 58 | + def get_affine(self, with_grad=False): |
| 59 | + affine = compose_affine(*self._parameters) |
| 60 | + return affine if with_grad else affine.detach() |
| 61 | + |
| 62 | + |
| 63 | +def affine_transform(x, affine, shape=None, mode='bilinear', padding_mode='border', align_corners=True): |
| 64 | + shape = x.shape[2:] if shape is None else shape |
| 65 | + grid = F.affine_grid(affine, [len(x), len(shape), *shape], align_corners) |
| 66 | + sample_mode = 'bilinear' if mode == 'trilinear' else mode # grid_sample converts 'bi-' to 'trilinear' internally |
| 67 | + return F.grid_sample(x, grid, sample_mode, padding_mode, align_corners) |
| 68 | + |
| 69 | + |
| 70 | +def init_parameters(is_3d=True, batch_size=1, device='cpu', with_translation=True, with_rotation=True, with_zoom=True, |
| 71 | + with_shear=True, init_translation=None, init_rotation=None, init_zoom=None, init_shear=None): |
| 72 | + _check_parameter_shapes(init_translation, init_rotation, init_zoom, init_shear, is_3d, batch_size) |
| 73 | + n_dim = 2 + is_3d |
| 74 | + translation = torch.zeros(batch_size, n_dim).to(device) if init_translation is None else init_translation |
| 75 | + rotation = torch.stack(batch_size * [torch.eye(n_dim)]).to(device) if init_rotation is None else init_rotation |
| 76 | + zoom = torch.ones(batch_size, n_dim).to(device) if init_zoom is None else init_zoom |
| 77 | + shear = torch.zeros(batch_size, n_dim).to(device) if init_shear is None else init_shear |
| 78 | + params = [translation, rotation, zoom, shear] |
| 79 | + with_grad = [with_translation, with_rotation, with_zoom, with_shear] |
| 80 | + return [torch.nn.Parameter(param, requires_grad=grad) for param, grad in zip(params, with_grad)] |
| 81 | + |
| 82 | + |
| 83 | +def compose_affine(translation, rotation, zoom, shear): |
| 84 | + _check_parameter_shapes(translation, rotation, zoom, shear, zoom.shape[-1] == 3, zoom.shape[0]) |
| 85 | + square_matrix = torch.diag_embed(zoom) |
| 86 | + if zoom.shape[-1] == 3: |
| 87 | + square_matrix[..., 0, 1:] = shear[..., :2] |
| 88 | + square_matrix[..., 1, 2] = shear[..., 2] |
| 89 | + else: |
| 90 | + square_matrix[..., 0, 1] = shear[..., 0] |
| 91 | + square_matrix = rotation @ square_matrix |
| 92 | + return torch.cat([square_matrix, translation[:, :, None]], dim=-1) |
| 93 | + |
| 94 | + |
| 95 | +def _check_parameter_shapes(translation, rotation, zoom, shear, is_3d=True, batch_size=1): |
| 96 | + n_dim = 2 + is_3d |
| 97 | + params = {'translation': translation, 'rotation': rotation, 'zoom': zoom, 'shear': shear} |
| 98 | + for name, param in params.items(): |
| 99 | + if param is not None: |
| 100 | + desired_shape = (batch_size, n_dim, n_dim) if name == 'rotation' else (batch_size, n_dim) |
| 101 | + if param.shape != desired_shape: |
| 102 | + raise ValueError(f'Expected {name} to be size {desired_shape} since batch_size is {batch_size} ' |
| 103 | + f'and is_3d is {is_3d} -> {2 + is_3d} dimensions. Got size {param.shape}.') |
0 commit comments