|
| 1 | +""" |
| 2 | +*Preliminary* pytorch implementation. |
| 3 | +
|
| 4 | +Losses for VoxelMorph |
| 5 | +""" |
| 6 | + |
| 7 | +import math |
| 8 | +import torch |
| 9 | +import numpy as np |
| 10 | +from Model.config import Config as args |
| 11 | +import torch.nn.functional as F |
| 12 | +import pystrum.pynd.ndutils as nd |
| 13 | + |
| 14 | + |
| 15 | +def gradient_loss(s, penalty='l2'): |
| 16 | + dy = torch.abs(s[:, :, 1:, :] - s[:, :, :-1, :]) |
| 17 | + dx = torch.abs(s[:, :, :, 1:] - s[:, :, :, :-1]) |
| 18 | + |
| 19 | + if penalty == 'l2': |
| 20 | + dy = dy * dy |
| 21 | + dx = dx * dx |
| 22 | + |
| 23 | + d = torch.mean(dx) + torch.mean(dy) |
| 24 | + return d / 2.0 |
| 25 | + |
| 26 | + |
| 27 | +def mse_loss(x, y): |
| 28 | + return torch.mean((x - y) ** 2) |
| 29 | + |
| 30 | + |
| 31 | +def compute_label_dice(pred, gt): |
| 32 | + return DSC(gt == 255, pred == 255) |
| 33 | + |
| 34 | + |
| 35 | +def DSC(pred, target): |
| 36 | + smooth = 1e-5 |
| 37 | + intersection = torch.mul(pred, target).sum() |
| 38 | + return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth) |
| 39 | + |
| 40 | + |
| 41 | +def ncc_loss(I, J, win=None): |
| 42 | + ''' |
| 43 | + 输入大小是[B,C,D,W,H]格式的,在计算ncc时用卷积来实现指定窗口内求和 |
| 44 | + ''' |
| 45 | + device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu' |
| 46 | + ndims = len(list(I.size())) - 2 |
| 47 | + assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims |
| 48 | + if win is None: |
| 49 | + win = [9] * ndims |
| 50 | + sum_filt = torch.ones([1, 1, *win]).to(device) |
| 51 | + pad_no = math.floor(win[0] / 2) |
| 52 | + stride = [1] * ndims |
| 53 | + padding = [pad_no] * ndims |
| 54 | + I_var, J_var, cross = compute_local_sums(I, J, sum_filt, stride, padding, win) |
| 55 | + cc = cross * cross / (I_var * J_var + 1e-5) |
| 56 | + return -1 * torch.mean(cc) |
| 57 | + |
| 58 | + |
| 59 | +def compute_local_sums(I, J, filt, stride, padding, win): |
| 60 | + I2, J2, IJ = I * I, J * J, I * J |
| 61 | + I_sum = F.conv3d(I, filt, stride=stride, padding=padding) |
| 62 | + J_sum = F.conv3d(J, filt, stride=stride, padding=padding) |
| 63 | + I2_sum = F.conv3d(I2, filt, stride=stride, padding=padding) |
| 64 | + J2_sum = F.conv3d(J2, filt, stride=stride, padding=padding) |
| 65 | + IJ_sum = F.conv3d(IJ, filt, stride=stride, padding=padding) |
| 66 | + win_size = np.prod(win) |
| 67 | + u_I = I_sum / win_size |
| 68 | + u_J = J_sum / win_size |
| 69 | + cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size |
| 70 | + I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size |
| 71 | + J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size |
| 72 | + return I_var, J_var, cross |
| 73 | + |
| 74 | + |
| 75 | +def cc_loss(x, y): |
| 76 | + # 根据互相关公式进行计算 |
| 77 | + dim = [2, 3, 4] |
| 78 | + mean_x = torch.mean(x, dim, keepdim=True) |
| 79 | + mean_y = torch.mean(y, dim, keepdim=True) |
| 80 | + mean_x2 = torch.mean(x ** 2, dim, keepdim=True) |
| 81 | + mean_y2 = torch.mean(y ** 2, dim, keepdim=True) |
| 82 | + stddev_x = torch.sum(torch.sqrt(mean_x2 - mean_x ** 2), dim, keepdim=True) |
| 83 | + stddev_y = torch.sum(torch.sqrt(mean_y2 - mean_y ** 2), dim, keepdim=True) |
| 84 | + return -torch.mean((x - mean_x) * (y - mean_y) / (stddev_x * stddev_y)) |
| 85 | + |
| 86 | + |
| 87 | +def jacobian_determinant(disp): |
| 88 | + """ |
| 89 | + jacobian determinant of a displacement field. |
| 90 | + NB: to compute the spatial gradients, we use np.gradient. |
| 91 | +
|
| 92 | + Parameters: |
| 93 | + disp: 2D or 3D displacement field of size [*vol_shape, nb_dims], |
| 94 | + where vol_shape is of len nb_dims |
| 95 | +
|
| 96 | + Returns: |
| 97 | + jacobian determinant (scalar) |
| 98 | + """ |
| 99 | + |
| 100 | + # check input |
| 101 | + disp = disp.transpose(1, 2, 3, 0) |
| 102 | + volshape = disp.shape[:-1] |
| 103 | + nb_dims = len(volshape) |
| 104 | + assert len(volshape) in (2, 3), 'flow has to be 2D or 3D' |
| 105 | + |
| 106 | + # compute grid |
| 107 | + grid_lst = nd.volsize2ndgrid(volshape) |
| 108 | + grid = np.stack(grid_lst, len(volshape)) |
| 109 | + |
| 110 | + # compute gradients |
| 111 | + J = np.gradient(disp + grid) |
| 112 | + |
| 113 | + # 3D glow |
| 114 | + if nb_dims == 3: |
| 115 | + dx = J[0] |
| 116 | + dy = J[1] |
| 117 | + dz = J[2] |
| 118 | + |
| 119 | + # compute jacobian components |
| 120 | + Jdet0 = dx[..., 0] * (dy[..., 1] * dz[..., 2] - dy[..., 2] * dz[..., 1]) |
| 121 | + Jdet1 = dx[..., 1] * (dy[..., 0] * dz[..., 2] - dy[..., 2] * dz[..., 0]) |
| 122 | + Jdet2 = dx[..., 2] * (dy[..., 0] * dz[..., 1] - dy[..., 1] * dz[..., 0]) |
| 123 | + |
| 124 | + return Jdet0 - Jdet1 + Jdet2 |
| 125 | + |
| 126 | + else: # must be 2 |
| 127 | + |
| 128 | + dfdx = J[0] |
| 129 | + dfdy = J[1] |
| 130 | + |
| 131 | + return dfdx[..., 0] * dfdy[..., 1] - dfdy[..., 0] * dfdx[..., 1] |
| 132 | + |
| 133 | + |
| 134 | +def charbonnier_loss(flow_diff, alpha=0.45, beta=1.0, epsilon=0.01): |
| 135 | + normalization = int(flow_diff.numel()) |
| 136 | + error = torch.pow(torch.square(flow_diff * beta) + epsilon, alpha) |
| 137 | + return torch.sum(error) / normalization |
0 commit comments