Skip to content

Commit e5b06fa

Browse files
Fix sklearn compat issues
1 parent 4e89e3d commit e5b06fa

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

metric_learn/base_metric.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Base module.
33
"""
44

5-
from sklearn.base import BaseEstimator
5+
from sklearn.base import BaseEstimator, ClassifierMixin
66
from sklearn.utils.extmath import stable_cumsum
77
from sklearn.utils.validation import _is_arraylike, check_is_fitted
88
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve
@@ -464,7 +464,7 @@ def get_mahalanobis_matrix(self):
464464
return self.components_.T.dot(self.components_)
465465

466466

467-
class _PairsClassifierMixin(BaseMetricLearner):
467+
class _PairsClassifierMixin(BaseMetricLearner, ClassifierMixin):
468468
"""Base class for pairs learners.
469469
470470
Attributes
@@ -475,6 +475,7 @@ class _PairsClassifierMixin(BaseMetricLearner):
475475
classified as dissimilar.
476476
"""
477477

478+
classes_ = np.array([0, 1])
478479
_tuple_size = 2 # number of points in a tuple, 2 for pairs
479480

480481
def predict(self, pairs):
@@ -752,11 +753,12 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None,
752753
'Got {} instead.'.format(type(beta)))
753754

754755

755-
class _TripletsClassifierMixin(BaseMetricLearner):
756+
class _TripletsClassifierMixin(BaseMetricLearner, ClassifierMixin):
756757
"""
757758
Base class for triplets learners.
758759
"""
759760

761+
classes_ = np.array([0, 1])
760762
_tuple_size = 3 # number of points in a tuple, 3 for triplets
761763

762764
def predict(self, triplets):
@@ -837,11 +839,12 @@ def score(self, triplets):
837839
return self.predict(triplets).mean() / 2 + 0.5
838840

839841

840-
class _QuadrupletsClassifierMixin(BaseMetricLearner):
842+
class _QuadrupletsClassifierMixin(BaseMetricLearner, ClassifierMixin):
841843
"""
842844
Base class for quadruplets learners.
843845
"""
844846

847+
classes_ = np.array([0, 1])
845848
_tuple_size = 4 # number of points in a tuple, 4 for quadruplets
846849

847850
def predict(self, quadruplets):

metric_learn/sdml.py

+4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def _fit(self, pairs, y):
4343
print("SDML will use skggm's graphical lasso solver.")
4444
pairs, y = self._prepare_inputs(pairs, y,
4545
type_of_inputs='tuples')
46+
n_features = pairs.shape[2]
47+
if n_features < 2:
48+
raise ValueError(f"Cannot fit SDML with {n_features} feature(s)")
4649

4750
# set up (the inverse of) the prior M
4851
# if the prior is the default (None), we raise a warning
@@ -83,6 +86,7 @@ def _fit(self, pairs, y):
8386
w_mahalanobis, _ = np.linalg.eigh(M)
8487
not_spd = any(w_mahalanobis < 0.)
8588
not_finite = not np.isfinite(M).all()
89+
# TODO: Narrow this to the specific exceptions we expect.
8690
except Exception as e:
8791
raised_error = e
8892
not_spd = False # not_spd not applicable here so we set to False

0 commit comments

Comments
 (0)