Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from numpy.linalg import LinAlgError
from inspect import signature
from sklearn.datasets import make_spd_matrix
from sklearn.decomposition import PCA
from sklearn.utils import check_array
Expand All @@ -22,6 +23,29 @@ def vector_norm(X):
return np.linalg.norm(X, axis=1)


_CHECK_ARRAY_SUPPORTS_FORCE_ALL_FINITE = (
'force_all_finite' in signature(check_array).parameters)
_CHECK_X_Y_SUPPORTS_FORCE_ALL_FINITE = (
'force_all_finite' in signature(check_X_y).parameters)


def _check_array(*args, **kwargs):
"""Local wrapper around `sklearn.utils.check_array` to deal with the change
from `force_all_finite` to `ensure_all_finite` in scikit-learn."""
if not _CHECK_ARRAY_SUPPORTS_FORCE_ALL_FINITE and "force_all_finite" in kwargs:
kwargs = kwargs.copy()
kwargs["ensure_all_finite"] = kwargs.pop("force_all_finite")
return check_array(*args, **kwargs)

def _check_X_y(*args, **kwargs):
"""Local wrapper around `sklearn.utils.check_X_y` to deal with the change
from `force_all_finite` to `ensure_all_finite` in scikit-learn."""
if not _CHECK_X_Y_SUPPORTS_FORCE_ALL_FINITE and "force_all_finite" in kwargs:
kwargs = kwargs.copy()
kwargs["ensure_all_finite"] = kwargs.pop("force_all_finite")
return check_X_y(*args, **kwargs)


def check_input(input_data, y=None, preprocessor=None,
type_of_inputs='classic', tuple_size=None, accept_sparse=False,
dtype='numeric', order=None,
Expand Down Expand Up @@ -115,14 +139,14 @@ def check_input(input_data, y=None, preprocessor=None,

# We need to convert input_data into a numpy.ndarray if possible, before
# any further checks or conversions, and deal with y if needed. Therefore
# we use check_array/check_X_y with fixed permissive arguments.
# we use the wrappers _check_array/_check_X_y with fixed permissive arguments.
if y is None:
input_data = check_array(input_data, ensure_2d=False, allow_nd=True,
input_data = _check_array(input_data, ensure_2d=False, allow_nd=True,
copy=False, force_all_finite=False,
accept_sparse=True, dtype=None,
ensure_min_features=0, ensure_min_samples=0)
else:
input_data, y = check_X_y(input_data, y, ensure_2d=False, allow_nd=True,
input_data, y = _check_X_y(input_data, y, ensure_2d=False, allow_nd=True,
copy=False, force_all_finite=False,
accept_sparse=True, dtype=None,
ensure_min_features=0, ensure_min_samples=0,
Expand Down Expand Up @@ -165,9 +189,9 @@ def check_input_tuples(input_data, context, preprocessor, args_for_sk_checks,
make_error_input(420, input_data, context)
else:
make_error_input(200, input_data, context)
input_data = check_array(input_data, allow_nd=True, ensure_2d=False,
input_data = _check_array(input_data, allow_nd=True, ensure_2d=False,
**args_for_sk_checks)
# we need to check num_features because check_array does not check it
# we need to check num_features because _check_array does not check it
# for 3D inputs:
if args_for_sk_checks['ensure_min_features'] > 0:
n_features = input_data.shape[2]
Expand All @@ -180,7 +204,7 @@ def check_input_tuples(input_data, context, preprocessor, args_for_sk_checks,
# normally we don't need to check_tuple_size too because tuple_size
# shouldn't be able to be modified by any preprocessor
if input_data.ndim != 3:
# we have to ensure this because check_array above does not
# we have to ensure this because _check_array above does not
if preprocessor_has_been_applied:
make_error_input(211, input_data, context)
else:
Expand All @@ -205,10 +229,10 @@ def check_input_classic(input_data, context, preprocessor, args_for_sk_checks):
else:
make_error_input(100, input_data, context)

input_data = check_array(input_data, allow_nd=True, ensure_2d=False,
input_data = _check_array(input_data, allow_nd=True, ensure_2d=False,
**args_for_sk_checks)
if input_data.ndim != 2:
# we have to ensure this because check_array above does not
# we have to ensure this because _check_array above does not
if preprocessor_has_been_applied:
make_error_input(111, input_data, context)
else:
Expand Down Expand Up @@ -317,7 +341,7 @@ def __init__(self, X):
# format with arguments in check_input, and only this latter function
# should return the appropriate errors). We do this only to have a numpy
# array object which can be indexed by another numpy array object.
X = check_array(X,
X = _check_array(X,
accept_sparse=True, dtype=None,
force_all_finite=False,
ensure_2d=False, allow_nd=True,
Expand Down Expand Up @@ -514,7 +538,7 @@ def _initialize_components(n_components, input, y=None, init='auto',
if isinstance(init, np.ndarray):
# we copy the array, so that if we update the metric, we don't want to
# update the init
init = check_array(init, copy=True)
init = _check_array(init, copy=True)

# Assert that init.shape[1] = X.shape[1]
if init.shape[1] != n_features:
Expand Down Expand Up @@ -656,7 +680,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
if isinstance(init, np.ndarray):
# we copy the array, so that if we update the metric, we don't want to
# update the init
init = check_array(init, copy=True)
init = _check_array(init, copy=True)

# Assert that init.shape[1] = n_features
if init.shape != (n_features,) * 2:
Expand Down