Skip to content

Commit 833e186

Browse files
authored
[MRG] Learning on Triplets (scikit-learn-contrib#279)
* add _TripletsClassifierMixin * added doc * remove redundant code * added tests * triplets added to doc autosumary * rephrasing, added docstring and small changes * small rephrasing * small flake8 fix * Handle low number of neighbors for knn triplets * add tests for knn triplet generation * fixed typos and rephrasing * added more tests for knn triplet construction * sorted triplet & fix test_generate_knntriplets_k * added over the edge knn triplets test * multiple small code refactoring * more refactoring * Fix & test unlabeled handling triplet generation * closer unlabeled point * small clarity enhancement & repmat replacement
1 parent 1276040 commit 833e186

7 files changed

+526
-24
lines changed

doc/metric_learn.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Base Classes
1414
metric_learn.Constraints
1515
metric_learn.base_metric.BaseMetricLearner
1616
metric_learn.base_metric._PairsClassifierMixin
17+
metric_learn.base_metric._TripletsClassifierMixin
1718
metric_learn.base_metric._QuadrupletsClassifierMixin
1819

1920
Supervised Learning Algorithms

doc/weakly_supervised.rst

+119-20
Original file line numberDiff line numberDiff line change
@@ -592,14 +592,122 @@ points, while constrains the sum of distances between dissimilar points:
592592
-with-side-information.pdf>`_. NIPS 2002
593593
.. [2] Adapted from Matlab code http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz
594594
595+
.. _learning_on_triplets:
596+
597+
Learning on triplets
598+
====================
599+
600+
Some metric learning algorithms learn on triplets of samples. In this case,
601+
one should provide the algorithm with `n_samples` triplets of points. The
602+
semantic of each triplet is that the first point should be closer to the
603+
second point than to the third one.
604+
605+
Fitting
606+
-------
607+
Here is an example for fitting on triplets (see :ref:`fit_ws` for more
608+
details on the input data format and how to fit, in the general case of
609+
learning on tuples).
610+
611+
>>> from metric_learn import SCML
612+
>>> triplets = np.array([[[1.2, 3.2], [2.3, 5.5], [2.1, 0.6]],
613+
>>> [[4.5, 2.3], [2.1, 2.3], [7.3, 3.4]]])
614+
>>> scml = SCML(random_state=42)
615+
>>> scml.fit(triplets)
616+
SCML(beta=1e-5, B=None, max_iter=100000, verbose=False,
617+
preprocessor=None, random_state=None)
618+
619+
Or alternatively (using a preprocessor):
620+
621+
>>> X = np.array([[[1.2, 3.2],
622+
>>> [2.3, 5.5],
623+
>>> [2.1, 0.6],
624+
>>> [4.5, 2.3],
625+
>>> [2.1, 2.3],
626+
>>> [7.3, 3.4]])
627+
>>> triplets_indices = np.array([[0, 1, 2], [3, 4, 5]])
628+
>>> scml = SCML(preprocessor=X, random_state=42)
629+
>>> scml.fit(triplets_indices)
630+
SCML(beta=1e-5, B=None, max_iter=100000, verbose=False,
631+
preprocessor=array([[1.2, 3.2],
632+
[2.3, 5.5],
633+
[2.4, 6.7],
634+
[2.1, 0.6],
635+
[4.5, 2.3],
636+
[2.1, 2.3],
637+
[0.6, 1.2],
638+
[7.3, 3.4]]),
639+
random_state=None)
640+
641+
642+
Here, we want to learn a metric that, for each of the two
643+
`triplets`, will make the first point closer to the
644+
second point than to the third one.
645+
646+
.. _triplets_predicting:
647+
648+
Prediction
649+
----------
650+
651+
When a triplets learner is fitted, it is also able to predict, for an
652+
upcoming triplet, whether the first point is closer to the second point
653+
than to the third one (+1), or not (-1).
654+
655+
>>> triplets_test = np.array(
656+
... [[[5.6, 5.3], [2.2, 2.1], [1.2, 3.4]],
657+
... [[6.0, 4.2], [4.3, 1.2], [0.1, 7.8]]])
658+
>>> scml.predict(triplets_test)
659+
array([-1., 1.])
660+
661+
.. _triplets_scoring:
662+
663+
Scoring
664+
-------
665+
666+
Triplet metric learners can also return a `decision_function` for a set of triplets,
667+
which corresponds to the distance between the first two points minus the distance
668+
between the first and last points of the triplet (the higher the value, the more
669+
similar the first point to the second point compared to the last one). This "score"
670+
can be interpreted as a measure of likeliness of having a +1 prediction for this
671+
triplet.
672+
673+
>>> scml.decision_function(triplets_test)
674+
array([-1.75700306, 4.98982131])
675+
676+
In the above example, for the first triplet in `triplets_test`, the first
677+
point is predicted less similar to the second point than to the last point
678+
(they are further away in the transformed space).
679+
680+
Unlike pairs learners, triplets learners do not allow to give a `y` when fitting: we
681+
assume that the ordering of points within triplets is such that the training triplets
682+
are all positive. Therefore, it is not possible to use scikit-learn scoring functions
683+
(such as 'f1_score') for triplets learners.
684+
685+
However, triplets learners do have a default scoring function, which will
686+
basically return the accuracy score on a given test set, i.e. the proportion
687+
of triplets that have the right predicted ordering.
688+
689+
>>> scml.score(triplets_test)
690+
0.5
691+
692+
.. note::
693+
See :ref:`fit_ws` for more details on metric learners functions that are
694+
not specific to learning on pairs, like `transform`, `score_pairs`,
695+
`get_metric` and `get_mahalanobis_matrix`.
696+
697+
698+
699+
700+
Algorithms
701+
----------
702+
595703

596704
.. _learning_on_quadruplets:
597705

598706
Learning on quadruplets
599707
=======================
600708

601709
Some metric learning algorithms learn on quadruplets of samples. In this case,
602-
one should provide the algorithm with `n_samples` quadruplets of points. Th
710+
one should provide the algorithm with `n_samples` quadruplets of points. The
603711
semantic of each quadruplet is that the first two points should be closer
604712
together than the last two points.
605713

@@ -666,14 +774,12 @@ array([-1., 1.])
666774
Scoring
667775
-------
668776

669-
Quadruplet metric learners can also
670-
return a `decision_function` for a set of pairs. This is basically the "score"
671-
which sign will be taken to find the prediction for the pair, which
672-
corresponds to the difference between the distance between the two last points,
673-
and the distance between the two last points of the quadruplet (higher
674-
score means the two last points are more likely to be more dissimilar than
675-
the two first points (i.e. more likely to have a +1 prediction since it's
676-
the right ordering)).
777+
Quadruplet metric learners can also return a `decision_function` for a set of
778+
quadruplets, which corresponds to the distance between the first pair of points minus
779+
the distance between the second pair of points of the triplet (the higher the value,
780+
the more similar the first pair is than the last pair).
781+
This "score" can be interpreted as a measure of likeliness of having a +1 prediction
782+
for this quadruplet.
677783

678784
>>> lsml.decision_function(quadruplets_test)
679785
array([-1.75700306, 4.98982131])
@@ -682,17 +788,10 @@ In the above example, for the first quadruplet in `quadruplets_test`, the
682788
two first points are predicted less similar than the two last points (they
683789
are further away in the transformed space).
684790

685-
Unlike for pairs learners, quadruplets learners don't allow to give a `y`
686-
when fitting, which does not allow to use scikit-learn scoring functions
687-
like:
688-
689-
>>> from sklearn.model_selection import cross_val_score
690-
>>> cross_val_score(lsml, quadruplets, scoring='f1_score') # this won't work
691-
692-
(This is actually intentional, for more details
693-
about that, see
694-
`this comment <https://github.com/scikit-learn-contrib/metric-learn/pull/168#pullrequestreview-203730742>`_
695-
on github.)
791+
Like triplet learners, quadruplets learners do not allow to give a `y` when fitting: we
792+
assume that the ordering of points within triplets is such that the training triplets
793+
are all positive. Therefore, it is not possible to use scikit-learn scoring functions
794+
(such as 'f1_score') for triplets learners.
696795

