Skip to content

Commit 2475b66

Browse files
committed
tes
0 parents  commit 2475b66

File tree

4 files changed

+458
-0
lines changed

4 files changed

+458
-0
lines changed

README.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# AtlasNet
2+
This is a pytorch version of the implementation code for article
3+
***Deep Atlas Network for Efficient 3D Left Ventricle Segmentation on Echocardiography***
4+
published above ***Medical Image Analysis***.
5+
6+
## install
7+
Use the following command to install the environment.
8+
9+
```angular2html
10+
conda env create -f atlasnet.yaml
11+
```
12+
13+
## dataset
14+
We use the CETUS dataset in our code, and the image size was changed to 128x128x128.
15+
16+
In the folder ./CETUS, we use some examples to show the format of the data.
17+
18+
## train/test
19+
After setting params in the ./Model/config.py
20+
Run train.py/test.py directly to start training/testing.

atlasnet.yaml

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
name: atlasnet
2+
channels:
3+
- pytorch
4+
- nvidia
5+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
6+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
7+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
8+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r/
9+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
10+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
11+
- defaults
12+
dependencies:
13+
- blas=1.0=mkl
14+
- brotli-python=1.0.9=py39hd77b12b_8
15+
- ca-certificates=2024.3.11=haa95532_0
16+
- certifi=2024.2.2=py39haa95532_0
17+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
18+
- cuda-cccl=12.4.127=0
19+
- cuda-cudart=12.1.105=0
20+
- cuda-cudart-dev=12.1.105=0
21+
- cuda-cupti=12.1.105=0
22+
- cuda-libraries=12.1.0=0
23+
- cuda-libraries-dev=12.1.0=0
24+
- cuda-nvrtc=12.1.105=0
25+
- cuda-nvrtc-dev=12.1.105=0
26+
- cuda-nvtx=12.1.105=0
27+
- cuda-opencl=12.4.127=0
28+
- cuda-opencl-dev=12.4.127=0
29+
- cuda-profiler-api=12.4.127=0
30+
- cuda-runtime=12.1.0=0
31+
- filelock=3.13.1=py39haa95532_0
32+
- freetype=2.4.10=0
33+
- gmpy2=2.1.2=py39h7f96b67_0
34+
- idna=3.7=py39haa95532_0
35+
- intel-openmp=2021.4.0=haa95532_3556
36+
- jinja2=3.1.3=py39haa95532_0
37+
- jpeg=8d=0
38+
- libcublas=12.1.0.26=0
39+
- libcublas-dev=12.1.0.26=0
40+
- libcufft=11.0.2.4=0
41+
- libcufft-dev=11.0.2.4=0
42+
- libcurand=10.3.5.147=0
43+
- libcurand-dev=10.3.5.147=0
44+
- libcusolver=11.4.4.55=0
45+
- libcusolver-dev=11.4.4.55=0
46+
- libcusparse=12.0.2.55=0
47+
- libcusparse-dev=12.0.2.55=0
48+
- libjpeg-turbo=2.0.0=h196d8e1_0
49+
- libnpp=12.0.2.50=0
50+
- libnpp-dev=12.0.2.50=0
51+
- libnvjitlink=12.1.105=0
52+
- libnvjitlink-dev=12.1.105=0
53+
- libnvjpeg=12.1.1.14=0
54+
- libnvjpeg-dev=12.1.1.14=0
55+
- libpng=1.6.17=0
56+
- libtiff=4.0.2=1
57+
- libuv=1.44.2=h2bbff1b_0
58+
- libwebp=1.2.4=h2bbff1b_0
59+
- libwebp-base=1.2.4=h2bbff1b_1
60+
- markupsafe=2.1.3=py39h2bbff1b_0
61+
- mkl=2021.4.0=haa95532_640
62+
- mkl-service=2.4.0=py39h2bbff1b_0
63+
- mkl_fft=1.3.1=py39h277e83a_0
64+
- mkl_random=1.2.2=py39hf11a4ad_0
65+
- mpc=1.1.0=h7edee0f_1
66+
- mpfr=4.0.2=h62dcd97_1
67+
- mpir=3.0.0=hec2e145_1
68+
- mpmath=1.3.0=py39haa95532_0
69+
- networkx=3.1=py39haa95532_0
70+
- numpy=1.24.3=py39hf95b240_0
71+
- numpy-base=1.24.3=py39h005ec55_0
72+
- openssl=3.0.13=h2bbff1b_1
73+
- pillow=9.3.0=py39hdc2b20a_1
74+
- pip=24.0=py39haa95532_0
75+
- pysocks=1.7.1=py39haa95532_0
76+
- python=3.9.19=h1aa4202_1
77+
- pytorch=2.3.0=py3.9_cuda12.1_cudnn8_0
78+
- pytorch-cuda=12.1=hde6ce7c_5
79+
- pytorch-mutex=1.0=cuda
80+
- pyyaml=6.0.1=py39h2bbff1b_0
81+
- requests=2.31.0=py39haa95532_1
82+
- setuptools=69.5.1=py39haa95532_0
83+
- six=1.16.0=pyhd3eb1b0_1
84+
- sqlite=3.45.3=h2bbff1b_0
85+
- sympy=1.12=py39haa95532_0
86+
- tk=8.6.14=h0416ee5_0
87+
- typing_extensions=4.11.0=py39haa95532_0
88+
- tzdata=2024a=h04d1e81_0
89+
- urllib3=2.2.1=py39haa95532_0
90+
- vc=14.2=h21ff451_1
91+
- vs2015_runtime=14.27.29016=h5e58377_2
92+
- wheel=0.43.0=py39haa95532_0
93+
- win_inet_pton=1.1.0=py39haa95532_0
94+
- yaml=0.2.5=he774522_0
95+
- zlib=1.2.13=h8cc25b3_1
96+
- pip:
97+
- contourpy==1.2.1
98+
- cycler==0.12.1
99+
- fonttools==4.51.0
100+
- importlib-resources==6.4.0
101+
- kiwisolver==1.4.5
102+
- matplotlib==3.9.0
103+
- packaging==24.0
104+
- pyparsing==3.1.2
105+
- pystrum==0.4
106+
- python-dateutil==2.9.0.post0
107+
- scipy==1.13.0
108+
- torchaudio==2.3.0
109+
- torchvision==0.18.0
110+
- zipp==3.18.1
111+
prefix: D:\anaconda\envs\atlasnet

