@@ -592,14 +592,122 @@ points, while constrains the sum of distances between dissimilar points:
592
592
-with-side-information.pdf> `_. NIPS 2002
593
593
.. [2 ] Adapted from Matlab code http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz
594
594
595
+ .. _learning_on_triplets :
596
+
597
+ Learning on triplets
598
+ ====================
599
+
600
+ Some metric learning algorithms learn on triplets of samples. In this case,
601
+ one should provide the algorithm with `n_samples ` triplets of points. The
602
+ semantic of each triplet is that the first point should be closer to the
603
+ second point than to the third one.
604
+
605
+ Fitting
606
+ -------
607
+ Here is an example for fitting on triplets (see :ref: `fit_ws ` for more
608
+ details on the input data format and how to fit, in the general case of
609
+ learning on tuples).
610
+
611
+ >>> from metric_learn import SCML
612
+ >>> triplets = np.array([[[1.2 , 3.2 ], [2.3 , 5.5 ], [2.1 , 0.6 ]],
613
+ >>> [[4.5 , 2.3 ], [2.1 , 2.3 ], [7.3 , 3.4 ]]])
614
+ >>> scml = SCML(random_state = 42 )
615
+ >>> scml.fit(triplets)
616
+ SCML(beta=1e-5, B=None, max_iter=100000, verbose=False,
617
+ preprocessor=None, random_state=None)
618
+
619
+ Or alternatively (using a preprocessor):
620
+
621
+ >>> X = np.array([[[1.2 , 3.2 ],
622
+ >>> [2.3 , 5.5 ],
623
+ >>> [2.1 , 0.6 ],
624
+ >>> [4.5 , 2.3 ],
625
+ >>> [2.1 , 2.3 ],
626
+ >>> [7.3 , 3.4 ]])
627
+ >>> triplets_indices = np.array([[0 , 1 , 2 ], [3 , 4 , 5 ]])
628
+ >>> scml = SCML(preprocessor = X, random_state = 42 )
629
+ >>> scml.fit(triplets_indices)
630
+ SCML(beta=1e-5, B=None, max_iter=100000, verbose=False,
631
+ preprocessor=array([[1.2, 3.2],
632
+ [2.3, 5.5],
633
+ [2.4, 6.7],
634
+ [2.1, 0.6],
635
+ [4.5, 2.3],
636
+ [2.1, 2.3],
637
+ [0.6, 1.2],
638
+ [7.3, 3.4]]),
639
+ random_state=None)
640
+
641
+
642
+ Here, we want to learn a metric that, for each of the two
643
+ `triplets `, will make the first point closer to the
644
+ second point than to the third one.
645
+
646
+ .. _triplets_predicting :
647
+
648
+ Prediction
649
+ ----------
650
+
651
+ When a triplets learner is fitted, it is also able to predict, for an
652
+ upcoming triplet, whether the first point is closer to the second point
653
+ than to the third one (+1), or not (-1).
654
+
655
+ >>> triplets_test = np.array(
656
+ ... [[[5.6 , 5.3 ], [2.2 , 2.1 ], [1.2 , 3.4 ]],
657
+ ... [[6.0 , 4.2 ], [4.3 , 1.2 ], [0.1 , 7.8 ]]])
658
+ >>> scml.predict(triplets_test)
659
+ array([-1., 1.])
660
+
661
+ .. _triplets_scoring :
662
+
663
+ Scoring
664
+ -------
665
+
666
+ Triplet metric learners can also return a `decision_function ` for a set of triplets,
667
+ which corresponds to the distance between the first two points minus the distance
668
+ between the first and last points of the triplet (the higher the value, the more
669
+ similar the first point to the second point compared to the last one). This "score"
670
+ can be interpreted as a measure of likeliness of having a +1 prediction for this
671
+ triplet.
672
+
673
+ >>> scml.decision_function(triplets_test)
674
+ array([-1.75700306, 4.98982131])
675
+
676
+ In the above example, for the first triplet in `triplets_test `, the first
677
+ point is predicted less similar to the second point than to the last point
678
+ (they are further away in the transformed space).
679
+
680
+ Unlike pairs learners, triplets learners do not allow to give a `y ` when fitting: we
681
+ assume that the ordering of points within triplets is such that the training triplets
682
+ are all positive. Therefore, it is not possible to use scikit-learn scoring functions
683
+ (such as 'f1_score') for triplets learners.
684
+
685
+ However, triplets learners do have a default scoring function, which will
686
+ basically return the accuracy score on a given test set, i.e. the proportion
687
+ of triplets that have the right predicted ordering.
688
+
689
+ >>> scml.score(triplets_test)
690
+ 0.5
691
+
692
+ .. note ::
693
+ See :ref: `fit_ws ` for more details on metric learners functions that are
694
+ not specific to learning on pairs, like `transform `, `score_pairs `,
695
+ `get_metric ` and `get_mahalanobis_matrix `.
696
+
697
+
698
+
699
+
700
+ Algorithms
701
+ ----------
702
+
595
703
596
704
.. _learning_on_quadruplets :
597
705
598
706
Learning on quadruplets
599
707
=======================
600
708
601
709
Some metric learning algorithms learn on quadruplets of samples. In this case,
602
- one should provide the algorithm with `n_samples ` quadruplets of points. Th
710
+ one should provide the algorithm with `n_samples ` quadruplets of points. The
603
711
semantic of each quadruplet is that the first two points should be closer
604
712
together than the last two points.
605
713
@@ -666,14 +774,12 @@ array([-1., 1.])
666
774
Scoring
667
775
-------
668
776
669
- Quadruplet metric learners can also
670
- return a `decision_function ` for a set of pairs. This is basically the "score"
671
- which sign will be taken to find the prediction for the pair, which
672
- corresponds to the difference between the distance between the two last points,
673
- and the distance between the two last points of the quadruplet (higher
674
- score means the two last points are more likely to be more dissimilar than
675
- the two first points (i.e. more likely to have a +1 prediction since it's
676
- the right ordering)).
777
+ Quadruplet metric learners can also return a `decision_function ` for a set of
778
+ quadruplets, which corresponds to the distance between the first pair of points minus
779
+ the distance between the second pair of points of the triplet (the higher the value,
780
+ the more similar the first pair is than the last pair).
781
+ This "score" can be interpreted as a measure of likeliness of having a +1 prediction
782
+ for this quadruplet.
677
783
678
784
>>> lsml.decision_function(quadruplets_test)
679
785
array([-1.75700306, 4.98982131])
@@ -682,17 +788,10 @@ In the above example, for the first quadruplet in `quadruplets_test`, the
682
788
two first points are predicted less similar than the two last points (they
683
789
are further away in the transformed space).
684
790
685
- Unlike for pairs learners, quadruplets learners don't allow to give a `y `
686
- when fitting, which does not allow to use scikit-learn scoring functions
687
- like:
688
-
689
- >>> from sklearn.model_selection import cross_val_score
690
- >>> cross_val_score(lsml, quadruplets, scoring = ' f1_score' ) # this won't work
691
-
692
- (This is actually intentional, for more details
693
- about that, see
694
- `this comment <https://github.com/scikit-learn-contrib/metric-learn/pull/168#pullrequestreview-203730742 >`_
695
- on github.)
791
+ Like triplet learners, quadruplets learners do not allow to give a `y ` when fitting: we
792
+ assume that the ordering of points within triplets is such that the training triplets
793
+ are all positive. Therefore, it is not possible to use scikit-learn scoring functions
794
+ (such as 'f1_score') for triplets learners.
696
795
697
796
However, quadruplets learners do have a default scoring function, which will
698
797
basically return the accuracy score on a given test set, i.e. the proportion
0 commit comments