Skip to content

Commit 971405c

Browse files
authored
Merge pull request #51 from minerva-ml/dev
Evaluation in chunks added, erosion pre - dilation post approach added, multiclass problem definition enabled
2 parents 5d6da2a + 7210a61 commit 971405c

9 files changed

+198
-34
lines changed

main.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
ctx = neptune.Context()
2121
params = read_params(ctx)
2222

23-
set_seed(1234)
23+
seed = 1234
24+
set_seed(seed)
2425

2526

2627
@click.group()
@@ -37,6 +38,7 @@ def prepare_metadata(train_data, valid_data, test_data, public_paths):
3738
logger.info('creating metadata')
3839
meta = generate_metadata(data_dir=params.data_dir,
3940
masks_overlayed_dir=params.masks_overlayed_dir,
41+
masks_overlayed_eroded_dir=params.masks_overlayed_eroded_dir,
4042
competition_stage=params.competition_stage,
4143
process_train_data=train_data,
4244
process_validation_data=valid_data,
@@ -49,14 +51,22 @@ def prepare_metadata(train_data, valid_data, test_data, public_paths):
4951

5052

5153
@action.command()
52-
def prepare_masks():
54+
@click.option('-d', '--dev_mode', help='if true only a small sample of data will be used', is_flag=True, required=False)
55+
def prepare_masks(dev_mode):
56+
if params.erode_selem_size > 0:
57+
erode = params.erode_selem_size
58+
target_dir = params.masks_overlayed_eroded_dir
59+
else:
60+
erode = 0
61+
target_dir = params.masks_overlayed_dir
5362
for dataset in ["train", "val"]:
5463
logger.info('Overlaying masks, dataset: {}'.format(dataset))
5564
overlay_masks(data_dir=params.data_dir,
5665
dataset=dataset,
57-
target_dir=params.masks_overlayed_dir,
66+
target_dir=target_dir,
5867
category_ids=CATEGORY_IDS,
59-
is_small=False)
68+
erode=erode,
69+
is_small=dev_mode)
6070

6171

6272
@action.command()
@@ -77,8 +87,8 @@ def _train(pipeline_name, dev_mode):
7787
meta_valid = meta[meta['is_valid'] == 1]
7888

7989
if dev_mode:
80-
meta_train = meta_train.sample(20, random_state=1234)
81-
meta_valid = meta_valid.sample(10, random_state=1234)
90+
meta_train = meta_train.sample(20, random_state=seed)
91+
meta_valid = meta_valid.sample(10, random_state=seed)
8292

