Skip to content

Commit a22c2e6

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] Remove shogun dependency (scikit-learn-contrib#216)
* Remove shogun dependency * Finalize removing of shogun LMNN * Remove LMNN useless base class
1 parent 999cb5b commit a22c2e6

File tree

7 files changed

+19
-89
lines changed

7 files changed

+19
-89
lines changed

README.rst

-7
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,6 @@ package installed).
4141

4242
See the `sphinx documentation`_ for full documentation about installation, API, usage, and examples.
4343

44-
**Notes**
45-
46-
If a recent version of the Shogun Python modular (``modshogun``) library
47-
is available, the LMNN implementation will use the fast C++ version from
48-
there. The two implementations differ slightly, and the C++ version is
49-
more complete.
50-
5144

5245
.. _sphinx documentation: http://metric-learn.github.io/metric-learn/
5346

bench/benchmarks/iris.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,9 @@
1313
'NCA': metric_learn.NCA(max_iter=700, n_components=2),
1414
'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30,
1515
chunk_size=2),
16-
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500),
16+
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500)
1717
}
1818

19-
try:
20-
from metric_learn.lmnn import python_LMNN
21-
if python_LMNN is not metric_learn.LMNN:
22-
CLASSES['python_LMNN'] = python_LMNN(k=5, learn_rate=1e-6, verbose=False)
23-
except ImportError:
24-
pass
25-
2619

2720
class IrisDataset(object):
2821
params = [sorted(CLASSES)]

doc/getting_started.rst

-8
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,6 @@ Alternately, download the source repository and run:
2323
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
2424
- For running the examples only: matplotlib
2525

26-
**Notes**
27-
28-
If a recent version of the Shogun Python modular (``modshogun``) library
29-
is available, the LMNN implementation will use the fast C++ version from
30-
there. The two implementations differ slightly, and the C++ version is
31-
more complete.
32-
33-
3426
Quick start
3527
===========
3628

doc/supervised.rst

-5
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,6 @@ indicates :math:`\mathbf{x}_{i}, \mathbf{x}_{j}` belong to different class,
8787
lmnn = LMNN(k=5, learn_rate=1e-6)
8888
lmnn.fit(X, Y, verbose=False)
8989

90-
If a recent version of the Shogun Python modular (``modshogun``) library
91-
is available, the LMNN implementation will use the fast C++ version from
92-
there. Otherwise, the included pure-Python version will be used.
93-
The two implementations differ slightly, and the C++ version is more complete.
94-
9590
.. topic:: References:
9691

