Skip to content

Commit 13656a5

Browse files
committed
Training again
1 parent a05bcab commit 13656a5

File tree

11 files changed

+602
-68
lines changed

11 files changed

+602
-68
lines changed
Binary file not shown.

Evaluation-Prediction/predict.py

+78-27
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,53 @@
33
import numpy as np
44
import matplotlib.pyplot as plt
55
#%matplotlib inline
6-
import tensorflow as tf
6+
#import tensorflow as tf
77
import keras.backend as K
8-
8+
import keras
99

1010
from keras.models import Model, load_model
11-
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout,Maximum
12-
from keras.layers.core import Lambda, RepeatVector, Reshape
13-
from keras.layers.convolutional import Conv2D, Conv2DTranspose,Conv3D,Conv3DTranspose
14-
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D,MaxPooling3D
15-
from keras.layers.merge import concatenate, add
16-
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
11+
#from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout,Maximum
12+
#from keras.layers.core import Lambda, RepeatVector, Reshape
13+
#from keras.layers.convolutional import Conv2D, Conv2DTranspose,Conv3D,Conv3DTranspose
14+
#from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D,MaxPooling3D
15+
#from keras.layers.merge import concatenate, add
16+
#from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
1717
from keras.optimizers import Adam
18-
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
18+
#from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
1919

20-
from skimage.io import imread, imshow, concatenate_images
21-
from skimage.transform import resize
20+
#from skimage.io import imread, imshow, concatenate_images
21+
#from skimage.transform import resize
2222

2323
import os
24-
from skimage.io import imread, imshow, concatenate_images
25-
from skimage.transform import resize
24+
#from skimage.io import imread, imshow, concatenate_images
25+
#from skimage.transform import resize
2626
from medpy.io import load
2727
from medpy.io import save
2828
import numpy as np
29+
#import time
30+
#import sys
31+
#sys.path.insert(1, '~/Brain_Segmentation/utils.py')
2932
#import cv2
3033

31-
from ../utils import f1_score,dice_coef,dice_coef_loss
34+
from utils import f1_score,dice_coef,dice_coef_loss,standardize,compute_class_sens_spec,get_sens_spec_df,one_hot_encode
3235
def reverse_encode(a):
3336
return np.argmax(a,axis=-1)
3437

3538

3639

37-
model_to_predict1 = load_model('../Models/survival_pred_240_240.h5',custom_objects={'dice_coef_loss':dice_coef_loss})
38-
model_to_predict2 = load_model('../Models/survival_pred_240_155_1.h5')
39-
model_to_predict3 = load_model('../Models/survival_pred_240_155_2.h5')
40-
path = '../TestData'
40+
41+
model_to_predict1 = load_model('../first_240_155.h5',custom_objects={'dice_coef_loss':dice_coef_loss , 'dice_coef':dice_coef})
42+
#model_to_predict2 = load_model('../Models/survival_pred_240_155_1.h5',custom_objects={'dice_coef_loss':dice_coef_loss , 'f1_score':f1_score})
43+
#model_to_predict3 = load_model('../Models/survival_pred_240_155_2.h5',custom_objects={'dice_coef_loss':dice_coef_loss , 'f1_score':f1_score})
44+
path = '../../Brats17TrainingData/LGG'
4145
all_images = os.listdir(path)
4246
#print(len(all_images))
4347
all_images.sort()
4448

4549
data = np.zeros((240,240,155,4))
4650

47-
for i in range(100,101):
48-
new_image = np.zeros((240,240,155,5))
51+
for i in range(40,42):
52+
new_image = np.zeros((240,240,155,4))
4953
print(i)
5054
x_to = []
5155
y_to = []
@@ -64,24 +68,42 @@ def reverse_encode(a):
6468
print("Entered ground truth")
6569
else:
6670
image_data, image_header = load(image_path);
71+
image_data = standardize(image_data)
6772
data[:,:,:,w] = image_data
6873
print("Entered modality")
6974
w = w+1
7075

