diff --git a/onedal/covariance/covariance.py b/onedal/covariance/covariance.py index 4d34e5db2b..a22cb424d8 100644 --- a/onedal/covariance/covariance.py +++ b/onedal/covariance/covariance.py @@ -17,8 +17,7 @@ import numpy as np -from daal4py.sklearn._utils import daal_check_version, get_dtype -from onedal.utils import _check_array +from daal4py.sklearn._utils import daal_check_version from ..common._base import BaseEstimator from ..common.hyperparameters import get_hyperparameters @@ -94,10 +93,8 @@ def fit(self, X, y=None, queue=None): Returns the instance itself. """ policy = self._get_policy(queue, X) - X = _check_array(X, dtype=[np.float64, np.float32]) - X = _convert_to_supported(policy, X) - dtype = get_dtype(X) - params = self._get_onedal_params(dtype) + X_table = to_table(_convert_to_supported(policy, X)) + params = self._get_onedal_params(X_table.dtype) hparams = get_hyperparameters("covariance", "compute") if hparams is not None and not hparams.is_default: result = self._get_backend( @@ -107,7 +104,7 @@ def fit(self, X, y=None, queue=None): policy, params, hparams.backend, - to_table(X), + X_table, ) else: result = self._get_backend( diff --git a/onedal/covariance/incremental_covariance.py b/onedal/covariance/incremental_covariance.py index 00037ff63b..ed67dd5ec3 100644 --- a/onedal/covariance/incremental_covariance.py +++ b/onedal/covariance/incremental_covariance.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -import numpy as np -from daal4py.sklearn._utils import daal_check_version, get_dtype +from daal4py.sklearn._utils import daal_check_version from ..datatypes import _convert_to_supported, from_table, to_table -from ..utils import _check_array from .covariance import BaseEmpiricalCovariance @@ -95,19 +93,15 @@ def partial_fit(self, X, y=None, queue=None): self : object Returns the instance itself. """ - X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True) - self._queue = queue policy = self._get_policy(queue, X) - X = _convert_to_supported(policy, X) - + X_table = to_table(_convert_to_supported(policy, X)) if not hasattr(self, "_dtype"): - self._dtype = get_dtype(X) + self._dtype = X_table.dtype params = self._get_onedal_params(self._dtype) - table_X = to_table(X) self._partial_result = self._get_backend( "covariance", None, @@ -115,7 +109,7 @@ def partial_fit(self, X, y=None, queue=None): policy, params, self._partial_result, - table_X, + X_table, ) self._need_to_finalize = True diff --git a/onedal/covariance/tests/test_covariance.py b/onedal/covariance/tests/test_covariance.py index a55da035ab..5f604bb72c 100644 --- a/onedal/covariance/tests/test_covariance.py +++ b/onedal/covariance/tests/test_covariance.py @@ -25,7 +25,7 @@ def test_onedal_import_covariance(queue): from onedal.covariance import EmpiricalCovariance - X = np.array([[0, 1], [0, 1]]) + X = np.array([[0, 1], [0, 1]], dtype=np.float64) result = EmpiricalCovariance().fit(X, queue=queue) expected_covariance = np.array([[0, 0], [0, 0]]) expected_means = np.array([0, 1]) @@ -33,7 +33,7 @@ def test_onedal_import_covariance(queue): assert_allclose(expected_covariance, result.covariance_) assert_allclose(expected_means, result.location_) - X = np.array([[1, 2], [3, 6]]) + X = np.array([[1, 2], [3, 6]], dtype=np.float64) result = EmpiricalCovariance().fit(X, queue=queue) expected_covariance = np.array([[2, 4], [4, 8]]) expected_means = np.array([2, 4]) @@ -41,7 +41,7 @@ def test_onedal_import_covariance(queue): assert_allclose(expected_covariance, result.covariance_) assert_allclose(expected_means, result.location_) - X = np.array([[1, 2], [3, 6]]) + X = np.array([[1, 2], [3, 6]], dtype=np.float64) result = EmpiricalCovariance(bias=True).fit(X, queue=queue) expected_covariance = np.array([[1, 2], [2, 4]]) expected_means = np.array([2, 4]) diff --git a/sklearnex/covariance/incremental_covariance.py b/sklearnex/covariance/incremental_covariance.py index 89ed92b601..24874a9c40 100644 --- a/sklearnex/covariance/incremental_covariance.py +++ b/sklearnex/covariance/incremental_covariance.py @@ -30,21 +30,17 @@ from onedal.covariance import ( IncrementalEmpiricalCovariance as onedal_IncrementalEmpiricalCovariance, ) -from sklearnex import config_context +from onedal.utils._array_api import _is_numpy_namespace from .._device_offload import dispatch, wrap_output_data from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters from ..metrics import pairwise_distances from ..utils._array_api import get_namespace +from ..utils.validation import validate_data if sklearn_check_version("1.2"): from sklearn.utils._param_validation import Interval -if sklearn_check_version("1.6"): - from sklearn.utils.validation import validate_data -else: - validate_data = BaseEstimator._validate_data - @control_n_jobs(decorated_methods=["partial_fit", "fit", "_onedal_finalize_fit"]) class IncrementalEmpiricalCovariance(IntelEstimator, BaseEstimator): @@ -152,8 +148,9 @@ def _onedal_finalize_fit(self, queue=None): if not daal_check_version((2024, "P", 400)) and self.assume_centered: location = self._onedal_estimator.location_[None, :] - self._onedal_estimator.covariance_ += np.dot(location.T, location) - self._onedal_estimator.location_ = np.zeros_like(np.squeeze(location)) + lp, _ = get_namespace(location) + self._onedal_estimator.covariance_ += lp.dot(location.T, location) + self._onedal_estimator.location_ = lp.zeros_like(lp.squeeze(location)) if self.store_precision: self.precision_ = linalg.pinvh( self._onedal_estimator.covariance_, check_finite=False @@ -187,8 +184,8 @@ def _onedal_partial_fit(self, X, queue=None, check_input=True): first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0 - # finite check occurs on onedal side if check_input: + xp, _ = get_namespace(X) if sklearn_check_version("1.2"): self._validate_params() @@ -196,17 +193,15 @@ def _onedal_partial_fit(self, X, queue=None, check_input=True): X = validate_data( self, X, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], reset=first_pass, copy=self.copy, - force_all_finite=False, ) else: X = check_array( X, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], copy=self.copy, - force_all_finite=False, ) onedal_params = { @@ -239,16 +234,16 @@ def score(self, X_test, y=None): X = validate_data( self, X_test, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], reset=False, ) else: X = check_array( X_test, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], ) - if "numpy" not in xp.__name__: + if not _is_numpy_namespace(xp): location = xp.asarray(location, device=X_test.device) # depending on the sklearn version, check_array # and validate_data will return only numpy arrays @@ -337,19 +332,16 @@ def _onedal_fit(self, X, queue=None): if sklearn_check_version("1.2"): self._validate_params() - # finite check occurs on onedal side + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data( self, X, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], copy=self.copy, - force_all_finite=False, ) else: - X = check_array( - X, dtype=[np.float64, np.float32], copy=self.copy, force_all_finite=False - ) + X = check_array(X, dtype=[xp.float64, xp.float32], copy=self.copy) self.n_features_in_ = X.shape[1] self.batch_size_ = self.batch_size if self.batch_size else 5 * self.n_features_in_ @@ -378,8 +370,8 @@ def mahalanobis(self, X): # pairwise_distances will check n_features (via n_feature matching with # self.location_) , and will check for finiteness via check array # check_feature_names will match _validate_data functionally - location = self.location_[np.newaxis, :] - if "numpy" not in xp.__name__: + location = self.location_[None, :] + if not _is_numpy_namespace(xp): # Guarantee that inputs to pairwise_distances match in type and location location = xp.asarray(location, device=X.device) diff --git a/sklearnex/preview/covariance/covariance.py b/sklearnex/preview/covariance/covariance.py index 04bdc0be8d..c75f1bee0b 100644 --- a/sklearnex/preview/covariance/covariance.py +++ b/sklearnex/preview/covariance/covariance.py @@ -16,8 +16,7 @@ import warnings -import numpy as np -from scipy import sparse as sp +import scipy.sparse as sp from sklearn.covariance import EmpiricalCovariance as _sklearn_EmpiricalCovariance from sklearn.utils import check_array @@ -30,11 +29,8 @@ from ..._device_offload import dispatch, wrap_output_data from ..._utils import PatchingConditionsChain, register_hyperparameters - -if sklearn_check_version("1.6"): - from sklearn.utils.validation import validate_data -else: - validate_data = _sklearn_EmpiricalCovariance._validate_data +from ...utils._array_api import get_namespace +from ...utils.validation import validate_data @register_hyperparameters({"fit": get_hyperparameters("covariance", "compute")}) @@ -51,14 +47,23 @@ def _save_attributes(self): assert hasattr(self, "_onedal_estimator") if not daal_check_version((2024, "P", 400)) and self.assume_centered: location = self._onedal_estimator.location_[None, :] - self._onedal_estimator.covariance_ += np.dot(location.T, location) - self._onedal_estimator.location_ = np.zeros_like(np.squeeze(location)) + lp, _ = get_namespace(location) + self._onedal_estimator.covariance_ += lp.dot(location.T, location) + self._onedal_estimator.location_ = lp.zeros_like(lp.squeeze(location)) self._set_covariance(self._onedal_estimator.covariance_) self.location_ = self._onedal_estimator.location_ _onedal_covariance = staticmethod(onedal_EmpiricalCovariance) def _onedal_fit(self, X, queue=None): + xp, _ = get_namespace(X) + if sklearn_check_version("1.2"): + self._validate_params() + if sklearn_check_version("1.0"): + X = validate_data(self, X, dtype=[xp.float64, xp.float32]) + else: + X = check_array(X) + if X.shape[0] == 1: warnings.warn( "Only one sample available. You may want to reshape your data array" @@ -93,13 +98,6 @@ def _onedal_supported(self, method_name, *data): _onedal_gpu_supported = _onedal_supported def fit(self, X, y=None): - if sklearn_check_version("1.2"): - self._validate_params() - if sklearn_check_version("0.23"): - X = validate_data(self, X, force_all_finite=False) - else: - X = check_array(X, force_all_finite=False) - dispatch( self, "fit", @@ -113,10 +111,10 @@ def fit(self, X, y=None): return self # expose sklearnex pairwise_distances if mahalanobis distance eventually supported - @wrap_output_data def mahalanobis(self, X): + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): - X = validate_data(self, X, reset=False) + X = validate_data(self, X, reset=False, dtype=[xp.float64, xp.float32]) else: X = check_array(X) @@ -124,10 +122,10 @@ def mahalanobis(self, X): with config_context(assume_finite=True): # compute mahalanobis distances dist = pairwise_distances( - X, self.location_[np.newaxis, :], metric="mahalanobis", VI=precision + X, self.location_[None, :], metric="mahalanobis", VI=precision ) - return np.reshape(dist, (len(X),)) ** 2 + return xp.reshape(dist, (len(X),)) ** 2 error_norm = wrap_output_data(_sklearn_EmpiricalCovariance.error_norm) score = wrap_output_data(_sklearn_EmpiricalCovariance.score) diff --git a/sklearnex/spmd/covariance/covariance.py b/sklearnex/spmd/covariance/covariance.py index 3b2f704932..f103932457 100644 --- a/sklearnex/spmd/covariance/covariance.py +++ b/sklearnex/spmd/covariance/covariance.py @@ -14,8 +14,11 @@ # limitations under the License. # ============================================================================== -from onedal.spmd.covariance import EmpiricalCovariance +from onedal.spmd.covariance import EmpiricalCovariance as onedal_EmpiricalCovariance -# TODO: -# Currently it uses `onedal` module interface. -# Add sklearnex dispatching. +from ...preview.covariance import EmpiricalCovariance as EmpiricalCovariance_Batch + + +class EmpiricalCovariance(EmpiricalCovariance_Batch): + __doc__ = EmpiricalCovariance_Batch.__doc__ + _onedal_covariance = staticmethod(onedal_EmpiricalCovariance) diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index aa92df1d6a..2d52a545cf 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -35,10 +35,14 @@ get_dataframes_and_queues, ) from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available -from onedal.utils._array_api import _get_sycl_namespace from onedal.utils._dpep_helpers import dpctl_available, dpnp_available from sklearnex import config_context -from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES +from sklearnex.tests.utils import ( + PATCHED_FUNCTIONS, + PATCHED_MODELS, + SPECIAL_INSTANCES, + DummyEstimator, +) from sklearnex.utils._array_api import get_namespace if dpctl_available: @@ -131,41 +135,6 @@ def gen_functions(functions): ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray} -if _is_dpc_backend: - - from sklearn.utils.validation import check_is_fitted - - from onedal.datatypes import from_table, to_table - - class DummyEstimatorWithTableConversions(BaseEstimator): - - def fit(self, X, y=None): - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - y_table = to_table(y) - # The presence of the fitted attributes (ending with a trailing - # underscore) is required for the correct check. The cleanup of - # the memory will occur at the estimator instance deletion. - self.x_attr_ = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - self.y_attr_ = from_table( - y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return self - - def predict(self, X): - # Checks if the estimator is fitted by verifying the presence of - # fitted attributes (ending with a trailing underscore). - check_is_fitted(self) - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - returned_X = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return returned_X - - def gen_clsf_data(n_samples, n_features, dtype=None): data, label = make_classification( n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777 @@ -369,7 +338,7 @@ def test_table_conversions_memory_leaks(dataframe, queue, order, data_shape, dty pytest.skip("SYCL device memory leak check requires the level zero sysman") _kfold_function_template( - DummyEstimatorWithTableConversions, + DummyEstimator, dataframe, data_shape, queue, diff --git a/sklearnex/tests/utils/__init__.py b/sklearnex/tests/utils/__init__.py index 60ca67fa37..db728fe913 100644 --- a/sklearnex/tests/utils/__init__.py +++ b/sklearnex/tests/utils/__init__.py @@ -21,6 +21,7 @@ SPECIAL_INSTANCES, UNPATCHED_FUNCTIONS, UNPATCHED_MODELS, + DummyEstimator, _get_processor_info, call_method, gen_dataset, @@ -39,6 +40,7 @@ "gen_models_info", "gen_dataset", "sklearn_clone_dict", + "DummyEstimator", ] _IS_INTEL = "GenuineIntel" in _get_processor_info() diff --git a/sklearnex/tests/utils/base.py b/sklearnex/tests/utils/base.py index 1949519585..706de39a91 100755 --- a/sklearnex/tests/utils/base.py +++ b/sklearnex/tests/utils/base.py @@ -32,8 +32,11 @@ ) from sklearn.datasets import load_diabetes, load_iris from sklearn.neighbors._base import KNeighborsMixin +from sklearn.utils.validation import check_is_fitted +from onedal.datatypes import from_table, to_table from onedal.tests.utils._dataframes_support import _convert_to_dataframe +from onedal.utils._array_api import _get_sycl_namespace from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics from sklearnex.linear_model import LogisticRegression @@ -369,3 +372,41 @@ def _get_processor_info(): ) return proc + + +class DummyEstimator(BaseEstimator): + + def fit(self, X, y=None): + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + y_table = to_table(y) + # The presence of the fitted attributes (ending with a trailing + # underscore) is required for the correct check. The cleanup of + # the memory will occur at the estimator instance deletion. + if sua_iface: + self.x_attr_ = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + self.y_attr_ = from_table( + y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + self.x_attr = from_table(X_table) + self.y_attr = from_table(y_table) + + return self + + def predict(self, X): + # Checks if the estimator is fitted by verifying the presence of + # fitted attributes (ending with a trailing underscore). + check_is_fitted(self) + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + if sua_iface: + returned_X = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + returned_X = from_table(X_table) + + return returned_X diff --git a/sklearnex/utils/__init__.py b/sklearnex/utils/__init__.py index 4c3fe21154..686e089adf 100755 --- a/sklearnex/utils/__init__.py +++ b/sklearnex/utils/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # =============================================================================== -from .validation import _assert_all_finite +from .validation import assert_all_finite -__all__ = ["_assert_all_finite"] +__all__ = ["assert_all_finite"] diff --git a/sklearnex/utils/tests/test_finite.py b/sklearnex/utils/tests/test_finite.py deleted file mode 100644 index 7d83667699..0000000000 --- a/sklearnex/utils/tests/test_finite.py +++ /dev/null @@ -1,89 +0,0 @@ -# ============================================================================== -# Copyright 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import time - -import numpy as np -import numpy.random as rand -import pytest -from numpy.testing import assert_raises - -from sklearnex.utils import _assert_all_finite - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -def test_sum_infinite_actually_finite(dtype, shape, allow_nan): - X = np.empty(shape, dtype=dtype) - X.fill(np.finfo(dtype).max) - _assert_all_finite(X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 65539, # 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_location(dtype, shape, allow_nan, check, seed): - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X.reshape((-1,))[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_shape_and_location(dtype, allow_nan, check, seed): - lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) diff --git a/sklearnex/utils/tests/test_validation.py b/sklearnex/utils/tests/test_validation.py new file mode 100644 index 0000000000..70da28dbce --- /dev/null +++ b/sklearnex/utils/tests/test_validation.py @@ -0,0 +1,236 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time + +import numpy as np +import numpy.random as rand +import pytest + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) +from sklearnex import config_context +from sklearnex.tests.utils import DummyEstimator, gen_dataset +from sklearnex.utils.validation import _check_sample_weight, validate_data + +# array_api support starts in sklearn 1.2, and array_api_strict conformance starts in sklearn 1.3 +_dataframes_supported = ( + "numpy,pandas" + + (",dpctl" if sklearn_check_version("1.2") else "") + + (",array_api" if sklearn_check_version("1.3") else "") +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [2**16 + 3], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +def test_sum_infinite_actually_finite(dtype, shape, ensure_all_finite): + est = DummyEstimator() + X = np.empty(shape, dtype=dtype) + X.fill(np.finfo(dtype).max) + X = np.atleast_2d(X) + X_array = validate_data(est, X, ensure_all_finite=ensure_all_finite) + assert type(X_array) == type(X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [2**16 + 3], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_location( + dataframe, queue, dtype, shape, ensure_all_finite, check, seed +): + est = DummyEstimator() + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X.reshape((-1,))[loc] = float(check) + + # column heavy pandas inputs are very slow in sklearn's check_array even without + # the finite check, just transpose inputs to guarantee fast processing in tests + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}" + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_shape_and_location( + dataframe, queue, dtype, ensure_all_finite, check, seed +): + est = DummyEstimator() + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X[loc] = float(check) + + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}." + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test__check_sample_weight_random_shape_and_location( + dataframe, queue, dtype, check, seed +): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + shape = (rand.randint(lb, ub), 2) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + sample_weight = rand.uniform(high=np.finfo(dtype).max, size=shape[0]).astype(dtype) + + if check: + loc = rand.randint(0, shape[0] - 1) + sample_weight[loc] = float(check) + + X = _convert_to_dataframe( + X, + target_df=dataframe, + sycl_queue=queue, + ) + sample_weight = _convert_to_dataframe( + sample_weight, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + if check is None: + X_out = _check_sample_weight(sample_weight, X) + if dispatch: + assert type(X_out) == type(X) + else: + assert isinstance(X_out, np.ndarray) + else: + msg_err = "Input sample_weight contains [NaN|infinity]" + with pytest.raises(ValueError, match=msg_err): + X_out = _check_sample_weight(sample_weight, X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_output(dtype, dataframe, queue): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + est = DummyEstimator() + X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0] + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + X_out, y_out = validate_data(est, X, y) + # check sklearn validate_data operations work underneath + X_array = validate_data(est, X, reset=False) + + if dispatch: + assert type(X) == type( + X_array + ), f"validate_data converted {type(X)} to {type(X_array)}" + assert type(X) == type(X_out), f"from_array converted {type(X)} to {type(X_out)}" + else: + # array_api_strict from sklearn < 1.2 and pandas will convert to numpy arrays + assert isinstance(X_array, np.ndarray) + assert isinstance(X_out, np.ndarray) diff --git a/sklearnex/utils/validation.py b/sklearnex/utils/validation.py index b2d1898643..c2ba2c1dc5 100755 --- a/sklearnex/utils/validation.py +++ b/sklearnex/utils/validation.py @@ -14,4 +14,162 @@ # limitations under the License. # =============================================================================== -from daal4py.sklearn.utils.validation import _assert_all_finite +import numbers + +import scipy.sparse as sp +from sklearn.utils.validation import _assert_all_finite as _sklearn_assert_all_finite +from sklearn.utils.validation import _num_samples, check_array, check_non_negative + +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version + +from ._array_api import get_namespace + +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data as _sklearn_validate_data + + _finite_keyword = "ensure_all_finite" + +else: + from sklearn.base import BaseEstimator + + _sklearn_validate_data = BaseEstimator._validate_data + _finite_keyword = "force_all_finite" + + +if daal_check_version((2024, "P", 700)): + from onedal.utils.validation import _assert_all_finite as _onedal_assert_all_finite + + def _onedal_supported_format(X, xp=None): + # array_api does not have a `strides` or `flags` attribute for testing memory + # order. When dlpack support is brought in for oneDAL, general support for + # array_api can be enabled and the hasattr check can be removed. + # _onedal_supported_format is therefore conservative in verifying attributes and + # does not support array_api. This will block onedal_assert_all_finite from being + # used for array_api inputs but will allow dpnp ndarrays and dpctl tensors. + return X.dtype in [xp.float32, xp.float64] and hasattr(X, "flags") + +else: + from daal4py.utils.validation import _assert_all_finite as _onedal_assert_all_finite + from onedal.utils._array_api import _is_numpy_namespace + + def _onedal_supported_format(X, xp=None): + # daal4py _assert_all_finite only supports numpy namespaces, use internally- + # defined check to validate inputs, otherwise offload to sklearn + return X.dtype in [xp.float32, xp.float64] and _is_numpy_namespace(xp) + + +def _sklearnex_assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + # size check is an initial match to daal4py for performance reasons, can be + # optimized later + xp, _ = get_namespace(X) + if X.size < 32768 or not _onedal_supported_format(X, xp): + if sklearn_check_version("1.1"): + _sklearn_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + else: + _sklearn_assert_all_finite(X, allow_nan=allow_nan) + else: + _onedal_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + + +def assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + _sklearnex_assert_all_finite( + X.data if sp.issparse(X) else X, + allow_nan=allow_nan, + input_name=input_name, + ) + + +def validate_data( + _estimator, + /, + X="no_validation", + y="no_validation", + **kwargs, +): + # force finite check to not occur in sklearn, default is True + # `ensure_all_finite` is the most up-to-date keyword name in sklearn + # _finite_keyword provides backward compatability for `force_all_finite` + ensure_all_finite = kwargs.pop("ensure_all_finite", True) + kwargs[_finite_keyword] = False + + out = _sklearn_validate_data( + _estimator, + X=X, + y=y, + **kwargs, + ) + if ensure_all_finite: + # run local finite check + allow_nan = ensure_all_finite == "allow-nan" + arg = iter(out if isinstance(out, tuple) else (out,)) + if not isinstance(X, str) or X != "no_validation": + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X") + if not (y is None or isinstance(y, str) and y == "no_validation"): + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y") + return out + + +def _check_sample_weight( + sample_weight, X, dtype=None, copy=False, only_non_negative=False +): + + n_samples = _num_samples(X) + xp, _ = get_namespace(X) + + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = xp.float64 + + if sample_weight is None: + if hasattr(X, "device"): + sample_weight = xp.ones(n_samples, dtype=dtype, device=X.device) + else: + sample_weight = xp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, numbers.Number): + if hasattr(X, "device"): + sample_weight = xp.full( + n_samples, sample_weight, dtype=dtype, device=X.device + ) + else: + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) + else: + if dtype is None: + dtype = [xp.float64, xp.float32] + + params = { + "accept_sparse": False, + "ensure_2d": False, + "dtype": dtype, + "order": "C", + "copy": copy, + _finite_keyword: False, + } + if sklearn_check_version("1.1"): + params["input_name"] = "sample_weight" + + sample_weight = check_array(sample_weight, **params) + assert_all_finite(sample_weight, input_name="sample_weight") + + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape != (n_samples,): + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if only_non_negative: + check_non_negative(sample_weight, "`sample_weight`") + + return sample_weight