Skip to content

Commit 78a9e6a

Browse files
committed
tes
1 parent 2475b66 commit 78a9e6a

13 files changed

+462
-0
lines changed

CETUS/test/patient03.pkl

32 MB
Binary file not shown.

CETUS/train/patient01.pkl

32 MB
Binary file not shown.

CETUS/valid/patient16.pkl

32 MB
Binary file not shown.
689 Bytes
Binary file not shown.
1.02 KB
Binary file not shown.
2.96 KB
Binary file not shown.
4.13 KB
Binary file not shown.
5.06 KB
Binary file not shown.
7.79 KB
Binary file not shown.

Model/config.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class Config:
2+
gpu = '0'
3+
data_path = './CETUS/'
4+
model = 'vm2'
5+
result_dir = './Result'
6+
lr = 1e-5
7+
max_epochs = 201
8+
sim_loss = 'mse'
9+
alpha = 0.04
10+
batch_size = 1
11+
n_save_epoch = 5
12+
model_dir = './Checkpoint/exp_'
13+
log_dir = './Log'
14+
saved_unet_name = 'unet_model.pth'
15+
saved_tnet_name = 'tnet_model.pth'
16+
17+
# test时参数
18+
checkpoint_path = "./Checkpoint/exp_1715837905.2272804/"

Model/dataset.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
3+
import numpy as np
4+
from torch.utils.data import Dataset
5+
import pickle
6+
7+
8+
class CETUS(Dataset):
9+
10+
def __init__(self, dataroot='./dataset_pkl/', split='train'):
11+
# 初始化
12+
self.target_file = dataroot + split + '/'
13+
self.files = os.listdir(self.target_file)
14+
self.data_len = len(self.files)
15+
16+
def __len__(self):
17+
# 返回数据集的大小
18+
return self.data_len
19+
20+
def __getitem__(self, index):
21+
# 索引数据集中的某个数据,还可以对数据进行预处理
22+
# 下标index参数是必须有的,名字任意
23+
24+
file_current = self.files[index]
25+
f = open(self.target_file + file_current, 'rb')
26+
data = pickle.load(f)
27+
for key in data:
28+
data[key] = data[key][np.newaxis, ...]
29+
return data
30+
31+