test.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.utils.data as Data
5+
from Model import losses
6+
from Model.config import Config as args
7+
from Model.dataset import CETUS
8+
from Model.model import U_Network, SpatialTransformer, TransformerNet, AffineCOMTransform
9+
import torch.nn.functional as F
10+
import warnings
11+
12+
warnings.filterwarnings('ignore')
13+
14+
15+
def test():
16+
device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
17+
18+
test_ds = CETUS(args.data_path, 'test')
19+
vol_size = test_ds[0]['es'].shape[1:]
20+
test_dl = Data.DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
21+
22+
# create affine
23+
affine_transform = AffineCOMTransform(device)
24+
25+
# create Unet
26+
nf_enc = [16, 32, 32, 32]
27+
if args.model == "vm1":
28+
nf_dec = [32, 32, 32, 32, 8, 8]
29+
else:
30+
nf_dec = [32, 32, 32, 32, 32, 16, 16]
31+
32+
# create STN
33+
STN_label = SpatialTransformer(vol_size, mode='nearest').to(device)
34+
35+
dice_fn = losses.compute_label_dice
36+
37+
best_UNet_model = os.path.join(args.checkpoint_path, args.saved_unet_name)
38+
best_UNet = U_Network(len(vol_size), nf_enc, nf_dec).to(device)
39+
best_UNet.load_state_dict(torch.load(best_UNet_model))
40+
best_UNet.eval()
41+
42+
best_tnet_model = os.path.join(args.checkpoint_path, args.saved_tnet_name)
43+
best_tnet = TransformerNet().to(device)
44+
best_tnet.load_state_dict(torch.load(best_tnet_model))
45+
best_tnet.eval()
46+
with torch.no_grad():
47+
dice_list = []
48+
dice_before_list = []
49+
jac_list = []
50+
for test_iter_, test_d in enumerate(test_dl):
51+
m, f, ml, fl = test_d['es'], test_d['ed'], test_d['es_gt'], test_d['ed_gt']
52+
# [B, C, W, H]
53+
moving_label = ml.to(device).float()
54+
fixed_label = fl.to(device).float()
55+
moving = m.to(device).float()
56+
fixed = f.to(device).float()
57+
58+
# Run the data through the model to produce warp and flow field
59+
affine_param = best_tnet(moving, fixed)
60+
affine_moving, affine_matrix = affine_transform(moving, affine_param)
61+
62+
affine_moving_label = F.grid_sample(moving_label,
63+
F.affine_grid(affine_matrix, moving_label.shape,
64+
align_corners=True), mode="nearest",
65+
align_corners=True)
66+
67+
flow_m2f = best_UNet(affine_moving, fixed)
68+
m2f_label = STN_label(affine_moving_label, flow_m2f)
69+
70+
# Calculate dice score
71+
dice_score = dice_fn(m2f_label, fixed_label)
72+
dice_list.append(dice_score.item())
73+
dice_before_list.append(dice_fn(moving_label, fixed_label).item())
74+
75+
tar = moving.detach().cpu().numpy()[0, 0, ...]
76+
jac_den = np.prod(tar.shape)
77+
for flow_item in flow_m2f:
78+
jac_det = losses.jacobian_determinant(flow_item.detach().cpu().numpy())
79+
jac_list.append(np.sum(jac_det <= 0) / jac_den)
80+
81+
mean_dice = np.array(dice_list).mean()
82+
before_mean_dice = np.array(dice_before_list).mean()
83+
print(f'test dice: {mean_dice:.5f}, original dice: {before_mean_dice:.5f}')
84+
print(f'test jacob mean: {np.array(jac_list).mean()}, jacob std: {np.array(jac_list).std()}')
85+
86+
87+
if __name__ == "__main__":
88+
args = args()
89+
test()

0 commit comments

Comments
 (0)