8393
data = {'input': {'meta': meta_train,
8494
'meta_valid': meta_valid,
@@ -108,7 +118,7 @@ def _evaluate(pipeline_name, dev_mode, chunk_size):
108118
meta_valid = meta[meta['is_valid'] == 1]
109119

110120
if dev_mode:
111-
meta_valid = meta_valid.sample(30, random_state=1234)
121+
meta_valid = meta_valid.sample(30, random_state=seed)
112122

113123
pipeline = PIPELINES[pipeline_name]['inference'](SOLUTION_CONFIG)
114124
prediction = generate_prediction(meta_valid, pipeline, logger, CATEGORY_IDS, chunk_size)
@@ -146,7 +156,7 @@ def _predict(pipeline_name, dev_mode, submit_predictions, chunk_size):
146156
meta_test = meta[meta['is_test'] == 1]
147157

148158
if dev_mode:
149-
meta_test = meta_test.sample(2, random_state=1234)
159+
meta_test = meta_test.sample(2, random_state=seed)
150160

151161
pipeline = PIPELINES[pipeline_name]['inference'](SOLUTION_CONFIG)
152162
prediction = generate_prediction(meta_test, pipeline, logger, CATEGORY_IDS, chunk_size)
@@ -161,6 +171,7 @@ def _predict(pipeline_name, dev_mode, submit_predictions, chunk_size):
161171
if submit_predictions:
162172
_make_submission(submission_filepath)
163173

174+
164175
@action.command()
165176
@click.option('-p', '--pipeline_name', help='pipeline to be trained', required=True)
166177
@click.option('-s', '--submit_predictions', help='submit predictions if true', is_flag=True, required=False)
@@ -262,4 +273,3 @@ def _generate_prediction_in_chunks(meta_data, pipeline, logger, category_ids, ch
262273

263274
if __name__ == "__main__":
264275
action()
265-

models.py

+45
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
from torch.autograd import Variable
13
from torch import optim
24

35
from callbacks import NeptuneMonitorSegmentation
@@ -25,6 +27,49 @@ def transform(self, datagen, validation_datagen=None):
2527
outputs[name] = softmax(prediction, axis=1)
2628
return outputs
2729

30+
class PyTorchUNetStream(Model):
31+
def __init__(self, architecture_config, training_config, callbacks_config):
32+
super().__init__(architecture_config, training_config, callbacks_config)
33+
self.model = UNet(**architecture_config['model_params'])
34+
self.weight_regularization = weight_regularization_unet
35+
self.optimizer = optim.Adam(self.weight_regularization(self.model, **architecture_config['regularizer_params']),
36+
**architecture_config['optimizer_params'])
37+
self.loss_function = [('multichannel_map', multiclass_segmentation_loss, 1.0)]
38+
self.callbacks = callbacks_unet(self.callbacks_config)
39+
40+
def transform(self, datagen, validation_datagen=None):
41+
if len(self.output_names) == 1:
42+
output_generator = self._transform(datagen, validation_datagen)
43+
output = {'{}_prediction'.format(self.output_names[0]): output_generator}
44+
return output
45+
else:
46+
raise NotImplementedError
47+
48+
def _transform(self, datagen, validation_datagen=None):
49+
self.model.eval()
50+
batch_gen, steps = datagen
51+
for batch_id, data in enumerate(batch_gen):
52+
if isinstance(data, list):
53+
X = data[0]
54+
else:
55+
X = data
56+
57+
if torch.cuda.is_available():
58+
X = Variable(X, volatile=True).cuda()
59+
else:
60+
X = Variable(X, volatile=True)
61+
62+
outputs_batch = self.model(X)
63+
outputs_batch = outputs_batch.data.cpu().numpy()
64+
65+
for output in outputs_batch:
66+
output = softmax(output, axis=0)
67+
yield output
68+
69+
if batch_id == steps:
70+
break
71+
self.model.train()
72+
2873

2974
def weight_regularization(model, regularize, weight_decay_conv2d, weight_decay_linear):
3075
if regularize:

neptune.yaml

+10-8
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,23 @@ parameters:
2525
# data_dir: /YOUR_PATH_TO_DATA_ON_CLOUD
2626
# meta_dir: /YOUR_PATH_TO_DATA_ON_CLOUD
2727
# masks_overlayed_dir: /YOUR_PATH_TO_DATA_ON_CLOUD/masks_overlayed
28+
# masks_overlayed_eroded_dir: /YOUR_PATH_TO_DATA_ON_CLOUD/masks_overlayed_eroded/
2829
# experiment_dir: /YOUR_PATH_TO_OUTPUT_FOLDER_ON_CLOUD/experiments
2930

3031
# Local Environment
31-
data_dir: /path/to/data
32-
meta_dir: /path/to/data
33-
masks_overlayed_dir: /path/to/masks_overlayed
34-
experiment_dir: /path/to/work/dir
32+
data_dir: /path/to/data
33+
meta_dir: /path/to/data
34+
masks_overlayed_dir: /path/to/masks_overlayed
35+
masks_overlayed_eroded_dir: /path/to/masks_overlayed_eroded
36+
experiment_dir: /path/to/work/dir
3537

3638
overwrite: 0
3739
num_workers: 4
3840
load_in_memory: 0
3941
pin_memory: 1
4042
competition_stage: 1
4143
api_key: YOUR_CROWDAI_API_KEY
44+
stream_mode: False
4245

4346
# General parameters
4447
image_h: 256
@@ -56,11 +59,8 @@ parameters:
5659

5760
# U-Net loss weights (multi-output)
5861
mask: 0.3
59-
contour: 0.5
60-
contour_touching: 0.1
61-
center: 0.1
6262
bce_mask: 1.0
63-
dice_mask: 1.0
63+
dice_mask: 2.0
6464

6565
# Training schedule
6666
epochs_nr: 100
@@ -83,6 +83,8 @@ parameters:
8383
# Postprocessing
8484
threshold: 0.5
8585
min_nuclei_size: 20
86+
erode_selem_size: 5
87+
dilate_selem_size: 5
8688

8789
#Neptune monitor
8890
unet_outputs_to_plot: '["multichannel_map",]'

pipeline_config.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
SIZE_COLUMNS = ['height', 'width']
1212
X_COLUMNS = ['file_path_image']
13-
Y_COLUMNS = ['file_path_mask']
13+
Y_COLUMNS = ['file_path_mask_eroded']
1414
Y_COLUMNS_SCORING = ['ImageId']
1515
CATEGORY_IDS = [None, 100]
1616

@@ -20,7 +20,8 @@
2020
'num_classes': 2,
2121
'img_H-W': (params.image_h, params.image_w),
2222
'batch_size_train': params.batch_size_train,
23-
'batch_size_inference': params.batch_size_inference
23+
'batch_size_inference': params.batch_size_inference,
24+
'stream_mode': params.stream_mode
2425
}
2526

2627
SOLUTION_CONFIG = AttrDict({
@@ -95,5 +96,5 @@
9596
},
9697
},
9798
'dropper': {'min_size': params.min_nuclei_size},
98-
'postprocessor': {}
99+
'postprocessor': {'dilate_selem_size': params.dilate_selem_size}
99100
})

