Skip to content

Commit

Permalink
attempt at fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Nov 28, 2024
1 parent 05ef656 commit 68ffc45
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 31 deletions.
6 changes: 5 additions & 1 deletion sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def _save_attributes(self):
setattr(self, option + "_", getattr(self._onedal_estimator, option))

def __getattr__(self, attr):
is_deprecated_attr = attr in self._onedal_estimator.options
is_deprecated_attr = (
attr in self._onedal_estimator.options
if hasattr(self, "_onedal_estimator")
else False
)
if is_deprecated_attr:
warnings.warn(
"Result attributes without a trailing underscore were deprecated in version 2025.1 and will be removed in 2026.0"
Expand Down
47 changes: 17 additions & 30 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_array, gen_batches
from sklearn.utils.validation import _check_sample_weight

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
Expand All @@ -34,10 +33,8 @@
import numbers
import warnings

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
else:
validate_data = BaseEstimator._validate_data
from ..utils._array_api import get_namespace
from ..utils.validation import _check_sample_weight, validate_data


@control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"])
Expand Down Expand Up @@ -153,12 +150,7 @@ class IncrementalBasicStatistics(IntelEstimator, BaseEstimator):
}

def __init__(self, result_options="all", batch_size=None):
if result_options == "all":
self.result_options = (
self._onedal_incremental_basic_statistics.get_all_result_options()
)
else:
self.result_options = result_options
self.result_options = result_options
self._need_to_finalize = False
self.batch_size = batch_size

Expand All @@ -171,14 +163,6 @@ def _onedal_supported(self, method_name, *data):
_onedal_cpu_supported = _onedal_supported
_onedal_gpu_supported = _onedal_supported

def _get_onedal_result_options(self, options):
if isinstance(options, list):
onedal_options = "|".join(self.result_options)
else:
onedal_options = options
assert isinstance(onedal_options, str)
return options

def _onedal_finalize_fit(self, queue=None):
assert hasattr(self, "_onedal_estimator")
self._onedal_estimator.finalize_fit(queue=queue)
Expand All @@ -188,6 +172,7 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru
first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0

if check_input:
xp, _ = get_namespace(X)
if sklearn_check_version("1.0"):
X = validate_data(
self,
Expand All @@ -210,27 +195,28 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru
else:
self.n_samples_seen_ += X.shape[0]

onedal_params = {
"result_options": self._get_onedal_result_options(self.result_options)
}
if not hasattr(self, "_onedal_estimator"):
self._onedal_estimator = self._onedal_incremental_basic_statistics(
**onedal_params
result_options=self.result_options
)
self._onedal_estimator.partial_fit(X, weights=sample_weight, queue=queue)

self._onedal_estimator.partial_fit(X, sample_weight=sample_weight, queue=queue)
self._need_to_finalize = True

def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()

xp, _ = get_namespace(X)
if sklearn_check_version("1.0"):
X = validate_data(self, X, dtype=[np.float64, np.float32])
X = validate_data(self, X, dtype=[xp.float64, xp.float32])
else:
X = check_array(X, dtype=[np.float64, np.float32])
X = check_array(X, dtype=[xp.float64, xp.float32])

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
sample_weight = _check_sample_weight(
sample_weight, X, dtype=[xp.float64, xp.float32]
)

n_samples, n_features = X.shape
if self.batch_size is None:
Expand All @@ -256,11 +242,12 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
return self

def __getattr__(self, attr):
result_options = self.__dict__["result_options"]
sattr = attr.removesuffix("_")
is_statistic_attr = (
isinstance(result_options, str) and (sattr == result_options)
) or (isinstance(result_options, list) and (sattr in result_options))
sattr in self._onedal_estimator.options
if hasattr(self, "_onedal_estimator")
else False
)
if is_statistic_attr:
if self._need_to_finalize:
self._onedal_finalize_fit()
Expand Down

0 comments on commit 68ffc45

Please sign in to comment.