697796
However, quadruplets learners do have a default scoring function, which will
698797
basically return the accuracy score on a given test set, i.e. the proportion

metric_learn/base_metric.py

+84-4
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,90 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None,
589589
'Got {} instead.'.format(type(beta)))
590590

591591

592+
class _TripletsClassifierMixin(BaseMetricLearner):
593+
"""Base class for triplets learners.
594+
"""
595+
596+
_tuple_size = 3 # number of points in a tuple, 3 for triplets
597+
598+
def predict(self, triplets):
599+
"""Predicts the ordering between sample distances in input triplets.
600+
601+
For each triplets, returns 1 if the first element is closer to the second
602+
than to the last and -1 if not.
603+
604+
Parameters
605+
----------
606+
triplets : array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3)
607+
3D array of triplets to predict, with each row corresponding to three
608+
points, or 2D array of indices of triplets if the metric learner
609+
uses a preprocessor.
610+
611+
Returns
612+
-------
613+
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
614+
Predictions of the ordering of pairs, for each triplet.
615+
"""
616+
return np.sign(self.decision_function(triplets))
617+
618+
def decision_function(self, triplets):
619+
"""Predicts differences between sample distances in input triplets.
620+
621+
For each triplet (X_a, X_b, X_c) in the samples, computes the difference
622+
between the learned distance of the second pair (X_a, X_c) minus the
623+
learned distance of the first pair (X_a, X_b). The higher it is, the more
624+
probable it is that the pairs in the triplets are presented in the right
625+
order, i.e. that the label of the triplet is 1. The lower it is, the more
626+
probable it is that the label of the triplet is -1.
627+
628+
Parameters
629+
----------
630+
triplet : array-like, shape=(n_triplets, 3, n_features) or \
631+
(n_triplets, 3)
632+
3D array of triplets to predict, with each row corresponding to three
633+
points, or 2D array of indices of triplets if the metric learner
634+
uses a preprocessor.
635+
636+
Returns
637+
-------
638+
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
639+
Metric differences.
640+
"""
641+
check_is_fitted(self, 'preprocessor_')
642+
triplets = check_input(triplets, type_of_inputs='tuples',
643+
preprocessor=self.preprocessor_,
644+
estimator=self, tuple_size=self._tuple_size)
645+
return (self.score_pairs(triplets[:, [0, 2]]) -
646+
self.score_pairs(triplets[:, :2]))
647+
648+
def score(self, triplets):
649+
"""Computes score on input triplets.
650+
651+
Returns the accuracy score of the following classification task: a triplet
652+
(X_a, X_b, X_c) is correctly classified if the predicted similarity between
653+
the first pair (X_a, X_b) is higher than that of the second pair (X_a, X_c)
654+
655+
Parameters
656+
----------
657+
triplets : array-like, shape=(n_triplets, 3, n_features) or \
658+
(n_triplets, 3)
659+
3D array of triplets to score, with each row corresponding to three
660+
points, or 2D array of indices of triplets if the metric learner
661+
uses a preprocessor.
662+
663+
Returns
664+
-------
665+
score : float
666+
The triplets score.
667+
"""
668+
# Since the prediction is a vector of values in {-1, +1}, we need to
669+
# rescale them to {0, 1} to compute the accuracy using the mean (because
670+
# then 1 means a correctly classified result (pairs are in the right
671+
# order), and a 0 an incorrectly classified result (pairs are in the
672+
# wrong order).
673+
return self.predict(triplets).mean() / 2 + 0.5
674+
675+
592676
class _QuadrupletsClassifierMixin(BaseMetricLearner):
593677
"""Base class for quadruplets learners.
594678
"""
@@ -614,10 +698,6 @@ def predict(self, quadruplets):
614698
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
615699
Predictions of the ordering of pairs, for each quadruplet.
616700
"""
617-
check_is_fitted(self, 'preprocessor_')
618-
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
619-
preprocessor=self.preprocessor_,
620-
estimator=self, tuple_size=self._tuple_size)
621701
return np.sign(self.decision_function(quadruplets))
622702

623703
def decision_function(self, quadruplets):

0 commit comments

Comments
 (0)