From 68ffc45f7c0664871eb6dd39e9ec9dce62b82c39 Mon Sep 17 00:00:00 2001 From: "Faust, Ian" Date: Thu, 28 Nov 2024 09:02:45 +0100 Subject: [PATCH] attempt at fixing --- .../basic_statistics/basic_statistics.py | 6 ++- .../incremental_basic_statistics.py | 47 +++++++------------ 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 092bc0974d..035a2160f6 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -130,7 +130,11 @@ def _save_attributes(self): setattr(self, option + "_", getattr(self._onedal_estimator, option)) def __getattr__(self, attr): - is_deprecated_attr = attr in self._onedal_estimator.options + is_deprecated_attr = ( + attr in self._onedal_estimator.options + if hasattr(self, "_onedal_estimator") + else False + ) if is_deprecated_attr: warnings.warn( "Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0" diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index bafd0d8a57..e0c0717142 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -17,7 +17,6 @@ import numpy as np from sklearn.base import BaseEstimator from sklearn.utils import check_array, gen_batches -from sklearn.utils.validation import _check_sample_weight from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version @@ -34,10 +33,8 @@ import numbers import warnings -if sklearn_check_version("1.6"): - from sklearn.utils.validation import validate_data -else: - validate_data = BaseEstimator._validate_data +from ..utils._array_api import get_namespace +from ..utils.validation import _check_sample_weight, validate_data @control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"]) @@ -153,12 +150,7 @@ class IncrementalBasicStatistics(IntelEstimator, BaseEstimator): } def __init__(self, result_options="all", batch_size=None): - if result_options == "all": - self.result_options = ( - self._onedal_incremental_basic_statistics.get_all_result_options() - ) - else: - self.result_options = result_options + self.result_options = result_options self._need_to_finalize = False self.batch_size = batch_size @@ -171,14 +163,6 @@ def _onedal_supported(self, method_name, *data): _onedal_cpu_supported = _onedal_supported _onedal_gpu_supported = _onedal_supported - def _get_onedal_result_options(self, options): - if isinstance(options, list): - onedal_options = "|".join(self.result_options) - else: - onedal_options = options - assert isinstance(onedal_options, str) - return options - def _onedal_finalize_fit(self, queue=None): assert hasattr(self, "_onedal_estimator") self._onedal_estimator.finalize_fit(queue=queue) @@ -188,6 +172,7 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0 if check_input: + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data( self, @@ -210,27 +195,28 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru else: self.n_samples_seen_ += X.shape[0] - onedal_params = { - "result_options": self._get_onedal_result_options(self.result_options) - } if not hasattr(self, "_onedal_estimator"): self._onedal_estimator = self._onedal_incremental_basic_statistics( - **onedal_params + result_options=self.result_options ) - self._onedal_estimator.partial_fit(X, weights=sample_weight, queue=queue) + + self._onedal_estimator.partial_fit(X, sample_weight=sample_weight, queue=queue) self._need_to_finalize = True def _onedal_fit(self, X, sample_weight=None, queue=None): if sklearn_check_version("1.2"): self._validate_params() + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): - X = validate_data(self, X, dtype=[np.float64, np.float32]) + X = validate_data(self, X, dtype=[xp.float64, xp.float32]) else: - X = check_array(X, dtype=[np.float64, np.float32]) + X = check_array(X, dtype=[xp.float64, xp.float32]) if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) + sample_weight = _check_sample_weight( + sample_weight, X, dtype=[xp.float64, xp.float32] + ) n_samples, n_features = X.shape if self.batch_size is None: @@ -256,11 +242,12 @@ def _onedal_fit(self, X, sample_weight=None, queue=None): return self def __getattr__(self, attr): - result_options = self.__dict__["result_options"] sattr = attr.removesuffix("_") is_statistic_attr = ( - isinstance(result_options, str) and (sattr == result_options) - ) or (isinstance(result_options, list) and (sattr in result_options)) + sattr in self._onedal_estimator.options + if hasattr(self, "_onedal_estimator") + else False + ) if is_statistic_attr: if self._need_to_finalize: self._onedal_finalize_fit()