1
+ import imgaug # https://github.com/aleju/imgaug
2
+ from imgaug import augmenters as iaa
3
+ import imgaug as ia
4
+ import os
5
+
6
+ ####
7
+ class Config (object ):
8
+ def __init__ (self , _args = None ):
9
+ if _args is not None :
10
+ self .__dict__ .update (_args .__dict__ )
11
+ self .seed = self .seed
12
+ self .init_lr = 1.0e-4
13
+ self .lr_steps = 20 # decrease at every n-th epoch
14
+ self .gamma = 0.2
15
+ self .train_batch_size = 64
16
+ self .infer_batch_size = 1
17
+ self .nr_classes = 3
18
+ self .nr_epochs = 60
19
+ self .epoch_length = 50
20
+
21
+ # nr of processes for parallel processing input
22
+ self .nr_procs_train = 8
23
+ self .nr_procs_valid = 8
24
+
25
+ self .nr_fold = 5
26
+ self .fold_idx = 0
27
+ self .cross_valid = False
28
+
29
+ self .load_network = False
30
+ self .save_net_path = ""
31
+
32
+ #
33
+ self .dataset = 'colon_manual'
34
+ self .logging = True # True for debug run only
35
+
36
+ self .log_path = '/data4/doanhbc/ViTOnly_prostate_hv/'
37
+ if not os .path .exists (self .log_path ):
38
+ os .makedirs (self .log_path , exist_ok = True )
39
+
40
+ self .chkpts_prefix = 'model'
41
+ if _args is not None :
42
+ self .__dict__ .update (_args .__dict__ )
43
+ self .task_type = self .run_info .split ('_' )[0 ]
44
+ self .loss_type = self .run_info .replace (self .task_type + "_" , "" )
45
+ self .model_name = f'/{ self .task_type } _{ self .loss_type } _cancer_Effi_seed{ self .seed } _BS64'
46
+ self .log_dir = self .log_path + self .model_name
47
+ print (self .model_name )
48
+
49
+ def train_augmentors (self ):
50
+ if self .dataset == "prostate_hv" :
51
+ shape_augs = [
52
+ iaa .Resize (0.5 , interpolation = 'nearest' ),
53
+ iaa .CropToFixedSize (width = 350 , height = 350 ),
54
+ ]
55
+ else :
56
+ shape_augs = [
57
+ # iaa.Resize(dict(height=384, width=384), interpolation='nearest')
58
+ iaa .Resize (dict (height = 384 , width = 384 ), interpolation = 'nearest' )
59
+ ]
60
+ #
61
+ sometimes = lambda aug : iaa .Sometimes (0.2 , aug )
62
+ input_augs = iaa .Sequential (
63
+ [
64
+ # apply the following augmenters to most images
65
+ iaa .Fliplr (0.5 ), # horizontally flip 50% of all images
66
+ iaa .Flipud (0.5 ), # vertically flip 50% of all images
67
+ sometimes (iaa .Affine (
68
+ rotate = (- 45 , 45 ), # rotate by -45 to +45 degrees
69
+ shear = (- 16 , 16 ), # shear by -16 to +16 degrees
70
+ order = [0 , 1 ], # use nearest neighbour or bilinear interpolation (fast)
71
+ cval = (0 , 255 ), # if mode is constant, use a cval between 0 and 255
72
+ mode = 'symmetric'
73
+ # use any of scikit-image's warping modes (see 2nd image from the top for examples)
74
+ )),
75
+ # execute 0 to 5 of the following (less important) augmenters per image
76
+ # don't execute all of them, as that would often be way too strong
77
+ iaa .SomeOf ((0 , 5 ),
78
+ [
79
+ iaa .OneOf ([
80
+ iaa .GaussianBlur ((0 , 3.0 )), # blur images with a sigma between 0 and 3.0
81
+ iaa .AverageBlur (k = (2 , 7 )),
82
+ # blur image using local means with kernel sizes between 2 and 7
83
+ iaa .MedianBlur (k = (3 , 11 )),
84
+ # blur image using local medians with kernel sizes between 2 and 7
85
+ ]),
86
+ iaa .AdditiveGaussianNoise (loc = 0 , scale = (0.0 , 0.05 * 255 ), per_channel = 0.5 ),
87
+ # add gaussian noise to images
88
+ iaa .Dropout ((0.01 , 0.1 ), per_channel = 0.5 ), # randomly remove up to 10% of the pixels
89
+ # change brightness of images (by -10 to 10 of original value)
90
+ iaa .AddToHueAndSaturation ((- 20 , 20 )), # change hue and saturation
91
+ iaa .LinearContrast ((0.5 , 2.0 ), per_channel = 0.5 ), # improve or worsen the contrast
92
+ ],
93
+ random_order = True
94
+ )
95
+ ],
96
+ random_order = True
97
+ )
98
+ return shape_augs , input_augs
99
+
100
+ ####
101
+ def infer_augmentors (self ):
102
+ if self .dataset == "prostate_hv" :
103
+ shape_augs = [
104
+ iaa .Resize (0.5 , interpolation = 'nearest' ),
105
+ iaa .CropToFixedSize (width = 350 , height = 350 , position = "center" ),
106
+ ]
107
+ else :
108
+ shape_augs = [
109
+ # iaa.Resize(dict(height=384, width=384), interpolation='nearest')
110
+ iaa .Resize (dict (height = 384 , width = 384 ), interpolation = 'nearest' )
111
+ ]
112
+ return shape_augs , None
113
+
114
+ ###########################################################################
0 commit comments