pipelines.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from functools import partial
22

33
import loaders
4-
from models import PyTorchUNet
5-
from postprocessing import Resizer, CategoryMapper, MulticlassLabeler
64
from steps.base import Step, Dummy
75
from steps.preprocessing.misc import XYSplit
86
from utils import squeeze_inputs
7+
from models import PyTorchUNet, PyTorchUNetStream
8+
from postprocessing import Resizer, CategoryMapper, MulticlassLabeler, MaskDilator, \
9+
ResizerStream, CategoryMapperStream, MulticlassLabelerStream, MaskDilatorStream
910

1011

1112
def unet(config, train_mode):
@@ -18,12 +19,24 @@ def unet(config, train_mode):
1819

1920
loader = preprocessing(config, model_type='single', is_train=train_mode)
2021
unet = Step(name='unet',
21-
transformer=PyTorchUNet(**config.unet),
22+
transformer=PyTorchUNetStream(**config.unet) if config.execution.stream_mode else PyTorchUNet(
23+
**config.unet),
2224
input_steps=[loader],
2325
cache_dirpath=config.env.cache_dirpath,
2426
save_output=save_output, load_saved_output=load_saved_output)
2527

2628
mask_postprocessed = mask_postprocessing(unet, config, save_output=save_output)
29+
if config.postprocessor["dilate_selem_size"] > 0:
30+
mask_postprocessed = Step(name='mask_dilation',
31+
transformer=MaskDilatorStream(
32+
**config.postprocessor) if config.execution.stream_mode else MaskDilator(
33+
**config.postprocessor),
34+
input_steps=[mask_postprocessed],
35+
adapter={'images': ([(mask_postprocessed.name, 'categorized_images')]),
36+
},
37+
cache_dirpath=config.env.cache_dirpath,
38+
save_output=save_output,
39+
load_saved_output=False)
2740
detached = multiclass_object_labeler(mask_postprocessed, config, save_output=save_output)
2841
output = Step(name='output',
2942
transformer=Dummy(),
@@ -45,9 +58,10 @@ def preprocessing(config, model_type, is_train, loader_mode=None):
4558
raise NotImplementedError
4659
return loader
4760

61+
4862
def multiclass_object_labeler(postprocessed_mask, config, save_output=True):
4963
labeler = Step(name='labeler',
50-
transformer=MulticlassLabeler(),
64+
transformer=MulticlassLabelerStream() if config.execution.stream_mode else MulticlassLabeler(),
5165
input_steps=[postprocessed_mask],
5266
adapter={'images': ([(postprocessed_mask.name, 'categorized_images')]),
5367
},
@@ -164,7 +178,7 @@ def _preprocessing_multitask_generator(config, is_train, use_patching):
164178

165179
def mask_postprocessing(model, config, save_output=False):
166180
mask_resize = Step(name='mask_resize',
167-
transformer=Resizer(),
181+
transformer=ResizerStream() if config.execution.stream_mode else Resizer(),
168182
input_data=['input'],
169183
input_steps=[model],
170184
adapter={'images': ([(model.name, 'multichannel_map_prediction')]),
@@ -173,7 +187,7 @@ def mask_postprocessing(model, config, save_output=False):
173187
cache_dirpath=config.env.cache_dirpath,
174188
save_output=save_output)
175189
category_mapper = Step(name='category_mapper',
176-
transformer=CategoryMapper(),
190+
transformer=CategoryMapperStream() if config.execution.stream_mode else CategoryMapper(),
177191
input_steps=[mask_resize],
178192
adapter={'images': ([('mask_resize', 'resized_images')]),
179193
},

postprocessing.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from scipy import ndimage as ndi
33
from skimage.transform import resize
4+
from skimage.morphology import binary_dilation, rectangle
45
from tqdm import tqdm
56

67
from steps.base import BaseTransformer
@@ -34,6 +35,59 @@ def transform(self, images):
3435
return {'categorized_images': categorized_images}
3536

3637

38+
class MaskDilator(BaseTransformer):
39+
def __init__(self, dilate_selem_size):
40+
self.selem_size = dilate_selem_size
41+
42+
def transform(self, images):
43+
dilated_images = []
44+
for image in tqdm(images):
45+
dilated_images.append(dilate_image(image, self.selem_size))
46+
return {'categorized_images': dilated_images}
47+
48+
49+
class MulticlassLabelerStream(BaseTransformer):
50+
def transform(self, images):
51+
return {'labeled_images': self._transform(images)}
52+
53+
def _transform(self, images):
54+
for i, image in enumerate(images):
55+
labeled_image = label_multiclass_image(image)
56+
yield labeled_image
57+
58+
59+
class ResizerStream(BaseTransformer):
60+
def transform(self, images, target_sizes):
61+
return {'resized_images': self._transform(images, target_sizes)}
62+
63+
def _transform(self, images, target_sizes):
64+
for image, target_size in tqdm(zip(images, target_sizes)):
65+
n_channels = image.shape[0]
66+
resized_image = resize(image, (n_channels,) + target_size, mode='constant')
67+
yield resized_image
68+
69+
70+
class CategoryMapperStream(BaseTransformer):
71+
def transform(self, images):
72+
return {'categorized_images': self._transform(images)}
73+
74+
def _transform(self, images):
75+
for image in tqdm(images):
76+
yield categorize_image(image)
77+
78+
79+
class MaskDilatorStream(BaseTransformer):
80+
def __init__(self, dilate_selem_size):
81+
self.selem_size = dilate_selem_size
82+
83+
def transform(self, images):
84+
return {'categorized_images': self._transform(images)}
85+
86+
def _transform(self, images):
87+
for image in tqdm(images):
88+
yield dilate_image(image, self.selem_size)
89+
90+
3791
def label(mask):
3892
labeled, nr_true = ndi.label(mask)
3993
return labeled
@@ -45,3 +99,8 @@ def label_multiclass_image(mask):
4599
labeled_channels.append(label(mask == label_nr))
46100
labeled_image = np.stack(labeled_channels)
47101
return labeled_image
102+
103+
104+
def dilate_image(mask, selem_size):
105+
selem = rectangle(selem_size, selem_size)
106+
return binary_dilation(mask, selem=selem)

0 commit comments

Comments
 (0)