2
2
Base module.
3
3
"""
4
4
5
- from sklearn .base import BaseEstimator
5
+ from sklearn .base import BaseEstimator , ClassifierMixin
6
6
from sklearn .utils .extmath import stable_cumsum
7
7
from sklearn .utils .validation import _is_arraylike , check_is_fitted
8
8
from sklearn .metrics import roc_auc_score , roc_curve , precision_recall_curve
@@ -464,7 +464,7 @@ def get_mahalanobis_matrix(self):
464
464
return self .components_ .T .dot (self .components_ )
465
465
466
466
467
- class _PairsClassifierMixin (BaseMetricLearner ):
467
+ class _PairsClassifierMixin (BaseMetricLearner , ClassifierMixin ):
468
468
"""Base class for pairs learners.
469
469
470
470
Attributes
@@ -475,6 +475,7 @@ class _PairsClassifierMixin(BaseMetricLearner):
475
475
classified as dissimilar.
476
476
"""
477
477
478
+ classes_ = np .array ([0 , 1 ])
478
479
_tuple_size = 2 # number of points in a tuple, 2 for pairs
479
480
480
481
def predict (self , pairs ):
@@ -752,11 +753,12 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None,
752
753
'Got {} instead.' .format (type (beta )))
753
754
754
755
755
- class _TripletsClassifierMixin (BaseMetricLearner ):
756
+ class _TripletsClassifierMixin (BaseMetricLearner , ClassifierMixin ):
756
757
"""
757
758
Base class for triplets learners.
758
759
"""
759
760
761
+ classes_ = np .array ([0 , 1 ])
760
762
_tuple_size = 3 # number of points in a tuple, 3 for triplets
761
763
762
764
def predict (self , triplets ):
@@ -837,11 +839,12 @@ def score(self, triplets):
837
839
return self .predict (triplets ).mean () / 2 + 0.5
838
840
839
841
840
- class _QuadrupletsClassifierMixin (BaseMetricLearner ):
842
+ class _QuadrupletsClassifierMixin (BaseMetricLearner , ClassifierMixin ):
841
843
"""
842
844
Base class for quadruplets learners.
843
845
"""
844
846
847
+ classes_ = np .array ([0 , 1 ])
845
848
_tuple_size = 4 # number of points in a tuple, 4 for quadruplets
846
849
847
850
def predict (self , quadruplets ):
0 commit comments