Skip to content

Commit e4221fd

Browse files
committed
init
1 parent 199f8f0 commit e4221fd

21 files changed

+3673
-0
lines changed

att_utils.py

+634
Large diffs are not rendered by default.

config.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import imgaug # https://github.com/aleju/imgaug
2+
from imgaug import augmenters as iaa
3+
import imgaug as ia
4+
import os
5+
6+
####
7+
class Config(object):
8+
def __init__(self, _args=None):
9+
if _args is not None:
10+
self.__dict__.update(_args.__dict__)
11+
self.seed = self.seed
12+
self.init_lr = 1.0e-4
13+
self.lr_steps = 20 # decrease at every n-th epoch
14+
self.gamma = 0.2
15+
self.train_batch_size = 64
16+
self.infer_batch_size = 1
17+
self.nr_classes = 3
18+
self.nr_epochs = 60
19+
self.epoch_length = 50
20+
21+
# nr of processes for parallel processing input
22+
self.nr_procs_train = 8
23+
self.nr_procs_valid = 8
24+
25+
self.nr_fold = 5
26+
self.fold_idx = 0
27+
self.cross_valid = False
28+
29+
self.load_network = False
30+
self.save_net_path = ""
31+
32+
#
33+
self.dataset = 'colon_manual'
34+
self.logging = True # True for debug run only
35+
36+
self.log_path = '/data4/doanhbc/ViTOnly_prostate_hv/'
37+
if not os.path.exists(self.log_path):
38+
os.makedirs(self.log_path, exist_ok=True)
39+
40+
self.chkpts_prefix = 'model'
41+
if _args is not None:
42+
self.__dict__.update(_args.__dict__)
43+
self.task_type = self.run_info.split('_')[0]
44+
self.loss_type = self.run_info.replace(self.task_type + "_", "")
45+
self.model_name = f'/{self.task_type}_{self.loss_type}_cancer_Effi_seed{self.seed}_BS64'
46+
self.log_dir = self.log_path + self.model_name
47+
print(self.model_name)
48+
49+
def train_augmentors(self):
50+
if self.dataset == "prostate_hv":
51+
shape_augs = [
52+
iaa.Resize(0.5, interpolation='nearest'),
53+
iaa.CropToFixedSize(width=350, height=350),
54+
]
55+
else:
56+
shape_augs = [
57+
# iaa.Resize(dict(height=384, width=384), interpolation='nearest')
58+
iaa.Resize(dict(height=384, width=384), interpolation='nearest')
59+
]
60+
#
61+
sometimes = lambda aug: iaa.Sometimes(0.2, aug)
62+
input_augs = iaa.Sequential(
63+
[
64+
# apply the following augmenters to most images
65+
iaa.Fliplr(0.5), # horizontally flip 50% of all images
66+
iaa.Flipud(0.5), # vertically flip 50% of all images
67+
sometimes(iaa.Affine(
68+
rotate=(-45, 45), # rotate by -45 to +45 degrees
69+
shear=(-16, 16), # shear by -16 to +16 degrees
70+
order=[0, 1], # use nearest neighbour or bilinear interpolation (fast)
71+
cval=(0, 255), # if mode is constant, use a cval between 0 and 255
72+
mode='symmetric'
73+
# use any of scikit-image's warping modes (see 2nd image from the top for examples)
74+
)),
75+
# execute 0 to 5 of the following (less important) augmenters per image
76+
# don't execute all of them, as that would often be way too strong
77+
iaa.SomeOf((0, 5),
78+
[
79+
iaa.OneOf([
80+
iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0
81+
iaa.AverageBlur(k=(2, 7)),
82+
# blur image using local means with kernel sizes between 2 and 7
83+
iaa.MedianBlur(k=(3, 11)),
84+
# blur image using local medians with kernel sizes between 2 and 7
85+
]),
86+
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
87+
# add gaussian noise to images
88+
iaa.Dropout((0.01, 0.1), per_channel=0.5), # randomly remove up to 10% of the pixels
89+
# change brightness of images (by -10 to 10 of original value)
90+
iaa.AddToHueAndSaturation((-20, 20)), # change hue and saturation
91+
iaa.LinearContrast((0.5, 2.0), per_channel=0.5), # improve or worsen the contrast
92+
],
93+
random_order=True
94+
)
95+
],
96+
random_order=True
97+
)
98+
return shape_augs, input_augs
99+
100+
####
101+
def infer_augmentors(self):
102+
if self.dataset == "prostate_hv":
103+
shape_augs = [
104+
iaa.Resize(0.5, interpolation='nearest'),
105+
iaa.CropToFixedSize(width=350, height=350, position="center"),
106+
]
107+
else:
108+
shape_augs = [
109+
# iaa.Resize(dict(height=384, width=384), interpolation='nearest')
110+
iaa.Resize(dict(height=384, width=384), interpolation='nearest')
111+
]
112+
return shape_augs, None
113+
114+
###########################################################################

0 commit comments

Comments
 (0)