Skip to content

Commit 5bb1e70

Browse files
committed
moved evaluation metrics to separate file
1 parent 5b813b9 commit 5bb1e70

File tree

5 files changed

+71
-65
lines changed

5 files changed

+71
-65
lines changed

zephyr_ml/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
"sklearn.metrics.precision_score",
2222
"sklearn.metrics.f1_score",
2323
"sklearn.metrics.recall_score",
24-
"zephyr_ml.primitives.postprocessing.confusion_matrix",
25-
"zephyr_ml.primitives.postprocessing.roc_auc_score_and_curve",
24+
"zephyr_ml.primitives.evaluation.confusion_matrix",
25+
"zephyr_ml.primitives.evaluation.roc_auc_score_and_curve",
2626
]
2727

2828
LOGGER = logging.getLogger(__name__)

zephyr_ml/primitives/evaluation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Evaluation metrics
3+
"""
4+
5+
import matplotlib.pyplot as plt
6+
import seaborn as sns
7+
from sklearn import metrics
8+
9+
10+
def confusion_matrix(
11+
y_true,
12+
y_pred,
13+
labels=None,
14+
sample_weight=None,
15+
normalize=None):
16+
conf_matrix = metrics.confusion_matrix(
17+
y_true, y_pred, labels=labels, sample_weight=sample_weight, normalize=normalize
18+
)
19+
fig = plt.figure()
20+
ax = fig.add_axes(sns.heatmap(conf_matrix, annot=True, cmap="Blues"))
21+
22+
ax.set_title("Confusion Matrix\n")
23+
ax.set_xlabel("\nPredicted Values")
24+
ax.set_ylabel("Actual Values")
25+
26+
ax.xaxis.set_ticklabels(["False", "True"])
27+
ax.yaxis.set_ticklabels(["False", "True"])
28+
29+
return conf_matrix, fig
30+
31+
32+
def roc_auc_score_and_curve(
33+
y_true, y_proba, pos_label=None, sample_weight=None, drop_intermediate=True
34+
):
35+
if y_proba.ndim > 1:
36+
y_proba = y_proba[:, 1]
37+
fpr, tpr, _ = metrics.roc_curve(
38+
y_true,
39+
y_proba,
40+
pos_label=pos_label,
41+
sample_weight=sample_weight,
42+
drop_intermediate=drop_intermediate,
43+
)
44+
ns_probs = [0 for _ in range(len(y_true))]
45+
ns_fpr, ns_tpr, _ = metrics.roc_curve(
46+
y_true,
47+
ns_probs,
48+
pos_label=pos_label,
49+
sample_weight=sample_weight,
50+
drop_intermediate=drop_intermediate,
51+
)
52+
53+
auc = metrics.roc_auc_score(y_true, y_proba)
54+
fig, ax = plt.subplots(1, 1)
55+
56+
ax.plot(fpr, tpr, "ro")
57+
ax.plot(fpr, tpr)
58+
ax.plot(ns_fpr, ns_tpr, linestyle="--", color="green")
59+
60+
ax.set_ylabel("True Positive Rate")
61+
ax.set_xlabel("False Positive Rate")
62+
ax.set_title("AUC: %.3f" % auc)
63+
64+
return auc, fig
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
{
2-
"name": "zephyr_ml.primitives.postprocessing.confusion_matrix",
2+
"name": "zephyr_ml.primitives.evaluation.confusion_matrix",
33
"contributors": ["Raymond Pan <[email protected]>"],
44
"description": "Create and plot confusion matrix.",
55
"classifiers": {
66
"type": "helper"
77
},
88
"modalities": [],
9-
"primitive": "zephyr_ml.primitives.postprocessing.confusion_matrix",
9+
"primitive": "zephyr_ml.primitives.evaluation.confusion_matrix",
1010
"produce": {
1111
"args": [
1212
{
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
{
2-
"name": "zephyr_ml.primitives.postprocessing.roc_auc_score_and_curve",
2+
"name": "zephyr_ml.primitives.evaluation.roc_auc_score_and_curve",
33
"contributors": ["Raymond Pan <[email protected]>"],
44
"description": "Calculate ROC AUC score and plot curve.",
55
"classifiers": {
66
"type": "helper"
77
},
88
"modalities": [],
9-
"primitive": "zephyr_ml.primitives.postprocessing.roc_auc_score_and_curve",
9+
"primitive": "zephyr_ml.primitives.evaluation.roc_auc_score_and_curve",
1010
"produce": {
1111
"args": [
1212
{

zephyr_ml/primitives/postprocessing.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
import logging
66

7-
import matplotlib.pyplot as plt
87
import numpy as np
9-
import seaborn as sns
108
import sklearn
11-
from sklearn import metrics
129

1310
LOGGER = logging.getLogger(__name__)
1411

12+
1513
METRICS = {
1614
"accuracy": sklearn.metrics.accuracy_score,
1715
"precision": sklearn.metrics.precision_score,
@@ -85,59 +83,3 @@ def apply_threshold(self, y_proba):
8583
binary = [1 if x else 0 for x in y_proba > self._threshold]
8684
return binary, self._threshold, self._scores
8785

88-
89-
def confusion_matrix(
90-
y_true,
91-
y_pred,
92-
labels=None,
93-
sample_weight=None,
94-
normalize=None):
95-
conf_matrix = metrics.confusion_matrix(
96-
y_true, y_pred, labels=labels, sample_weight=sample_weight, normalize=normalize
97-
)
98-
fig = plt.figure()
99-
ax = fig.add_axes(sns.heatmap(conf_matrix, annot=True, cmap="Blues"))
100-
101-
ax.set_title("Confusion Matrix\n")
102-
ax.set_xlabel("\nPredicted Values")
103-
ax.set_ylabel("Actual Values")
104-
105-
ax.xaxis.set_ticklabels(["False", "True"])
106-
ax.yaxis.set_ticklabels(["False", "True"])
107-
108-
return conf_matrix, fig
109-
110-
111-
def roc_auc_score_and_curve(
112-
y_true, y_proba, pos_label=None, sample_weight=None, drop_intermediate=True
113-
):
114-
if y_proba.ndim > 1:
115-
y_proba = y_proba[:, 1]
116-
fpr, tpr, _ = metrics.roc_curve(
117-
y_true,
118-
y_proba,
119-
pos_label=pos_label,
120-
sample_weight=sample_weight,
121-
drop_intermediate=drop_intermediate,
122-
)
123-
ns_probs = [0 for _ in range(len(y_true))]
124-
ns_fpr, ns_tpr, _ = metrics.roc_curve(
125-
y_true,
126-
ns_probs,
127-
pos_label=pos_label,
128-
sample_weight=sample_weight,
129-
drop_intermediate=drop_intermediate,
130-
)
131-
132-
auc = metrics.roc_auc_score(y_true, y_proba)
133-
fig, ax = plt.subplots(1, 1)
134-
135-
ax.plot(fpr, tpr, "ro")
136-
ax.plot(fpr, tpr)
137-
ax.plot(ns_fpr, ns_tpr, linestyle="--", color="green")
138-
139-
ax.set_ylabel("True Positive Rate")
140-
ax.set_xlabel("False Positive Rate")
141-
ax.set_title("AUC: %.3f" % auc)
142-
143-
return auc, fig

0 commit comments

Comments
 (0)