-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsave_plot.py
43 lines (29 loc) · 1.35 KB
/
save_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import pickle
from utils import confusionMatrixPlot
def savePlot(prediction:str, location:str) -> None:
"""
Save the confusion matrix plots under specific location
Args:
prediction (str) : Path for prediction pikle file.
location(str): Location save the confusion matrix plot.
"""
with open(prediction, 'rb') as f:
data = pickle.load(f)
# print(data)
for fold in data:
for fold_name, fold_predict in fold.items():
# Save confusion matrix plot for digit classification
confusionMatrixPlot(trueLabel=fold_predict['digit_gt'],
predLabel=fold_predict['digit_predict'],
location= location + f'Digit_{fold_name}')
# Save cofusion matrix plot for gender classification
confusionMatrixPlot(trueLabel=fold_predict['gen_gt'],
predLabel=fold_predict['gen_predict'],
location= location+ f'Gender_{fold_name}')
if __name__ == '__main__':
# Location to save the plot
location = 'Result/ConfusionMatrix/'
# Prediction pickle file
prediction = 'Model/BestModelWeight/cv_prediction.pickle'
# Save the plot
savePlot(prediction, location)