Skip to content

Commit 9269610

Browse files
ditwrdAditya Wardianto
and
Aditya Wardianto
authored
Hotfix/181 hotfix for supercomp (#182)
* fix: remove agatston zero * fix: added missing img_arr for extract dcm in ground truth cac calculation * fix: remove comment for writing tfrecord * ref: use base net to remove mobilenet backend * ref: use dice_coef * ref: speedtest for basic * fix: remove unused imports * fix: custom objects * chore: loss cleanup * chore: model cleanup --------- Co-authored-by: Aditya Wardianto <[email protected]>
1 parent ace0bee commit 9269610

File tree

8 files changed

+69
-179
lines changed

8 files changed

+69
-179
lines changed

notebooks/speed.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pydicom as pdc
99
import sklearn.metrics as skm
1010
import tensorflow as tf
11-
from keras.utils.layer_utils import count_params
1211
from tqdm import tqdm
1312

1413
sys.path.append(pathlib.Path.cwd().parent.as_posix())
@@ -17,8 +16,12 @@
1716
from src.models.lib.builder import build_unet_pp
1817
from src.models.lib.config import UNetPPConfig
1918
from src.models.lib.data_loader import create_dataset, preprocess_img
20-
from src.models.lib.loss import (dice_coef, dice_coef_nosq, log_cosh_dice_loss,
21-
log_cosh_dice_loss_nosq)
19+
from src.models.lib.loss import (
20+
dice_coef,
21+
dice_coef_nosq,
22+
log_cosh_dice_loss,
23+
log_cosh_dice_loss_nosq,
24+
)
2225
from src.models.lib.utils import loss_dict_gen
2326
from src.system.pipeline.output import auto_cac, ground_truth_auto_cac
2427

@@ -36,13 +39,13 @@
3639
selected_model_path,
3740
custom_objects={
3841
"log_cosh_dice_loss": loss_func,
39-
"dice_coef_nosq": dice_coef_nosq,
42+
"dice_coef": dice_coef,
4043
},
4144
)
4245

4346
model_depth = 5
4447
depth = int(sys.argv[3])
45-
filter_list = [16, 32, 64, 128, 256]
48+
# filter_list = [16, 32, 64, 128, 256]
4649

4750

4851
pruned_model = {}
@@ -53,16 +56,29 @@
5356

5457
model_config = UNetPPConfig(
5558
model_name=f"model_d{depth}",
56-
upsample_mode="transpose",
57-
depth=depth + 1,
5859
input_dim=[512, 512, 1],
5960
batch_norm=True,
60-
deep_supervision=False,
6161
model_mode="basic",
62+
depth=5,
6263
n_class={"bin": 1},
63-
filter_list=filter_list[: depth + 1],
64+
deep_supervision=False,
65+
upsample_mode="transpose",
66+
filter_list=[32, 64, 128, 256, 512],
6467
)
6568

69+
70+
# model_config = UNetPPConfig(
71+
# model_name=f"model_d{depth}",
72+
# upsample_mode="transpose",
73+
# depth=depth + 1,
74+
# input_dim=[512, 512, 1],
75+
# batch_norm=True,
76+
# deep_supervision=False,
77+
# model_mode="basic",
78+
# n_class={"bin": 1},
79+
# filter_list=filter_list[: depth + 1],
80+
# )
81+
6682
model, output_layer_name = build_unet_pp(model_config, custom=True)
6783

6884
print(f"-- Creating pruned model d{depth}")

src/data/preprocess/lib/image.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,11 @@
77
sys.path.append(pathlib.Path.cwd().as_posix())
88

99
from src.data.preprocess.lib.utils import ( # pylint: disable=wrong-import-position,import-error
10-
artery_loc_to_abbr,
11-
blacklist_agatston_zero,
12-
blacklist_invalid_dicom,
13-
blacklist_mislabelled_roi,
14-
blacklist_multiple_image_id,
15-
blacklist_multiple_image_id_with_roi,
16-
blacklist_neg_reverse_index,
17-
blacklist_no_image,
18-
blacklist_pixel_overlap,
19-
convert_abr_to_num,
20-
fill_segmentation,
21-
string_to_float_tuple,
22-
string_to_int_tuple,
23-
)
10+
artery_loc_to_abbr, blacklist_agatston_zero, blacklist_invalid_dicom,
11+
blacklist_mislabelled_roi, blacklist_multiple_image_id,
12+
blacklist_multiple_image_id_with_roi, blacklist_neg_reverse_index,
13+
blacklist_no_image, blacklist_pixel_overlap, convert_abr_to_num,
14+
fill_segmentation, string_to_float_tuple, string_to_int_tuple)
2415

