Skip to content

Commit

Permalink
fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Nov 28, 2024
1 parent cfeb2c5 commit d3a69c6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def fit(self, data, sample_weight=None, queue=None):

is_csr = _is_csr(data)

data, sample_weight = _convert_to_supported(policy, data, sample_weight)
is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight)
data, sample_weight = to_table(
*_convert_to_supported(policy, data, sample_weight)
)

dtype = data_table.dtype
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, data_table.dtype)
result = module.compute(policy, params, data_table, weights_table)
params = self._get_onedal_params(is_csr, data.dtype)
result = module.compute(policy, params, data, sample_weight)

for opt in self.options:
value = from_table(getattr(result, opt))[:, 0] # two-dimensional table [n, 1]
Expand Down
2 changes: 1 addition & 1 deletion onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def partial_fit(self, X, sample_weight=None, queue=None):
"""
self._queue = queue
policy = self._get_policy(queue, X)
X, sample_weight = to_table(_convert_to_supported(policy, X, sample_weight))
X, sample_weight = to_table(*_convert_to_supported(policy, X, sample_weight))

if not hasattr(self, "_onedal_params"):
self._onedal_params = self._get_onedal_params(False, dtype=X.dtype)
Expand Down

0 comments on commit d3a69c6

Please sign in to comment.