|
| 1 | +""" |
| 2 | +CrossentropyND and TopKLoss are from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/ND_Crossentropy.py |
| 3 | +""" |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | +from scipy.ndimage import distance_transform_edt |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | +class CrossentropyND(torch.nn.CrossEntropyLoss): |
| 12 | + """ |
| 13 | + Network has to have NO NONLINEARITY! |
| 14 | + """ |
| 15 | + def forward(self, inp, target): |
| 16 | + target = target.long() |
| 17 | + num_classes = inp.size()[1] |
| 18 | + |
| 19 | + i0 = 1 |
| 20 | + i1 = 2 |
| 21 | + |
| 22 | + while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once |
| 23 | + inp = inp.transpose(i0, i1) |
| 24 | + i0 += 1 |
| 25 | + i1 += 1 |
| 26 | + |
| 27 | + inp = inp.contiguous() |
| 28 | + inp = inp.view(-1, num_classes) |
| 29 | + |
| 30 | + target = target.view(-1,) |
| 31 | + |
| 32 | + return super(CrossentropyND, self).forward(inp, target) |
| 33 | + |
| 34 | +class TopKLoss(CrossentropyND): |
| 35 | + """ |
| 36 | + Network has to have NO LINEARITY! |
| 37 | + """ |
| 38 | + def __init__(self, weight=None, ignore_index=-100, k=10): |
| 39 | + self.k = k |
| 40 | + super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False) |
| 41 | + |
| 42 | + def forward(self, inp, target): |
| 43 | + target = target[:, 0].long() |
| 44 | + res = super(TopKLoss, self).forward(inp, target) |
| 45 | + num_voxels = np.prod(res.shape) |
| 46 | + res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) |
| 47 | + return res.mean() |
| 48 | + |
| 49 | + |
| 50 | +class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss): |
| 51 | + """ |
| 52 | + Network has to have NO NONLINEARITY! |
| 53 | + """ |
| 54 | + def __init__(self, weight=None): |
| 55 | + super(WeightedCrossEntropyLoss, self).__init__() |
| 56 | + self.weight = weight |
| 57 | + |
| 58 | + def forward(self, inp, target): |
| 59 | + target = target.long() |
| 60 | + num_classes = inp.size()[1] |
| 61 | + |
| 62 | + i0 = 1 |
| 63 | + i1 = 2 |
| 64 | + |
| 65 | + while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once |
| 66 | + inp = inp.transpose(i0, i1) |
| 67 | + i0 += 1 |
| 68 | + i1 += 1 |
| 69 | + |
| 70 | + inp = inp.contiguous() |
| 71 | + inp = inp.view(-1, num_classes) |
| 72 | + |
| 73 | + target = target.view(-1,) |
| 74 | + wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight) |
| 75 | + |
| 76 | + return wce_loss(inp, target) |
| 77 | + |
| 78 | +class WeightedCrossEntropyLossV2(torch.nn.Module): |
| 79 | + """ |
| 80 | + WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf |
| 81 | + Network has to have NO LINEARITY! |
| 82 | + copy from: https://github.com/wolny/pytorch-3dunet/blob/6e5a24b6438f8c631289c10638a17dea14d42051/unet3d/losses.py#L121 |
| 83 | + """ |
| 84 | + |
| 85 | + def forward(self, net_output, gt): |
| 86 | + # compute weight |
| 87 | + # shp_x = net_output.shape |
| 88 | + # shp_y = gt.shape |
| 89 | + # print(shp_x, shp_y) |
| 90 | + # with torch.no_grad(): |
| 91 | + # if len(shp_x) != len(shp_y): |
| 92 | + # gt = gt.view((shp_y[0], 1, *shp_y[1:])) |
| 93 | + |
| 94 | + # if all([i == j for i, j in zip(net_output.shape, gt.shape)]): |
| 95 | + # # if this is the case then gt is probably already a one hot encoding |
| 96 | + # y_onehot = gt |
| 97 | + # else: |
| 98 | + # gt = gt.long() |
| 99 | + # y_onehot = torch.zeros(shp_x) |
| 100 | + # if net_output.device.type == "cuda": |
| 101 | + # y_onehot = y_onehot.cuda(net_output.device.index) |
| 102 | + # y_onehot.scatter_(1, gt, 1) |
| 103 | + # y_onehot = y_onehot.transpose(0,1).contiguous() |
| 104 | + # class_weights = (torch.einsum("cbxyz->c", y_onehot).type(torch.float32) + 1e-10)/torch.numel(y_onehot) |
| 105 | + # print('class_weights', class_weights) |
| 106 | + # class_weights = class_weights.view(-1) |
| 107 | + class_weights = torch.cuda.FloatTensor([0.2,0.8]) |
| 108 | + gt = gt.long() |
| 109 | + num_classes = net_output.size()[1] |
| 110 | + # class_weights = self._class_weights(inp) |
| 111 | + |
| 112 | + i0 = 1 |
| 113 | + i1 = 2 |
| 114 | + |
| 115 | + while i1 < len(net_output.shape): # this is ugly but torch only allows to transpose two axes at once |
| 116 | + net_output = net_output.transpose(i0, i1) |
| 117 | + i0 += 1 |
| 118 | + i1 += 1 |
| 119 | + |
| 120 | + net_output = net_output.contiguous() |
| 121 | + net_output = net_output.view(-1, num_classes) #shape=(vox_num, class_num) |
| 122 | + |
| 123 | + gt = gt.view(-1,) |
| 124 | + # print('*'*20) |
| 125 | + return F.cross_entropy(net_output, gt) # , weight=class_weights |
| 126 | + |
| 127 | + # @staticmethod |
| 128 | + # def _class_weights(input): |
| 129 | + # # normalize the input first |
| 130 | + # input = F.softmax(input, _stacklevel=5) |
| 131 | + # flattened = flatten(input) |
| 132 | + # nominator = (1. - flattened).sum(-1) |
| 133 | + # denominator = flattened.sum(-1) |
| 134 | + # class_weights = Variable(nominator / denominator, requires_grad=False) |
| 135 | + # return class_weights |
| 136 | + |
| 137 | +def flatten(tensor): |
| 138 | + """Flattens a given tensor such that the channel axis is first. |
| 139 | + The shapes are transformed as follows: |
| 140 | + (N, C, D, H, W) -> (C, N * D * H * W) |
| 141 | + """ |
| 142 | + C = tensor.size(1) |
| 143 | + # new axis order |
| 144 | + axis_order = (1, 0) + tuple(range(2, tensor.dim())) |
| 145 | + # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) |
| 146 | + transposed = tensor.permute(axis_order) |
| 147 | + # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) |
| 148 | + transposed = transposed.contiguous() |
| 149 | + return transposed.view(C, -1) |
| 150 | + |
| 151 | +def compute_edts_forPenalizedLoss(GT): |
| 152 | + """ |
| 153 | + GT.shape = (batch_size, x,y,z) |
| 154 | + only for binary segmentation |
| 155 | + """ |
| 156 | + GT = np.squeeze(GT) |
| 157 | + res = np.zeros(GT.shape) |
| 158 | + for i in range(GT.shape[0]): |
| 159 | + posmask = GT[i] |
| 160 | + negmask = ~posmask |
| 161 | + pos_edt = distance_transform_edt(posmask) |
| 162 | + pos_edt = (np.max(pos_edt)-pos_edt)*posmask |
| 163 | + neg_edt = distance_transform_edt(negmask) |
| 164 | + neg_edt = (np.max(neg_edt)-neg_edt)*negmask |
| 165 | + res[i] = pos_edt/np.max(pos_edt) + neg_edt/np.max(neg_edt) |
| 166 | + return res |
| 167 | + |
| 168 | +class DisPenalizedCE(torch.nn.Module): |
| 169 | + """ |
| 170 | + Only for binary 3D segmentation |
| 171 | +
|
| 172 | + Network has to have NO NONLINEARITY! |
| 173 | + """ |
| 174 | + |
| 175 | + def forward(self, inp, target): |
| 176 | + # print(inp.shape, target.shape) # (batch, 2, xyz), (batch, 2, xyz) |
| 177 | + # compute distance map of ground truth |
| 178 | + with torch.no_grad(): |
| 179 | + dist = compute_edts_forPenalizedLoss(target.cpu().numpy()>0.5) + 1.0 |
| 180 | + |
| 181 | + dist = torch.from_numpy(dist) |
| 182 | + if dist.device != inp.device: |
| 183 | + dist = dist.to(inp.device).type(torch.float32) |
| 184 | + dist = dist.view(-1,) |
| 185 | + |
| 186 | + target = target.long() |
| 187 | + num_classes = inp.size()[1] |
| 188 | + |
| 189 | + i0 = 1 |
| 190 | + i1 = 2 |
| 191 | + |
| 192 | + while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once |
| 193 | + inp = inp.transpose(i0, i1) |
| 194 | + i0 += 1 |
| 195 | + i1 += 1 |
| 196 | + |
| 197 | + inp = inp.contiguous() |
| 198 | + inp = inp.view(-1, num_classes) |
| 199 | + log_sm = torch.nn.LogSoftmax(dim=1) |
| 200 | + inp_logs = log_sm(inp) |
| 201 | + |
| 202 | + target = target.view(-1,) |
| 203 | + # loss = nll_loss(inp_logs, target) |
| 204 | + loss = -inp_logs[range(target.shape[0]), target] |
| 205 | + # print(loss.type(), dist.type()) |
| 206 | + weighted_loss = loss*dist |
| 207 | + |
| 208 | + return loss.mean() |
| 209 | + |
| 210 | + |
| 211 | +def nll_loss(input, target): |
| 212 | + """ |
| 213 | + customized nll loss |
| 214 | + source: https://medium.com/@zhang_yang/understanding-cross-entropy- |
| 215 | + implementation-in-pytorch-softmax-log-softmax-nll-cross-entropy-416a2b200e34 |
| 216 | + """ |
| 217 | + loss = -input[range(target.shape[0]), target] |
| 218 | + return loss.mean() |
| 219 | + |
| 220 | + |
0 commit comments