2516

2617
def extract_patient_dicom_path(gated_path: pathlib.Path):
@@ -54,7 +45,6 @@ def extract_patient_dicom_path(gated_path: pathlib.Path):
5445
or patient_number in blacklist_invalid_dicom()
5546
or patient_number in blacklist_no_image()
5647
or patient_number in blacklist_neg_reverse_index()
57-
or patient_number in blacklist_agatston_zero()
5848
):
5949
continue
6050

src/data/preprocess/pipeline/tfrecord.py

+30-44
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,16 @@
1111

1212
sys.path.append(pathlib.Path.cwd().as_posix())
1313

14-
from src.data.preprocess.lib.tfrecord import (
15-
create_example_fn,
16-
) # pylint: disable=wrong-import-position,import-error
14+
from src.data.preprocess.lib.tfrecord import \
15+
create_example_fn # pylint: disable=wrong-import-position,import-error
1716
from src.data.preprocess.lib.utils import ( # pylint: disable=wrong-import-position,import-error
18-
artery_loc_to_abbr,
19-
blacklist_agatston_zero,
20-
blacklist_invalid_dicom,
21-
blacklist_mislabelled_roi,
22-
blacklist_multiple_image_id,
23-
blacklist_multiple_image_id_with_roi,
24-
blacklist_neg_reverse_index,
25-
blacklist_no_image,
26-
blacklist_pixel_overlap,
27-
convert_abr_to_num,
28-
fill_segmentation,
29-
get_patient_split,
30-
get_pos_from_bin_list,
31-
get_pos_from_mult_list,
32-
split_list,
33-
string_to_float_tuple,
34-
string_to_int_tuple,
35-
)
17+
artery_loc_to_abbr, blacklist_agatston_zero, blacklist_invalid_dicom,
18+
blacklist_mislabelled_roi, blacklist_multiple_image_id,
19+
blacklist_multiple_image_id_with_roi, blacklist_neg_reverse_index,
20+
blacklist_no_image, blacklist_pixel_overlap, convert_abr_to_num,
21+
fill_segmentation, get_patient_split, get_pos_from_bin_list,
22+
get_pos_from_mult_list, split_list, string_to_float_tuple,
23+
string_to_int_tuple)
3624

3725

3826
def combine_to_tfrecord(
@@ -188,12 +176,12 @@ def combine_to_tfrecord(
188176
+ 512 * 512
189177
- patient_dict["mult_seg"].shape[0]
190178
)
191-
# patient_dict["img"] = indexer[patient_index]["img"][
192-
# img_index
193-
# ]["img_hu"][:]
194-
#
195-
# example = create_example_fn(patient_dict)
196-
# tf_record_file.write(example.SerializeToString())
179+
patient_dict["img"] = indexer[patient_index]["img"][
180+
img_index
181+
]["img_hu"][:]
182+
183+
example = create_example_fn(patient_dict)
184+
tf_record_file.write(example.SerializeToString())
197185
else:
198186
log_key = f"{split_mode}-img-non-cac"
199187
if split_mode == "train":
@@ -216,29 +204,27 @@ def combine_to_tfrecord(
216204
)
217205
+ 512 * 512
218206
)
219-
# patient_dict["img"] = indexer[
220-
# patient_index
221-
# ]["img"][img_index]["img_hu"][:]
222-
# #
223-
# example = create_example_fn(
224-
# patient_dict
225-
# )
226-
# tf_record_file.write(
227-
# example.SerializeToString()
228-
# )
207+
patient_dict["img"] = indexer[patient_index][
208+
"img"
209+
][img_index]["img_hu"][:]
210+
#
211+
example = create_example_fn(patient_dict)
212+
tf_record_file.write(
213+
example.SerializeToString()
214+
)
229215
else:
230216
log[log_key] = log.get(log_key, 0) + 1
231217
log[log_key + " non_cac_pixel"] = (
232218
log.get(log_key + " non_cac_pixel", 0)
233219
+ 512 * 512
234220
)
235-
# patient_dict["img"] = indexer[patient_index][
236-
# "img"
237-
# ][img_index]["img_hu"][:]
238-
# example = create_example_fn(patient_dict)
239-
# tf_record_file.write(
240-
# example.SerializeToString()
241-
# )
221+
patient_dict["img"] = indexer[patient_index][
222+
"img"
223+
][img_index]["img_hu"][:]
224+
example = create_example_fn(patient_dict)
225+
tf_record_file.write(
226+
example.SerializeToString()
227+
)
242228

