|
25 | 25 | from .base_metric import MahalanobisMixin
|
26 | 26 |
|
27 | 27 |
|
28 |
| -# commonality between LMNN implementations |
29 |
| -class _base_LMNN(MahalanobisMixin, TransformerMixin): |
| 28 | +class LMNN(MahalanobisMixin, TransformerMixin): |
30 | 29 | def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
|
31 | 30 | learn_rate=1e-7, regularization=0.5, convergence_tol=0.001,
|
32 | 31 | use_pca=True, verbose=False, preprocessor=None,
|
@@ -114,11 +113,7 @@ def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
|
114 | 113 | self.n_components = n_components
|
115 | 114 | self.num_dims = num_dims
|
116 | 115 | self.random_state = random_state
|
117 |
| - super(_base_LMNN, self).__init__(preprocessor) |
118 |
| - |
119 |
| - |
120 |
| -# slower Python version |
121 |
| -class python_LMNN(_base_LMNN): |
| 116 | + super(LMNN, self).__init__(preprocessor) |
122 | 117 |
|
123 | 118 | def fit(self, X, y):
|
124 | 119 | if self.num_dims != 'deprecated':
|
@@ -344,40 +339,3 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):
|
344 | 339 | if weights is not None:
|
345 | 340 | return np.dot(Xab.T, Xab * weights[:,None])
|
346 | 341 | return np.dot(Xab.T, Xab)
|
347 |
| - |
348 |
| - |
349 |
| -try: |
350 |
| - # use the fast C++ version, if available |
351 |
| - from modshogun import LMNN as shogun_LMNN |
352 |
| - from modshogun import RealFeatures, MulticlassLabels |
353 |
| - |
354 |
| - class LMNN(_base_LMNN): |
355 |
| - """Large Margin Nearest Neighbor (LMNN) |
356 |
| -
|
357 |
| - Attributes |
358 |
| - ---------- |
359 |
| - n_iter_ : `int` |
360 |
| - The number of iterations the solver has run. |
361 |
| -
|
362 |
| - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) |
363 |
| - The learned linear transformation ``L``. |
364 |
| - """ |
365 |
| - |
366 |
| - def fit(self, X, y): |
367 |
| - X, y = self._prepare_inputs(X, y, dtype=float, |
368 |
| - ensure_min_samples=2) |
369 |
| - labels = MulticlassLabels(y) |
370 |
| - self._lmnn = shogun_LMNN(RealFeatures(X.T), labels, self.k) |
371 |
| - self._lmnn.set_maxiter(self.max_iter) |
372 |
| - self._lmnn.set_obj_threshold(self.convergence_tol) |
373 |
| - self._lmnn.set_regularization(self.regularization) |
374 |
| - self._lmnn.set_stepsize(self.learn_rate) |
375 |
| - if self.use_pca: |
376 |
| - self._lmnn.train() |
377 |
| - else: |
378 |
| - self._lmnn.train(np.eye(X.shape[1])) |
379 |
| - self.transformer_ = self._lmnn.get_linear_transform(X) |
380 |
| - return self |
381 |
| - |
382 |
| -except ImportError: |
383 |
| - LMNN = python_LMNN |
0 commit comments