7176
print(data.shape)
72-
77+
Y_hat = model_to_predict1.predict(data)
78+
#print(Y_hat.shape)
79+
image_data2[image_data2==4] = 3
80+
Y_hat[Y_hat > 0.6] = 1.0
81+
Y_hat[Y_hat <= 0.6] = 0.0
82+
#print(Y_hat[0,100,100])
83+
#print(len(Y_hat[:,:,:,0]==1))
84+
#print(len(Y_hat[:,:,:,1]==1))
85+
#print(len(Y_hat[:,:,:,2]==1))
86+
#print(len(Y_hat[:,:,:,3]==1))
87+
image_data2 = keras.utils.to_categorical(image_data2, num_classes = 4)
88+
#image_data2 = one_hot_encode(image_data2)
89+
print(get_sens_spec_df(Y_hat,image_data2))
90+
print(model_to_predict1.evaluate(x=data,y=image_data2))
91+
print(model_to_predict1.metrics_names)
7392
#Combining results from all 3 dimensions
74-
93+
'''
7594
for slice_no in range(0,240):
7695
a = slice_no
7796
X = data[slice_no,:,:,:]
7897
X = X.reshape(1,240,155,4)
7998
Y_hat = model_to_predict3.predict(X)
8099
new_image[a,:,:,:] = Y_hat[0,:,:,:]
100+
'''
101+
81102

103+
'''
82104
for slice_no in range(0,155):
83105
a = slice_no
84-
X = data[:,:slice_no,:]
106+
X = data[:,:,slice_no,:]
85107
X = X.reshape(1,240,240,4)
86108
Y_hat = model_to_predict1.predict(X)
87109
new_image[:,:,slice_no,:] += Y_hat[0,:,:,:]
@@ -92,10 +114,39 @@ def reverse_encode(a):
92114
X = X.reshape(1,240,155,4)
93115
Y_hat = model_to_predict2.predict(X)
94116
new_image[:,a,:,:] += Y_hat[0,:,:,:]
95-
117+
'''
96118

97-
new_image = new_image/3 #average of probabilities from 3 directions
119+
120+
#new_image = new_image/3.0
121+
#print(new_image[100,100,100])
122+
#pred = pred.reshape(-1,5)
123+
#pred1 = np.argmax(new_image[:,:,:,1:],axis=3)
124+
#new_image = np.argmax(new_image,axis=3)
125+
#pred1[new_image[:,:,:,0] > 0.56] = 0 #average of probabilities from 3 directions
126+
#pred1 = pred1.astype('int64')
127+
#image_data2 = image_data2.astype('int64')
128+
129+
'''
130+
for slice_no in range(0,155):
131+
print(slice_no)
132+
img = pred1[:,:,slice_no]
133+
imgplot = plt.imshow(img)
134+
plt.show(block=False)
135+
#time.sleep(1)
136+
plt.pause(0.1)
137+
plt.close()
138+
139+
140+
for slice_no in range(0,155):
141+
print(slice_no)
142+
img = image_data2[:,:,slice_no]
143+
imgplot = plt.imshow(img)
144+
plt.show(block=False)
145+
#time.sleep(1)
146+
plt.pause(0.1)
147+
plt.close()
148+
'''
98149

99150

100-
name = '../all_images/VSD.Seg_001.'+ image_id + '.mha'
101-
save(new_image,name)
151+
#name = '../all_images/VSD.Seg_001.'+ image_id + '.mha'
152+
#save(new_image,name)