243229
# Over sample algorithmm
244230
# CAC = 2391

src/models/lib/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def base_unet_pp(config: UNetPPConfig):
175175
output_lists[index] for index in non_deep_supervision_output_index[1:]
176176
],
177177
),
178-
output_layer_name[-1]
178+
[output_layer_name[-1]]
179179
if n_head == 1
180180
else [
181181
output_layer_name[index] for index in non_deep_supervision_output_index[1:]

src/models/lib/builder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import pathlib
33
import sys
44

5-
from tensorflow import keras # pylint: disable=wrong-import-position,import-error
5+
from tensorflow import \
6+
keras # pylint: disable=wrong-import-position,import-error
67

78
sys.path.append(pathlib.Path.cwd().parent.as_posix())
89
from src.models.lib.base import base_unet_pp, unetpp_mobile_backend
@@ -37,6 +38,7 @@ def build_unet_pp(config: UNetPPConfig, custom: bool = False) -> keras.Model:
3738
else:
3839
raise ValueError(f"Invalid model mode: {config.model_mode}")
3940

41+
# return base_unet_pp(config)
4042
return unetpp_mobile_backend(config)
4143

4244
if config.model_mode == "basic":

src/models/lib/loss.py

-105
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,6 @@
44

55

66
def categorical_focal_loss(alpha=0.25, gamma=2.0):
7-
"""
8-
https://github.com/umbertogriffo/focal-loss-keras
9-
10-
Softmax version of focal loss.
11-
When there is a skew between different categories/labels in your data set, you can try to apply this function as a
12-
loss.
13-
m
14-
FL = ∑ -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
15-
c=1
16-
17-
where m = number of classes, c = class and o = observation
18-
19-
Parameters:
20-
alpha -- the same as weighing factor in balanced cross entropy. Alpha is used to specify the weight of different
21-
categories/labels, the size of the array needs to be consistent with the number of classes.
22-
gamma -- focusing parameter for modulating factor (1-p)
23-
24-
Default value:
25-
gamma -- 2.0 as mentioned in the paper
26-
alpha -- 0.25 as mentioned in the paper
27-
28-
References:
29-
Official paper: https://arxiv.org/pdf/1708.02002.pdf
30-
https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy
31-
32-
Usage:
33-
model.compile(loss=[categorical_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
34-
"""
35-
36-
# def categorical_focal_loss_fixed(y_true, y_pred):
37-
387
def focal_loss_fixed(y_true, y_pred):
398
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
409
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
@@ -43,33 +12,6 @@ def focal_loss_fixed(y_true, y_pred):
4312
) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1.0 - pt_0 + K.epsilon()))
4413

4514
return focal_loss_fixed
46-
# """
47-
# :param y_true: A tensor of the same shape as `y_pred`
48-
# :param y_pred: A tensor resulting from a softmax
49-
# :return: Output tensor.
50-
# """
51-
# y_true = tf.cast(y_true, tf.float32)
52-
# # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
53-
# epsilon = K.epsilon()
54-
# # Add the epsilon to prediction value
55-
# # y_pred = y_pred + epsilon
56-
# # Clip the prediciton value
57-
# y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
58-
# # Calculate p_t
59-
# p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred)
60-
# # Calculate alpha_t
61-
# alpha_factor = K.ones_like(y_true) * alpha
62-
# alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
63-
# # Calculate cross entropy
64-
# cross_entropy = -K.log(p_t)
65-
# weight = alpha_t * K.pow((1 - p_t), gamma)
66-
# # Calculate focal loss
67-
# loss = weight * cross_entropy
68-
# # Sum the losses in mini_batch
69-
# loss = K.mean(K.sum(loss, axis=-1))
70-
# return loss
71-
72-
# return tf.keras.losses.BinaryFocalCrossentropy(alpha=alpha, gamma=gamma)
7315

