Skip to content

Commit 9acdafe

Browse files
committed
update code
1 parent c833adc commit 9acdafe

File tree

150 files changed

+24400
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

150 files changed

+24400
-0
lines changed

MyLoss/ND_Crossentropy.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""
2+
CrossentropyND and TopKLoss are from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/ND_Crossentropy.py
3+
"""
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from scipy.ndimage import distance_transform_edt
8+
import numpy as np
9+
10+
11+
class CrossentropyND(torch.nn.CrossEntropyLoss):
12+
"""
13+
Network has to have NO NONLINEARITY!
14+
"""
15+
def forward(self, inp, target):
16+
target = target.long()
17+
num_classes = inp.size()[1]
18+
19+
i0 = 1
20+
i1 = 2
21+
22+
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
23+
inp = inp.transpose(i0, i1)
24+
i0 += 1
25+
i1 += 1
26+
27+
inp = inp.contiguous()
28+
inp = inp.view(-1, num_classes)
29+
30+
target = target.view(-1,)
31+
32+
return super(CrossentropyND, self).forward(inp, target)
33+
34+
class TopKLoss(CrossentropyND):
35+
"""
36+
Network has to have NO LINEARITY!
37+
"""
38+
def __init__(self, weight=None, ignore_index=-100, k=10):
39+
self.k = k
40+
super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)
41+
42+
def forward(self, inp, target):
43+
target = target[:, 0].long()
44+
res = super(TopKLoss, self).forward(inp, target)
45+
num_voxels = np.prod(res.shape)
46+
res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
47+
return res.mean()
48+
49+
50+
class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
51+
"""
52+
Network has to have NO NONLINEARITY!
53+
"""
54+
def __init__(self, weight=None):
55+
super(WeightedCrossEntropyLoss, self).__init__()
56+
self.weight = weight
57+
58+
def forward(self, inp, target):
59+
target = target.long()
60+
num_classes = inp.size()[1]
61+
62+
i0 = 1
63+
i1 = 2
64+
65+
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
66+
inp = inp.transpose(i0, i1)
67+
i0 += 1
68+
i1 += 1
69+
70+
inp = inp.contiguous()
71+
inp = inp.view(-1, num_classes)
72+
73+
target = target.view(-1,)
74+
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
75+
76+
return wce_loss(inp, target)
77+
78+
class WeightedCrossEntropyLossV2(torch.nn.Module):
79+
"""
80+
WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
81+
Network has to have NO LINEARITY!
82+
copy from: https://github.com/wolny/pytorch-3dunet/blob/6e5a24b6438f8c631289c10638a17dea14d42051/unet3d/losses.py#L121
83+
"""
84+
85+
def forward(self, net_output, gt):
86+
# compute weight
87+
# shp_x = net_output.shape
88+
# shp_y = gt.shape
89+
# print(shp_x, shp_y)
90+
# with torch.no_grad():
91+
# if len(shp_x) != len(shp_y):
92+
# gt = gt.view((shp_y[0], 1, *shp_y[1:]))
93+
94+
# if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
95+
# # if this is the case then gt is probably already a one hot encoding
96+
# y_onehot = gt
97+
# else:
98+
# gt = gt.long()
99+
# y_onehot = torch.zeros(shp_x)
100+
# if net_output.device.type == "cuda":
101+
# y_onehot = y_onehot.cuda(net_output.device.index)
102+
# y_onehot.scatter_(1, gt, 1)
103+
# y_onehot = y_onehot.transpose(0,1).contiguous()
104+
# class_weights = (torch.einsum("cbxyz->c", y_onehot).type(torch.float32) + 1e-10)/torch.numel(y_onehot)
105+
# print('class_weights', class_weights)
106+
# class_weights = class_weights.view(-1)
107+
class_weights = torch.cuda.FloatTensor([0.2,0.8])
108+
gt = gt.long()
109+
num_classes = net_output.size()[1]
110+
# class_weights = self._class_weights(inp)
111+
112+
i0 = 1
113+
i1 = 2
114+
115+
while i1 < len(net_output.shape): # this is ugly but torch only allows to transpose two axes at once
116+
net_output = net_output.transpose(i0, i1)
117+
i0 += 1
118+
i1 += 1
119+
120+
net_output = net_output.contiguous()
121+
net_output = net_output.view(-1, num_classes) #shape=(vox_num, class_num)
122+
123+
gt = gt.view(-1,)
124+
# print('*'*20)
125+
return F.cross_entropy(net_output, gt) # , weight=class_weights
126+
127+
# @staticmethod
128+
# def _class_weights(input):
129+
# # normalize the input first
130+
# input = F.softmax(input, _stacklevel=5)
131+
# flattened = flatten(input)
132+
# nominator = (1. - flattened).sum(-1)
133+
# denominator = flattened.sum(-1)
134+
# class_weights = Variable(nominator / denominator, requires_grad=False)
135+
# return class_weights
136+
137+
def flatten(tensor):
138+
"""Flattens a given tensor such that the channel axis is first.
139+
The shapes are transformed as follows:
140+
(N, C, D, H, W) -> (C, N * D * H * W)
141+
"""
142+
C = tensor.size(1)
143+
# new axis order
144+
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
145+
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
146+
transposed = tensor.permute(axis_order)
147+
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
148+
transposed = transposed.contiguous()
149+
return transposed.view(C, -1)
150+
151+
def compute_edts_forPenalizedLoss(GT):
152+
"""
153+
GT.shape = (batch_size, x,y,z)
154+
only for binary segmentation
155+
"""
156+
GT = np.squeeze(GT)
157+
res = np.zeros(GT.shape)
158+
for i in range(GT.shape[0]):
159+
posmask = GT[i]
160+
negmask = ~posmask
161+
pos_edt = distance_transform_edt(posmask)
162+
pos_edt = (np.max(pos_edt)-pos_edt)*posmask
163+
neg_edt = distance_transform_edt(negmask)
164+
neg_edt = (np.max(neg_edt)-neg_edt)*negmask
165+
res[i] = pos_edt/np.max(pos_edt) + neg_edt/np.max(neg_edt)
166+
return res
167+
168+
class DisPenalizedCE(torch.nn.Module):
169+
"""
170+
Only for binary 3D segmentation
171+
172+
Network has to have NO NONLINEARITY!
173+
"""
174+
175+
def forward(self, inp, target):
176+
# print(inp.shape, target.shape) # (batch, 2, xyz), (batch, 2, xyz)
177+
# compute distance map of ground truth
178+
with torch.no_grad():
179+
dist = compute_edts_forPenalizedLoss(target.cpu().numpy()>0.5) + 1.0
180+
181+
dist = torch.from_numpy(dist)
182+
if dist.device != inp.device:
183+
dist = dist.to(inp.device).type(torch.float32)
184+
dist = dist.view(-1,)
185+
186+
target = target.long()
187+
num_classes = inp.size()[1]
188+
189+
i0 = 1
190+
i1 = 2
191+
192+
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
193+
inp = inp.transpose(i0, i1)
194+
i0 += 1
195+
i1 += 1
196+
197+
inp = inp.contiguous()
198+
inp = inp.view(-1, num_classes)
199+
log_sm = torch.nn.LogSoftmax(dim=1)
200+
inp_logs = log_sm(inp)
201+
202+
target = target.view(-1,)
203+
# loss = nll_loss(inp_logs, target)
204+
loss = -inp_logs[range(target.shape[0]), target]
205+
# print(loss.type(), dist.type())
206+
weighted_loss = loss*dist
207+
208+
return loss.mean()
209+
210+
211+
def nll_loss(input, target):
212+
"""
213+
customized nll loss
214+
source: https://medium.com/@zhang_yang/understanding-cross-entropy-
215+
implementation-in-pytorch-softmax-log-softmax-nll-cross-entropy-416a2b200e34
216+
"""
217+
loss = -input[range(target.shape[0]), target]
218+
return loss.mean()
219+
220+

MyLoss/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
from .loss_factory import create_loss
3+
from .boundary_loss import BDLoss, SoftDiceLoss, DC_and_BD_loss, HDDTBinaryLoss,\
4+
DC_and_HDBinary_loss, DistBinaryDiceLoss
5+
from .dice_loss import GDiceLoss, GDiceLossV2, SSLoss, SoftDiceLoss,\
6+
IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, DC_and_CE_loss,\
7+
PenaltyGDiceLoss, DC_and_topk_loss, ExpLog_loss
8+
from .focal_loss import FocalLoss
9+
from .hausdorff import HausdorffDTLoss, HausdorffERLoss
10+
from .lovasz_loss import LovaszSoftmax
11+
from .ND_Crossentropy import CrossentropyND, TopKLoss, WeightedCrossEntropyLoss,\
12+
WeightedCrossEntropyLossV2, DisPenalizedCE
Binary file not shown.
1014 Bytes
Binary file not shown.
9.21 KB
Binary file not shown.
14.8 KB
Binary file not shown.
2.85 KB
Binary file not shown.
4.07 KB
Binary file not shown.
2.41 KB
Binary file not shown.
2.23 KB
Binary file not shown.

0 commit comments

Comments
 (0)