-
-
Notifications
You must be signed in to change notification settings - Fork 145
negate the result and prefix the metric name for error/loss metrics #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4dcde75
4cabe71
932c5c4
45dfc03
b8c36e7
1ee40a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,9 @@ | |
| import pandas as pd | ||
|
|
||
| from .data import Dataset, DatasetType, Feature | ||
| from .datautils import accuracy_score, confusion_matrix, f1_score, log_loss, balanced_accuracy_score, mean_absolute_error, mean_squared_error, mean_squared_log_error, r2_score, roc_auc_score, read_csv, write_csv, is_data_frame, to_data_frame | ||
| from .datautils import accuracy_score, auc, average_precision_score, balanced_accuracy_score, confusion_matrix, fbeta_score, log_loss, \ | ||
| mean_absolute_error, mean_squared_error, mean_squared_log_error, precision_recall_curve, r2_score, roc_auc_score, \ | ||
| read_csv, write_csv, is_data_frame, to_data_frame | ||
| from .resources import get as rget, config as rconfig, output_dirs | ||
| from .utils import Namespace, backup_file, cached, datetime_iso, json_load, memoize, profile | ||
|
|
||
|
|
@@ -394,6 +396,10 @@ def do_score(m): | |
| for metric in metadata.metrics or []: | ||
| scores[metric] = do_score(metric) | ||
| scores.result = scores[scores.metric] if scores.metric in scores else do_score(scores.metric) | ||
| if not higher_is_better(scores.metric): | ||
| scores.metric = f"neg_{scores.metric}" | ||
| scores.result = - scores.result | ||
|
|
||
| scores.info = result.info | ||
| if scoring_errors: | ||
| scores.info = "; ".join(filter(lambda it: it, [scores.info, *scoring_errors])) | ||
|
|
@@ -453,6 +459,8 @@ def __init__(self, error): | |
|
|
||
| class ClassificationResult(Result): | ||
|
|
||
| multi_class_average = 'weighted' # used by metrics like fbeta or auc | ||
|
|
||
| def __init__(self, predictions_df, info=None): | ||
| super().__init__(predictions_df, info) | ||
| self.classes = self.df.columns[:-2].values.astype(str, copy=False) | ||
|
|
@@ -464,42 +472,80 @@ def __init__(self, predictions_df, info=None): | |
| self.labels = self._autoencode(self.classes) | ||
|
|
||
| def acc(self): | ||
| """Accuracy""" | ||
| return float(accuracy_score(self.truth, self.predictions)) | ||
|
|
||
| def balacc(self): | ||
| return float(balanced_accuracy_score(self.truth, self.predictions)) | ||
|
|
||
| def auc(self): | ||
| """Array Under (ROC) Curve, computed on probabilities, not on predictions""" | ||
| if self.type != DatasetType.binary: | ||
| # raise ValueError("AUC metric is only supported for binary classification: {}.".format(self.classes)) | ||
| log.warning("AUC metric is only supported for binary classification: %s.", self.labels) | ||
| log.warning("For multiclass problems, please use `auc_ovr` or `auc_ovo` metrics instead of `auc`.") | ||
| return nan | ||
| return float(roc_auc_score(self.truth, self.probabilities[:, 1], labels=self.labels)) | ||
| return float(roc_auc_score(self.truth, self.probabilities[:, 1])) | ||
|
|
||
| def cm(self): | ||
| return confusion_matrix(self.truth, self.predictions, labels=self.labels) | ||
| def auc_ovo(self): | ||
| """AUC One-vs-One""" | ||
| return self._auc_multi(mc='ovo') | ||
|
|
||
| def _per_class_errors(self): | ||
| return [(s-d)/s for s, d in ((sum(r), r[i]) for i, r in enumerate(self.cm()))] | ||
| def auc_ovr(self): | ||
| """AUC One-vs-Rest""" | ||
| return self._auc_multi(mc='ovr') | ||
|
|
||
| def mean_pce(self): | ||
| """mean per class error""" | ||
| return statistics.mean(self._per_class_errors()) | ||
| def balacc(self): | ||
| """Balanced accuracy""" | ||
| return float(balanced_accuracy_score(self.truth, self.predictions)) | ||
|
|
||
| def max_pce(self): | ||
| """max per class error""" | ||
| return max(self._per_class_errors()) | ||
| def f05(self): | ||
| """F-beta 0.5""" | ||
| return self._fbeta(0.5) | ||
|
|
||
| def f1(self): | ||
| return float(f1_score(self.truth, self.predictions, labels=self.labels)) | ||
| """F-beta 1""" | ||
| return self._fbeta(1) | ||
|
|
||
| def f2(self): | ||
| """F-beta 2""" | ||
| return self._fbeta(2) | ||
|
|
||
| def logloss(self): | ||
| """Log Loss""" | ||
| return float(log_loss(self.truth, self.probabilities, labels=self.labels)) | ||
|
|
||
| def max_pce(self): | ||
| """Max per Class Error""" | ||
| return max(self._per_class_errors()) | ||
|
|
||
| def mean_pce(self): | ||
| """Mean per Class Error""" | ||
| return statistics.mean(self._per_class_errors()) | ||
|
|
||
| def pr_auc(self): | ||
| """Precision Recall AUC""" | ||
| if self.type != DatasetType.binary: | ||
| log.warning("PR AUC metric is only available for binary problems.") | ||
| return nan | ||
| # precision, recall, thresholds = precision_recall_curve(self.truth, self.probabilities[:, 1]) | ||
| # return float(auc(recall, precision)) | ||
| return float(average_precision_score(self.truth, self.probabilities[:, 1])) | ||
|
|
||
| def _autoencode(self, vec): | ||
| needs_encoding = not _encode_predictions_and_truth_ or (isinstance(vec[0], str) and not vec[0].isdigit()) | ||
| return self.target.label_encoder.transform(vec) if needs_encoding else vec | ||
|
|
||
| def _auc_multi(self, mc='raise'): | ||
| average = ClassificationResult.multi_class_average | ||
| return float(roc_auc_score(self.truth, self.probabilities, average=average, labels=self.labels, multi_class=mc)) | ||
|
|
||
| def _cm(self): | ||
| return confusion_matrix(self.truth, self.predictions, labels=self.labels) | ||
|
|
||
| def _fbeta(self, beta): | ||
| average = ClassificationResult.multi_class_average if self.truth == DatasetType.multiclass else 'binary' | ||
| return float(fbeta_score(self.truth, self.predictions, beta=beta, average=average, labels=self.labels)) | ||
|
|
||
| def _per_class_errors(self): | ||
| return [(s-d)/s for s, d in ((sum(r), r[i]) for i, r in enumerate(self._cm()))] | ||
|
|
||
|
|
||
|
|
||
| class RegressionResult(Result): | ||
|
|
||
|
|
@@ -510,24 +556,34 @@ def __init__(self, predictions_df, info=None): | |
| self.type = DatasetType.regression | ||
|
|
||
| def mae(self): | ||
| """Mean Absolute Error""" | ||
| return float(mean_absolute_error(self.truth, self.predictions)) | ||
|
|
||
| def mse(self): | ||
| """Mean Squared Error""" | ||
| return float(mean_squared_error(self.truth, self.predictions)) | ||
|
|
||
| def msle(self): | ||
| """Mean Squared Logarithmic Error""" | ||
| return float(mean_squared_log_error(self.truth, self.predictions)) | ||
|
|
||
| def rmse(self): | ||
| """Root Mean Square Error""" | ||
| return math.sqrt(self.mse()) | ||
|
|
||
| def rmsle(self): | ||
| """Root Mean Square Logarithmic Error""" | ||
| return math.sqrt(self.msle()) | ||
|
|
||
| def r2(self): | ||
| """R^2""" | ||
| return float(r2_score(self.truth, self.predictions)) | ||
|
|
||
|
|
||
| def higher_is_better(metric): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems a bit hacky. Better to have either a dictionary mapping or metrics as classes (example in AutoGluon).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't disagree with you: it IS a bit hacky. |
||
| return re.fullmatch(r"((pr_)?auc(_\w*)?)|(\w*acc)|(f\d+)|(r2)", metric) | ||
|
|
||
|
|
||
| _encode_predictions_and_truth_ = False | ||
|
|
||
| save_predictions = TaskResult.save_predictions | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: area instead of array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oups! will fix