Evaluation-Prediction/utils.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import random
2+
import pandas as pd
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
#%matplotlib inline
6+
import tensorflow as tf
7+
import keras.backend as K
8+
9+
10+
from keras.models import Model, load_model
11+
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout,Maximum
12+
from keras.layers.core import Lambda, RepeatVector, Reshape
13+
from keras.layers.convolutional import Conv2D, Conv2DTranspose,Conv3D,Conv3DTranspose
14+
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D,MaxPooling3D
15+
from keras.layers.merge import concatenate, add
16+
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
17+
from keras.optimizers import Adam
18+
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
19+
20+
from skimage.io import imread, imshow, concatenate_images
21+
from skimage.transform import resize
22+
23+
import os
24+
from skimage.io import imread, imshow, concatenate_images
25+
from skimage.transform import resize
26+
from medpy.io import load
27+
import numpy as np
28+
29+
import cv2
30+
from sklearn import metrics
31+
32+
def f1_score(y_true, y_pred):
33+
34+
# Count positive samples.
35+
c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
36+
c2 = K.sum(K.round(K.clip(y_true, 0, 1)))
37+
c3 = K.sum(K.round(K.clip(y_pred, 0, 1)))
38+
39+
# If there are no true samples, fix the F1 score at 0.
40+
if c3 == 0:
41+
return 0
42+
43+
# How many selected items are relevant?
44+
precision = c1 / c2
45+
46+
# How many relevant items are selected?
47+
recall = c1 / c3
48+
49+
# Calculate f1_score
50+
f1_score = 2 * (precision * recall) / (precision + recall)
51+
return f1_score
52+
53+
54+
def one_hot_encode(a):
55+
m = (np.arange(4) == a[...,None]).astype(int)
56+
return m
57+
58+
def dice_coef(y_true, y_pred, epsilon=0.00001):
59+
"""
60+
Dice = (2*|X & Y|)/ (|X|+ |Y|)
61+
= 2*sum(|A*B|)/(sum(A^2)+sum(B^2))
62+
ref: https://arxiv.org/pdf/1606.04797v1.pdf
63+
64+
"""
65+
axis = (0,1,2)
66+
dice_numerator = 2. * K.sum(y_true * y_pred, axis=axis) + epsilon
67+
dice_denominator = K.sum(y_true*y_true, axis=axis) + K.sum(y_pred*y_pred, axis=axis) + epsilon
68+
return K.mean((dice_numerator)/(dice_denominator))
69+
70+
def dice_coef_loss(y_true, y_pred):
71+
return 1-dice_coef(y_true, y_pred)
72+
73+
74+
def standardize(image):
75+
76+
standardized_image = np.zeros(image.shape)
77+
78+
#
79+
80+
# iterate over the `z` dimension
81+
for z in range(image.shape[2]):
82+
# get a slice of the image
83+
# at channel c and z-th dimension `z`
84+
image_slice = image[:,:,z]
85+
86+
# subtract the mean from image_slice
87+
centered = image_slice - np.mean(image_slice)
88+
89+
# divide by the standard deviation (only if it is different from zero)
90+
if(np.std(centered)!=0):
91+
centered = centered/np.std(centered)
92+
93+
# update the slice of standardized image
94+
# with the scaled centered and scaled image
95+
standardized_image[:, :, z] = centered
96+
97+
### END CODE HERE ###
98+
99+
return standardized_image
100+
101+
102+
def compute_class_sens_spec(pred, label, class_num):
103+
"""
104+
Compute sensitivity and specificity for a particular example
105+
for a given class.
106+
107+
Args:
108+
pred (np.array): binary arrary of predictions, shape is
109+
(num classes, height, width, depth).
110+
label (np.array): binary array of labels, shape is
111+
(num classes, height, width, depth).
112+
class_num (int): number between 0 - (num_classes -1) which says
113+
which prediction class to compute statistics
114+
for.
115+
116+
Returns:
117+
sensitivity (float): precision for given class_num.
118+
specificity (float): recall for given class_num
119+
"""
120+
121+
# extract sub-array for specified class
122+
class_pred = pred[:,:,:,class_num]
123+
class_label = label[:,:,:,class_num]
124+
125+
### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
126+
127+
# compute true positives, false positives,
128+
# true negatives, false negatives
129+
print(np.sum(class_pred==1))
130+
print(np.sum(class_pred==0))
131+
print(np.sum(class_label==1))
132+
print(np.sum(class_label==0))
133+
tp = np.sum((class_pred == 1) & (class_label == 1))
134+
tn = np.sum((class_pred == 0) & (class_label == 0))
135+
fp = np.sum((class_pred == 1) & (class_label == 0))
136+
fn = np.sum((class_pred == 0) & (class_label == 1))
137+
print(tp,tn,fp,fn)
138+
139+
# compute sensitivity and specificity
140+
sensitivity = tp / (tp + fn)
141+
specificity = tn / (tn + fp)
142+
143+
### END CODE HERE ###
144+
145+
return sensitivity, specificity
146+
147+
148+
def get_sens_spec_df(pred, label):
149+
patch_metrics = pd.DataFrame(
150+
columns = ['Nothing',
151+
'Edema',
152+
'Non-Enhancing Tumor',
153+
'Enhancing Tumor'],
154+
index = ['Sensitivity',
155+
'Specificity'])
156+
157+
for i, class_name in enumerate(patch_metrics.columns):
158+
print(i)
159+
sens, spec = compute_class_sens_spec(pred, label, i)
160+
patch_metrics.loc['Sensitivity', class_name] = round(sens,4)
161+
patch_metrics.loc['Specificity', class_name] = round(spec,4)
162+
163+
return patch_metrics
4.28 KB
Binary file not shown.

0 commit comments

Comments
 (0)