Model/losses.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
*Preliminary* pytorch implementation.
3+
4+
Losses for VoxelMorph
5+
"""
6+
7+
import math
8+
import torch
9+
import numpy as np
10+
from Model.config import Config as args
11+
import torch.nn.functional as F
12+
import pystrum.pynd.ndutils as nd
13+
14+
15+
def gradient_loss(s, penalty='l2'):
16+
dy = torch.abs(s[:, :, 1:, :] - s[:, :, :-1, :])
17+
dx = torch.abs(s[:, :, :, 1:] - s[:, :, :, :-1])
18+
19+
if penalty == 'l2':
20+
dy = dy * dy
21+
dx = dx * dx
22+
23+
d = torch.mean(dx) + torch.mean(dy)
24+
return d / 2.0
25+
26+
27+
def mse_loss(x, y):
28+
return torch.mean((x - y) ** 2)
29+
30+
31+
def compute_label_dice(pred, gt):
32+
return DSC(gt == 255, pred == 255)
33+
34+
35+
def DSC(pred, target):
36+
smooth = 1e-5
37+
intersection = torch.mul(pred, target).sum()
38+
return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
39+
40+
41+
def ncc_loss(I, J, win=None):
42+
'''
43+
输入大小是[B,C,D,W,H]格式的,在计算ncc时用卷积来实现指定窗口内求和
44+
'''
45+
device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'
46+
ndims = len(list(I.size())) - 2
47+
assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
48+
if win is None:
49+
win = [9] * ndims
50+
sum_filt = torch.ones([1, 1, *win]).to(device)
51+
pad_no = math.floor(win[0] / 2)
52+
stride = [1] * ndims
53+
padding = [pad_no] * ndims
54+
I_var, J_var, cross = compute_local_sums(I, J, sum_filt, stride, padding, win)
55+
cc = cross * cross / (I_var * J_var + 1e-5)
56+
return -1 * torch.mean(cc)
57+
58+
59+
def compute_local_sums(I, J, filt, stride, padding, win):
60+
I2, J2, IJ = I * I, J * J, I * J
61+
I_sum = F.conv3d(I, filt, stride=stride, padding=padding)
62+
J_sum = F.conv3d(J, filt, stride=stride, padding=padding)
63+
I2_sum = F.conv3d(I2, filt, stride=stride, padding=padding)
64+
J2_sum = F.conv3d(J2, filt, stride=stride, padding=padding)
65+
IJ_sum = F.conv3d(IJ, filt, stride=stride, padding=padding)
66+
win_size = np.prod(win)
67+
u_I = I_sum / win_size
68+
u_J = J_sum / win_size
69+
cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
70+
I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
71+
J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
72+
return I_var, J_var, cross
73+
74+
75+
def cc_loss(x, y):
76+
# 根据互相关公式进行计算
77+
dim = [2, 3, 4]
78+
mean_x = torch.mean(x, dim, keepdim=True)
79+
mean_y = torch.mean(y, dim, keepdim=True)
80+
mean_x2 = torch.mean(x ** 2, dim, keepdim=True)
81+
mean_y2 = torch.mean(y ** 2, dim, keepdim=True)
82+
stddev_x = torch.sum(torch.sqrt(mean_x2 - mean_x ** 2), dim, keepdim=True)
83+
stddev_y = torch.sum(torch.sqrt(mean_y2 - mean_y ** 2), dim, keepdim=True)
84+
return -torch.mean((x - mean_x) * (y - mean_y) / (stddev_x * stddev_y))
85+
86+
87+
def jacobian_determinant(disp):
88+
"""
89+
jacobian determinant of a displacement field.
90+
NB: to compute the spatial gradients, we use np.gradient.
91+
92+
Parameters:
93+
disp: 2D or 3D displacement field of size [*vol_shape, nb_dims],
94+
where vol_shape is of len nb_dims
95+
96+
Returns:
97+
jacobian determinant (scalar)
98+
"""
99+
100+
# check input
101+
disp = disp.transpose(1, 2, 3, 0)
102+
volshape = disp.shape[:-1]
103+
nb_dims = len(volshape)
104+
assert len(volshape) in (2, 3), 'flow has to be 2D or 3D'
105+
106+
# compute grid
107+
grid_lst = nd.volsize2ndgrid(volshape)
108+
grid = np.stack(grid_lst, len(volshape))
109+
110+
# compute gradients
111+
J = np.gradient(disp + grid)
112+
113+
# 3D glow
114+
if nb_dims == 3:
115+
dx = J[0]
116+
dy = J[1]
117+
dz = J[2]
118+
119+
# compute jacobian components
120+
Jdet0 = dx[..., 0] * (dy[..., 1] * dz[..., 2] - dy[..., 2] * dz[..., 1])
121+
Jdet1 = dx[..., 1] * (dy[..., 0] * dz[..., 2] - dy[..., 2] * dz[..., 0])
122+
Jdet2 = dx[..., 2] * (dy[..., 0] * dz[..., 1] - dy[..., 1] * dz[..., 0])
123+
124+
return Jdet0 - Jdet1 + Jdet2
125+
126+
else: # must be 2
127+
128+
dfdx = J[0]
129+
dfdy = J[1]
130+
131+
return dfdx[..., 0] * dfdy[..., 1] - dfdy[..., 0] * dfdx[..., 1]
132+
133+
134+
def charbonnier_loss(flow_diff, alpha=0.45, beta=1.0, epsilon=0.01):
135+
normalization = int(flow_diff.numel())
136+
error = torch.pow(torch.square(flow_diff * beta) + epsilon, alpha)
137+
return torch.sum(error) / normalization

0 commit comments

Comments
 (0)