Skip to content

Commit 4fa72a0

Browse files
committed
adding support for custom evaluation metrics
1 parent 0bd6b39 commit 4fa72a0

8 files changed

+346
-37
lines changed

fiftyone/utils/eval/activitynet.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class ActivityNetEvaluationConfig(DetectionEvaluationConfig):
4040
that mAP and PR curves can be generated
4141
iou_threshs (None): a list of IoU thresholds to use when computing mAP
4242
and PR curves. Only applicable when ``compute_mAP`` is True
43+
custom_metrics (None): an optional list of custom metrics to compute
44+
or dict mapping metric names to kwargs dicts
4345
"""
4446

4547
def __init__(
@@ -50,10 +52,16 @@ def __init__(
5052
classwise=None,
5153
compute_mAP=False,
5254
iou_threshs=None,
55+
custom_metrics=None,
5356
**kwargs,
5457
):
5558
super().__init__(
56-
pred_field, gt_field, iou=iou, classwise=classwise, **kwargs
59+
pred_field,
60+
gt_field,
61+
iou=iou,
62+
classwise=classwise,
63+
custom_metrics=custom_metrics,
64+
**kwargs,
5765
)
5866

5967
if compute_mAP and iou_threshs is None:
@@ -323,6 +331,7 @@ class ActivityNetDetectionResults(DetectionResults):
323331
``num_iou_threshs x num_classes x num_recall``
324332
missing (None): a missing label string. Any unmatched segments are
325333
given this label for evaluation purposes
334+
custom_metrics (None): an optional dict of custom metrics
326335
backend (None): a :class:`ActivityNetEvaluation` backend
327336
"""
328337

@@ -339,6 +348,7 @@ def __init__(
339348
classes,
340349
thresholds=None,
341350
missing=None,
351+
custom_metrics=None,
342352
backend=None,
343353
):
344354
super().__init__(
@@ -348,6 +358,7 @@ def __init__(
348358
matches,
349359
classes=classes,
350360
missing=missing,
361+
custom_metrics=custom_metrics,
351362
backend=backend,
352363
)
353364

fiftyone/utils/eval/base.py

+110-2
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,124 @@
66
|
77
"""
88
import itertools
9+
import logging
910

1011
import numpy as np
1112
import sklearn.metrics as skm
1213

1314
import fiftyone.core.evaluation as foe
1415
import fiftyone.core.plots as fop
16+
import fiftyone.core.utils as fou
17+
18+
foo = fou.lazy_import("fiftyone.operators")
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class BaseEvaluationMethodConfig(foe.EvaluationMethodConfig):
25+
"""Base class for configuring evaluation methods.
26+
27+
Args:
28+
**kwargs: any leftover keyword arguments after subclasses have done
29+
their parsing
30+
"""
31+
32+
pass
33+
34+
35+
class BaseEvaluationMethod(foe.EvaluationMethod):
36+
"""Base class for evaluation methods.
37+
38+
Args:
39+
config: an :class:`BaseEvaluationMethodConfig`
40+
"""
41+
42+
def _get_custom_metrics(self):
43+
if not self.config.custom_metrics:
44+
return {}
45+
46+
if isinstance(self.config.custom_metrics, list):
47+
return {m: None for m in self.config.custom_metrics}
48+
49+
return self.config.custom_metrics
50+
51+
def compute_custom_metrics(self, samples, eval_key, results):
52+
results.custom_metrics = {}
53+
54+
for metric, kwargs in self._get_custom_metrics().items():
55+
try:
56+
operator = foo.get_operator(metric)
57+
value = operator.compute(
58+
samples, eval_key, results, **kwargs or {}
59+
)
60+
if value is not None:
61+
results.custom_metrics[operator.config.label] = value
62+
except Exception as e:
63+
logger.warning(
64+
"Failed to compute metric '%s': Reason: %s",
65+
operator.uri,
66+
e,
67+
)
68+
69+
def get_custom_metric_fields(self, samples, eval_key):
70+
fields = []
71+
72+
for metric in self._get_custom_metrics().keys():
73+
try:
74+
operator = foo.get_operator(metric)
75+
fields.extend(operator.get_fields(samples, eval_key))
76+
except Exception as e:
77+
logger.warning(
78+
"Failed to get fields for metric '%s': Reason: %s",
79+
operator.uri,
80+
e,
81+
)
82+
83+
return fields
84+
85+
def rename_custom_metrics(self, samples, eval_key, new_eval_key):
86+
for metric in self._get_custom_metrics().keys():
87+
try:
88+
operator = foo.get_operator(metric)
89+
operator.rename(samples, eval_key, new_eval_key)
90+
except Exception as e:
91+
logger.warning(
92+
"Failed to rename fields for metric '%s': Reason: %s",
93+
operator.uri,
94+
e,
95+
)
96+
97+
def cleanup_custom_metrics(self, samples, eval_key):
98+
for metric in self._get_custom_metrics().keys():
99+
try:
100+
operator = foo.get_operator(metric)
101+
operator.cleanup(samples, eval_key)
102+
except Exception as e:
103+
logger.warning(
104+
"Failed to cleanup metric '%s': Reason: %s",
105+
operator.uri,
106+
e,
107+
)
15108

16109

17110
class BaseEvaluationResults(foe.EvaluationResults):
18111
"""Base class for evaluation results.
19112
113+
Args:
114+
samples: the :class:`fiftyone.core.collections.SampleCollection` used
115+
config: the :class:`BaseEvaluationMethodConfig` used
116+
eval_key: the evaluation key
117+
backend (None): an :class:`EvaluationMethod` backend
118+
"""
119+
120+
pass
121+
122+
123+
class BaseClassificationResults(BaseEvaluationResults):
124+
"""Base class for evaluation results that expose classification metrics
125+
like P/R/F1 and confusion matrices.
126+
20127
Args:
21128
samples: the :class:`fiftyone.core.collections.SampleCollection` used
22129
config: the :class:`fiftyone.core.evaluation.EvaluationMethodConfig`
@@ -32,8 +139,7 @@ class BaseEvaluationResults(foe.EvaluationResults):
32139
observed ground truth/predicted labels are used
33140
missing (None): a missing label string. Any None-valued labels are
34141
given this label for evaluation purposes
35-
samples (None): the :class:`fiftyone.core.collections.SampleCollection`
36-
for which the results were computed
142+
custom_metrics (None): an optional dict of custom metrics
37143
backend (None): a :class:`fiftyone.core.evaluation.EvaluationMethod`
38144
backend
39145
"""
@@ -51,6 +157,7 @@ def __init__(
51157
ypred_ids=None,
52158
classes=None,
53159
missing=None,
160+
custom_metrics=None,
54161
backend=None,
55162
):
56163
super().__init__(samples, config, eval_key, backend=backend)
@@ -72,6 +179,7 @@ def __init__(
72179
)
73180
self.classes = np.asarray(classes)
74181
self.missing = missing
182+
self.custom_metrics = custom_metrics
75183

76184
def report(self, classes=None):
77185
"""Generates a classification report for the results via

0 commit comments

Comments
 (0)