7416

7517
def dice_coef(y_true, y_pred):
@@ -145,50 +87,3 @@ def loss(y_true, y_pred):
14587
return dice + focal_loss
14688

14789
return loss
148-
149-
150-
def dyn_weighted_bincrossentropy(true, pred):
151-
"""
152-
Calculates weighted binary cross entropy. The weights are determined dynamically
153-
by the balance of each category. This weight is calculated for each batch.
154-
155-
The weights are calculted by determining the number of 'pos' and 'neg' classes
156-
in the true labels, then dividing by the number of total predictions.
157-
158-
For example if there is 1 pos class, and 99 neg class, then the weights are 1/100 and 99/100.
159-
These weights can be applied so false negatives are weighted 99/100, while false postives are weighted
160-
1/100. This prevents the classifier from labeling everything negative and getting 99% accuracy.
161-
162-
This can be useful for unbalanced catagories.
163-
164-
"""
165-
# get the total number of inputs
166-
num_pred = K.sum(K.cast(pred < 0.5, true.dtype)) + K.sum(true)
167-
168-
# get weight of values in 'pos' category
169-
zero_weight = K.sum(true) / num_pred + K.epsilon()
170-
171-
# get weight of values in 'false' category
172-
one_weight = K.sum(K.cast(pred < 0.5, true.dtype)) / num_pred + K.epsilon()
173-
174-
# calculate the weight vector
175-
weights = (1.0 - true) * zero_weight + true * one_weight
176-
177-
# calculate the binary cross entropy
178-
bin_crossentropy = K.binary_crossentropy(true, pred)
179-
180-
# apply the weights
181-
weighted_bin_crossentropy = weights * bin_crossentropy
182-
183-
return K.mean(weighted_bin_crossentropy)
184-
185-
186-
def dice_coef_nosq(y_true, y_pred):
187-
smooth = K.epsilon()
188-
y_true_f = K.flatten(y_true)
189-
y_pred_f = K.flatten(y_pred)
190-
intersection = K.sum(y_true_f * y_pred_f)
191-
dice = (2.0 * intersection + smooth) / (
192-
K.sum(K.square(y_true_f)) + K.sum(K.square(y_pred_f)) + smooth
193-
)
194-
return dice

src/models/train_model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from src.models.lib.loss import (categorical_focal_loss, dice_coef,
2121
dice_coef_nosq, dice_focal, dice_loss,
2222
dice_loss_nosq, dyn_weighted_bincrossentropy,
23-
log_cosh_dice_focal, log_cosh_dice_loss,log_cosh_dice_loss_nosq)
23+
log_cosh_dice_focal, log_cosh_dice_loss,
24+
log_cosh_dice_loss_nosq)
2425
from src.models.lib.utils import loss_dict_gen, parse_list_string
2526

2627

@@ -71,7 +72,7 @@ def train_model(
7172
strategy = tf.distribute.MirroredStrategy(devices_name)
7273
with strategy.scope():
7374
metrics = [
74-
dice_coef_nosq,
75+
dice_coef,
7576
]
7677
model, model_layer_name = build_unet_pp(model_config, custom=custom)
7778

@@ -88,7 +89,7 @@ def train_model(
8889
)
8990
else:
9091
metrics = [
91-
dice_coef_nosq,
92+
dice_coef,
9293
]
9394
model, model_layer_name = build_unet_pp(model_config, custom=custom)
9495

src/system/pipeline/output.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def ground_truth_auto_cac(img_dcm_paths, loc_lists, mem_opt=False):
142142
for index, (img_dcm_path, loc_list) in enumerate(zip(img_dcm_paths, loc_lists)):
143143
## Preprocessing
144144
# Get Image HU and pixel spacing
145-
img_hu, pxl_spc = extract_dcm(img_dcm_path)
145+
img_hu, pxl_spc, img_arr = extract_dcm(img_dcm_path)
146146

147147
temp = np.zeros((512, 512))
148148
temp[tuple(zip(*loc_list))] = 1

0 commit comments

Comments
 (0)