9792
.. [1] `Distance Metric Learning for Large Margin Nearest Neighbor

metric_learn/lmnn.py

+2-44
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
from .base_metric import MahalanobisMixin
2626

2727

28-
# commonality between LMNN implementations
29-
class _base_LMNN(MahalanobisMixin, TransformerMixin):
28+
class LMNN(MahalanobisMixin, TransformerMixin):
3029
def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
3130
learn_rate=1e-7, regularization=0.5, convergence_tol=0.001,
3231
use_pca=True, verbose=False, preprocessor=None,
@@ -114,11 +113,7 @@ def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
114113
self.n_components = n_components
115114
self.num_dims = num_dims
116115
self.random_state = random_state
117-
super(_base_LMNN, self).__init__(preprocessor)
118-
119-
120-
# slower Python version
121-
class python_LMNN(_base_LMNN):
116+
super(LMNN, self).__init__(preprocessor)
122117

123118
def fit(self, X, y):
124119
if self.num_dims != 'deprecated':
@@ -344,40 +339,3 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):
344339
if weights is not None:
345340
return np.dot(Xab.T, Xab * weights[:,None])
346341
return np.dot(Xab.T, Xab)
347-
348-
349-
try:
350-
# use the fast C++ version, if available
351-
from modshogun import LMNN as shogun_LMNN
352-
from modshogun import RealFeatures, MulticlassLabels
353-
354-
class LMNN(_base_LMNN):
355-
"""Large Margin Nearest Neighbor (LMNN)
356-
357-
Attributes
358-
----------
359-
n_iter_ : `int`
360-
The number of iterations the solver has run.
361-
362-
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
363-
The learned linear transformation ``L``.
364-
"""
365-
366-
def fit(self, X, y):
367-
X, y = self._prepare_inputs(X, y, dtype=float,
368-
ensure_min_samples=2)
369-
labels = MulticlassLabels(y)
370-
self._lmnn = shogun_LMNN(RealFeatures(X.T), labels, self.k)
371-
self._lmnn.set_maxiter(self.max_iter)
372-
self._lmnn.set_obj_threshold(self.convergence_tol)
373-
self._lmnn.set_regularization(self.regularization)
374-
self._lmnn.set_stepsize(self.learn_rate)
375-
if self.use_pca:
376-
self._lmnn.train()
377-
else:
378-
self._lmnn.train(np.eye(X.shape[1]))
379-
self.transformer_ = self._lmnn.get_linear_transform(X)
380-
return self
381-
382-
except ImportError:
383-
LMNN = python_LMNN

test/metric_learn_test.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
RCA_Supervised, MMC_Supervised, SDML, ITML, LSML)
2424
# Import this specially for testing.
2525
from metric_learn.constraints import wrap_pairs
26-
from metric_learn.lmnn import python_LMNN, _sum_outer_products
26+
from metric_learn.lmnn import _sum_outer_products
2727

2828

2929
def class_separation(X, labels):
@@ -213,14 +213,12 @@ def test_bounds_parameters_invalid(bounds):
213213

214214
class TestLMNN(MetricTestCase):
215215
def test_iris(self):
216-
# Test both impls, if available.
217-
for LMNN_cls in set((LMNN, python_LMNN)):
218-
lmnn = LMNN_cls(k=5, learn_rate=1e-6, verbose=False)
219-
lmnn.fit(self.iris_points, self.iris_labels)
216+
lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False)
217+
lmnn.fit(self.iris_points, self.iris_labels)
220218

221-
csep = class_separation(lmnn.transform(self.iris_points),
222-
self.iris_labels)
223-
self.assertLess(csep, 0.25)
219+
csep = class_separation(lmnn.transform(self.iris_points),
220+
self.iris_labels)
221+
self.assertLess(csep, 0.25)
224222

225223
def test_loss_grad_lbfgs(self):
226224
"""Test gradient of loss function
@@ -336,7 +334,7 @@ def test_convergence_simple_example(capsys):
336334
# LMNN should converge on this simple example, which it did not with
337335
# this issue: https://github.com/metric-learn/metric-learn/issues/88
338336
X, y = make_classification(random_state=0)
339-
lmnn = python_LMNN(verbose=True)
337+
lmnn = LMNN(verbose=True)
340338
lmnn.fit(X, y)
341339
out, _ = capsys.readouterr()
342340
assert "LMNN converged with objective" in out
@@ -346,7 +344,7 @@ def test_no_twice_same_objective(capsys):
346344
# test that the objective function never has twice the same value
347345
# see https://github.com/metric-learn/metric-learn/issues/88
348346
X, y = make_classification(random_state=0)
349-
lmnn = python_LMNN(verbose=True)
347+
lmnn = LMNN(verbose=True)
350348
lmnn.fit(X, y)
351349
out, _ = capsys.readouterr()
352350
lines = re.split("\n+", out)

test/test_base_metric.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ def test_covariance(self):
1919
remove_spaces("Covariance(preprocessor=None)"))
2020

2121
def test_lmnn(self):
22-
self.assertRegexpMatches(
23-
str(metric_learn.LMNN()),
24-
r"(python_)?LMNN\(convergence_tol=0.001, init=None, k=3, "
25-
r"learn_rate=1e-07,\s+"
26-
r"max_iter=1000, min_iter=50, n_components=None, "
27-
r"num_dims='deprecated',\s+preprocessor=None, random_state=None, "
28-
r"regularization=0.5,\s+use_pca=True, verbose=False\)")
22+
self.assertEqual(
23+
remove_spaces(str(metric_learn.LMNN())),
24+
remove_spaces(
25+
"LMNN(convergence_tol=0.001, init=None, k=3, "
26+
"learn_rate=1e-07, "
27+
"max_iter=1000, min_iter=50, n_components=None, "
28+
"num_dims='deprecated', preprocessor=None, random_state=None, "
29+
"regularization=0.5, use_pca=True, verbose=False)"))
2930

3031
def test_nca(self):
3132
self.assertEqual(remove_spaces(str(metric_learn.NCA())),

0 commit comments

Comments
 (0)