diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 868ececa..de608218 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,6 @@ import numpy as np from numpy.linalg import LinAlgError +from inspect import signature from sklearn.datasets import make_spd_matrix from sklearn.decomposition import PCA from sklearn.utils import check_array @@ -22,6 +23,29 @@ def vector_norm(X): return np.linalg.norm(X, axis=1) +_CHECK_ARRAY_SUPPORTS_FORCE_ALL_FINITE = ( + 'force_all_finite' in signature(check_array).parameters) +_CHECK_X_Y_SUPPORTS_FORCE_ALL_FINITE = ( + 'force_all_finite' in signature(check_X_y).parameters) + + +def _check_array(*args, **kwargs): + """Local wrapper around `sklearn.utils.check_array` to deal with the change + from `force_all_finite` to `ensure_all_finite` in scikit-learn.""" + if not _CHECK_ARRAY_SUPPORTS_FORCE_ALL_FINITE and "force_all_finite" in kwargs: + kwargs = kwargs.copy() + kwargs["ensure_all_finite"] = kwargs.pop("force_all_finite") + return check_array(*args, **kwargs) + +def _check_X_y(*args, **kwargs): + """Local wrapper around `sklearn.utils.check_X_y` to deal with the change + from `force_all_finite` to `ensure_all_finite` in scikit-learn.""" + if not _CHECK_X_Y_SUPPORTS_FORCE_ALL_FINITE and "force_all_finite" in kwargs: + kwargs = kwargs.copy() + kwargs["ensure_all_finite"] = kwargs.pop("force_all_finite") + return check_X_y(*args, **kwargs) + + def check_input(input_data, y=None, preprocessor=None, type_of_inputs='classic', tuple_size=None, accept_sparse=False, dtype='numeric', order=None, @@ -115,14 +139,14 @@ def check_input(input_data, y=None, preprocessor=None, # We need to convert input_data into a numpy.ndarray if possible, before # any further checks or conversions, and deal with y if needed. Therefore - # we use check_array/check_X_y with fixed permissive arguments. + # we use the wrappers _check_array/_check_X_y with fixed permissive arguments. if y is None: - input_data = check_array(input_data, ensure_2d=False, allow_nd=True, + input_data = _check_array(input_data, ensure_2d=False, allow_nd=True, copy=False, force_all_finite=False, accept_sparse=True, dtype=None, ensure_min_features=0, ensure_min_samples=0) else: - input_data, y = check_X_y(input_data, y, ensure_2d=False, allow_nd=True, + input_data, y = _check_X_y(input_data, y, ensure_2d=False, allow_nd=True, copy=False, force_all_finite=False, accept_sparse=True, dtype=None, ensure_min_features=0, ensure_min_samples=0, @@ -165,9 +189,9 @@ def check_input_tuples(input_data, context, preprocessor, args_for_sk_checks, make_error_input(420, input_data, context) else: make_error_input(200, input_data, context) - input_data = check_array(input_data, allow_nd=True, ensure_2d=False, + input_data = _check_array(input_data, allow_nd=True, ensure_2d=False, **args_for_sk_checks) - # we need to check num_features because check_array does not check it + # we need to check num_features because _check_array does not check it # for 3D inputs: if args_for_sk_checks['ensure_min_features'] > 0: n_features = input_data.shape[2] @@ -180,7 +204,7 @@ def check_input_tuples(input_data, context, preprocessor, args_for_sk_checks, # normally we don't need to check_tuple_size too because tuple_size # shouldn't be able to be modified by any preprocessor if input_data.ndim != 3: - # we have to ensure this because check_array above does not + # we have to ensure this because _check_array above does not if preprocessor_has_been_applied: make_error_input(211, input_data, context) else: @@ -205,10 +229,10 @@ def check_input_classic(input_data, context, preprocessor, args_for_sk_checks): else: make_error_input(100, input_data, context) - input_data = check_array(input_data, allow_nd=True, ensure_2d=False, + input_data = _check_array(input_data, allow_nd=True, ensure_2d=False, **args_for_sk_checks) if input_data.ndim != 2: - # we have to ensure this because check_array above does not + # we have to ensure this because _check_array above does not if preprocessor_has_been_applied: make_error_input(111, input_data, context) else: @@ -317,7 +341,7 @@ def __init__(self, X): # format with arguments in check_input, and only this latter function # should return the appropriate errors). We do this only to have a numpy # array object which can be indexed by another numpy array object. - X = check_array(X, + X = _check_array(X, accept_sparse=True, dtype=None, force_all_finite=False, ensure_2d=False, allow_nd=True, @@ -514,7 +538,7 @@ def _initialize_components(n_components, input, y=None, init='auto', if isinstance(init, np.ndarray): # we copy the array, so that if we update the metric, we don't want to # update the init - init = check_array(init, copy=True) + init = _check_array(init, copy=True) # Assert that init.shape[1] = X.shape[1] if init.shape[1] != n_features: @@ -656,7 +680,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, if isinstance(init, np.ndarray): # we copy the array, so that if we update the metric, we don't want to # update the init - init = check_array(init, copy=True) + init = _check_array(init, copy=True) # Assert that init.shape[1] = n_features if init.shape != (n_features,) * 2: