|
4 | 4 |
|
5 | 5 | import logging |
6 | 6 |
|
7 | | -import matplotlib.pyplot as plt |
8 | 7 | import numpy as np |
9 | | -import seaborn as sns |
10 | 8 | import sklearn |
11 | | -from sklearn import metrics |
12 | 9 |
|
13 | 10 | LOGGER = logging.getLogger(__name__) |
14 | 11 |
|
| 12 | + |
15 | 13 | METRICS = { |
16 | 14 | "accuracy": sklearn.metrics.accuracy_score, |
17 | 15 | "precision": sklearn.metrics.precision_score, |
@@ -85,59 +83,3 @@ def apply_threshold(self, y_proba): |
85 | 83 | binary = [1 if x else 0 for x in y_proba > self._threshold] |
86 | 84 | return binary, self._threshold, self._scores |
87 | 85 |
|
88 | | - |
89 | | -def confusion_matrix( |
90 | | - y_true, |
91 | | - y_pred, |
92 | | - labels=None, |
93 | | - sample_weight=None, |
94 | | - normalize=None): |
95 | | - conf_matrix = metrics.confusion_matrix( |
96 | | - y_true, y_pred, labels=labels, sample_weight=sample_weight, normalize=normalize |
97 | | - ) |
98 | | - fig = plt.figure() |
99 | | - ax = fig.add_axes(sns.heatmap(conf_matrix, annot=True, cmap="Blues")) |
100 | | - |
101 | | - ax.set_title("Confusion Matrix\n") |
102 | | - ax.set_xlabel("\nPredicted Values") |
103 | | - ax.set_ylabel("Actual Values") |
104 | | - |
105 | | - ax.xaxis.set_ticklabels(["False", "True"]) |
106 | | - ax.yaxis.set_ticklabels(["False", "True"]) |
107 | | - |
108 | | - return conf_matrix, fig |
109 | | - |
110 | | - |
111 | | -def roc_auc_score_and_curve( |
112 | | - y_true, y_proba, pos_label=None, sample_weight=None, drop_intermediate=True |
113 | | -): |
114 | | - if y_proba.ndim > 1: |
115 | | - y_proba = y_proba[:, 1] |
116 | | - fpr, tpr, _ = metrics.roc_curve( |
117 | | - y_true, |
118 | | - y_proba, |
119 | | - pos_label=pos_label, |
120 | | - sample_weight=sample_weight, |
121 | | - drop_intermediate=drop_intermediate, |
122 | | - ) |
123 | | - ns_probs = [0 for _ in range(len(y_true))] |
124 | | - ns_fpr, ns_tpr, _ = metrics.roc_curve( |
125 | | - y_true, |
126 | | - ns_probs, |
127 | | - pos_label=pos_label, |
128 | | - sample_weight=sample_weight, |
129 | | - drop_intermediate=drop_intermediate, |
130 | | - ) |
131 | | - |
132 | | - auc = metrics.roc_auc_score(y_true, y_proba) |
133 | | - fig, ax = plt.subplots(1, 1) |
134 | | - |
135 | | - ax.plot(fpr, tpr, "ro") |
136 | | - ax.plot(fpr, tpr) |
137 | | - ax.plot(ns_fpr, ns_tpr, linestyle="--", color="green") |
138 | | - |
139 | | - ax.set_ylabel("True Positive Rate") |
140 | | - ax.set_xlabel("False Positive Rate") |
141 | | - ax.set_title("AUC: %.3f" % auc) |
142 | | - |
143 | | - return auc, fig |
0 commit comments