Skip to content

Commit

Permalink
FEAT(forest): hyperparameter support for RF clf inference (#1751)
Browse files Browse the repository at this point in the history
* expand forest classes with hyperparams

* fix build errors

* add some version checks

* black format

* version check for hyperparameters backend

* fixup

* fixup

* fixup

* fixup

* add version check for hyperparams

* Add daal_require_version_wrapper

* simplify obtaining hparams

* fix version checks

* Fix ONEDAL_VERSION in comment

Co-authored-by: Victoriya Fedotova <[email protected]>

* Add version comment to endif

Co-authored-by: Victoriya Fedotova <[email protected]>

* add block_size_multiplier hyperparam

* fix argument name

* add hyperparameter tests

* review comments

* fixup for registe_hyperparameters

* simplify get_hyperparamters() definition

* rename set_block_size

* remove block size multiplier from testing

* modify hyperparameters getattr/setattr for better debug output

* instantiate hyperparams for entire task_list

* rename df_infer_hp

* instantiate hp only for classification

* Retrieve hparams only for classification

* change hyperparam values to %8 values

* [workaround] use older artifacts to avoid pipeline error - revert before merge

* Revert "[workaround] use older artifacts to avoid pipeline error - revert before merge"

This reverts commit 294396c.

* improve comment on hyperparameters __getattr__

* simplify hyperparameters retrieval

* Revert adding of daal_require_version_wrapper

* clean up imports

* clean up unused type

* fixup after renaming predict->infer

* move get_hyperparameters API to sklearnex

* move tests to sklearnex

* fix test

---------

Co-authored-by: Victoriya Fedotova <[email protected]>
  • Loading branch information
ahuber21 and Vika-F authored Oct 24, 2024
1 parent 0f0348a commit d9db8cc
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 23 deletions.
2 changes: 1 addition & 1 deletion daal4py/sklearn/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import sys
import warnings
from typing import Any, Callable, Tuple
from typing import Any, Tuple

import numpy as np
from numpy.lib.recfunctions import require_fields
Expand Down
28 changes: 28 additions & 0 deletions onedal/common/dispatch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,34 @@ struct infer_ops {
Ops ops;
};

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300

template <typename Policy, typename Input, typename Ops, typename Hyperparams>
struct infer_ops_with_hyperparams {
using Task = typename Input::task_t;

infer_ops_with_hyperparams(
const Policy& policy, const Input& input,
const Ops& ops, const Hyperparams& hyperparams)
: policy(policy),
input(input),
ops(ops),
hyperparams(hyperparams) {}

template <typename Float, typename Method, typename... Args>
auto operator()(const pybind11::dict& params) {
auto desc = ops.template operator()<Float, Method, Task, Args...>(params);
return dal::infer(policy, desc, hyperparams, input);
}

Policy policy;
Input input;
Ops ops;
Hyperparams hyperparams;
};

#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300

template <typename Policy, typename Input, typename Ops>
struct partial_compute_ops {
using Task = typename Input::task_t;
Expand Down
35 changes: 22 additions & 13 deletions onedal/common/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
# ==============================================================================

import logging
from typing import Any, Dict, Tuple
from warnings import warn

from daal4py.sklearn._utils import daal_check_version
from onedal import _backend

if daal_check_version((2024, "P", 0)):
if not daal_check_version((2024, "P", 0)):
warn("Hyperparameters are supported in oneDAL starting from 2024.0.0 version.")
hyperparameters_map = {}
else:
_hparams_reserved_words = [
"algorithm",
"op",
Expand Down Expand Up @@ -56,9 +60,16 @@ def __getattribute__(self, __name):
return super().__getattribute__(__name)
elif __name in self.getters.keys():
return self.getters[__name]()
else:
raise ValueError(
f"Unknown '{__name}' name in "
try:
# try to return attribute from base class
# required to read builtin attributes like __class__, __doc__, etc.
# which are used in debuggers
return super().__getattribute__(__name)
except AttributeError:
# raise an AttributeError with a hyperparameter-specific message
# for easier debugging
raise AttributeError(
f"Unknown attribute '{__name}' in "
f"'{self.algorithm}.{self.op}' hyperparameters"
)

Expand All @@ -70,7 +81,7 @@ def __setattr__(self, __name, __value):
self.setters[__name](__value)
else:
raise ValueError(
f"Unknown '{__name}' name in "
f"Unknown attribute '{__name}' in "
f"'{self.algorithm}.{self.op}' hyperparameters"
)

Expand All @@ -83,13 +94,16 @@ def get_methods_with_prefix(obj, prefix):
for method in filter(lambda f: f.startswith(prefix), dir(obj))
}

hyperparameters_backend = {
hyperparameters_backend: Dict[Tuple[str, str], Any] = {
(
"linear_regression",
"train",
): _backend.linear_model.regression.train_hyperparameters(),
("covariance", "compute"): _backend.covariance.compute_hyperparameters(),
}
if daal_check_version((2024, "P", 300)):
df_infer_hp = _backend.decision_forest.infer_hyperparameters
hyperparameters_backend[("decision_forest", "infer")] = df_infer_hp()
hyperparameters_map = {}

for (algorithm, op), hyperparameters in hyperparameters_backend.items():
Expand All @@ -106,11 +120,6 @@ def get_methods_with_prefix(obj, prefix):
algorithm, op, setters, getters, hyperparameters
)

def get_hyperparameters(algorithm, op):
return hyperparameters_map[(algorithm, op)]

else:

def get_hyperparameters(algorithm, op):
warn("Hyperparameters are supported in oneDAL starting from 2024.0.0 version.")
return None
def get_hyperparameters(algorithm, op):
return hyperparameters_map.get((algorithm, op), None)
72 changes: 71 additions & 1 deletion onedal/ensemble/forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,32 @@ void init_train_ops(py::module_& m) {
using namespace decision_forest;
using input_t = train_input<Task>;

train_ops ops(policy, input_t{ data, responses}, params2desc{});
train_ops ops(policy, input_t{ data, responses }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});
}

template <typename Policy, typename Task>
void init_infer_ops(py::module_& m) {
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300
using infer_hyperparams_t = decision_forest::detail::infer_parameters<Task>;
m.def("infer",
[](const Policy& policy,
const py::dict& params,
const infer_hyperparams_t& hyperparams,
const decision_forest::model<Task>& model,
const table& data) {
using namespace decision_forest;
using input_t = infer_input<Task>;

infer_ops_with_hyperparams ops(policy,
input_t{ model, data },
params2desc{},
hyperparams);
return fptype2t{ method2t{ Task{}, ops } }(params);
});
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300

m.def("infer",
[](const Policy& policy,
const py::dict& params,
Expand Down Expand Up @@ -309,6 +328,49 @@ void init_infer_result(py::module_& m) {
}
}

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300
template <typename Task>
void init_infer_hyperparameters(py::module_& m) {
using namespace dal::decision_forest::detail;
using infer_hyperparams_t = infer_parameters<Task>;

auto cls = py::class_<infer_hyperparams_t>(m, "infer_hyperparameters")
.def(py::init())
.def("set_block_size",
[](infer_hyperparams_t& self, std::int64_t block_size) {
self.set_block_size(block_size);
})
.def("get_block_size",
[](const infer_hyperparams_t& self) -> std::int64_t {
return self.get_block_size();
})
.def("set_min_trees_for_threading",
[](infer_hyperparams_t& self, std::int64_t trees) {
self.set_min_trees_for_threading(trees);
})
.def("get_min_trees_for_threading",
[](const infer_hyperparams_t& self) -> std::int64_t {
return self.get_min_trees_for_threading();
})
.def("set_min_number_of_rows_for_vect_seq_compute",
[](infer_hyperparams_t& self, std::int64_t rows) {
self.set_min_number_of_rows_for_vect_seq_compute(rows);
})
.def("get_min_number_of_rows_for_vect_seq_compute",
[](const infer_hyperparams_t& self) -> std::int64_t {
return self.get_min_number_of_rows_for_vect_seq_compute();
})
.def("set_scale_factor_for_vect_parallel_compute",
[](infer_hyperparams_t& self, double factor) {
self.set_scale_factor_for_vect_parallel_compute(factor);
})
.def("get_scale_factor_for_vect_parallel_compute",
[](const infer_hyperparams_t& self) -> double {
return self.get_scale_factor_for_vect_parallel_compute();
});
}
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300

ONEDAL_PY_TYPE2STR(decision_forest::task::classification, "classification");
ONEDAL_PY_TYPE2STR(decision_forest::task::regression, "regression");

Expand All @@ -317,6 +379,9 @@ ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_result);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_result);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_ops);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_ops);
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_hyperparameters);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300

ONEDAL_PY_INIT_MODULE(ensemble) {
using namespace decision_forest;
Expand All @@ -335,6 +400,11 @@ ONEDAL_PY_INIT_MODULE(ensemble) {
ONEDAL_PY_INSTANTIATE(init_model, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_result, sub, task_list);

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300
ONEDAL_PY_INSTANTIATE(init_infer_hyperparameters, sub, task::classification);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240300

#endif // ONEDAL_DATA_PARALLEL_SPMD
}

Expand Down
17 changes: 12 additions & 5 deletions onedal/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
import warnings
from abc import ABCMeta, abstractmethod
from math import ceil
from numbers import Number

import numpy as np
from sklearn.ensemble import BaseEnsemble
from sklearn.utils import check_random_state

from daal4py.sklearn._utils import daal_check_version
from onedal import _backend
from sklearnex import get_hyperparameters

from ..common._base import BaseEstimator
from ..common._estimator_checks import _check_is_fitted
Expand Down Expand Up @@ -346,7 +345,7 @@ def _create_model(self, module):
# upate error msg.
raise NotImplementedError("Creating model is not supported.")

def _predict(self, X, module, queue):
def _predict(self, X, module, queue, hparams=None):
_check_is_fitted(self)
X = _check_array(
X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
Expand All @@ -357,7 +356,11 @@ def _predict(self, X, module, queue):
model = self._onedal_model
X = _convert_to_supported(policy, X)
params = self._get_onedal_params(X)
result = module.infer(policy, params, model, to_table(X))
if hparams is not None and not hparams.is_default:
result = module.infer(policy, params, hparams.backend, model, to_table(X))
else:
result = module.infer(policy, params, model, to_table(X))

y = from_table(result.responses)
return y

Expand Down Expand Up @@ -458,8 +461,12 @@ def fit(self, X, y, sample_weight=None, queue=None):
)

def predict(self, X, queue=None):
hparams = get_hyperparameters("decision_forest", "infer")
pred = super()._predict(
X, self._get_backend("decision_forest", "classification", None), queue
X,
self._get_backend("decision_forest", "classification", None),
queue,
hparams,
)

return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
Expand Down
3 changes: 3 additions & 0 deletions sklearnex/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
from onedal.utils import _num_features, _num_samples
from sklearnex import get_hyperparameters
from sklearnex._utils import register_hyperparameters

from .._device_offload import dispatch, wrap_output_data
from .._utils import PatchingConditionsChain
Expand Down Expand Up @@ -1197,6 +1199,7 @@ def score(self, X, y, sample_weight=None):
score.__doc__ = _sklearn_ForestRegressor.score.__doc__


@register_hyperparameters({"infer": get_hyperparameters("decision_forest", "infer")})
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
class RandomForestClassifier(ForestClassifier):
__doc__ = _sklearn_RandomForestClassifier.__doc__
Expand Down
21 changes: 18 additions & 3 deletions sklearnex/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
# ===============================================================================

import numpy as np
import pytest
from numpy.testing import assert_allclose
from sklearn.datasets import make_classification, make_regression
Expand All @@ -26,9 +25,19 @@
get_dataframes_and_queues,
)

hparam_values = [
(None, None, None, None),
(8, 100, 32, 0.3),
(16, 100, 32, 0.3),
(32, 100, 32, 0.3),
(64, 10, 32, 0.1),
(128, 100, 1000, 1.0),
]

@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
def test_sklearnex_import_rf_classifier(dataframe, queue):

@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
@pytest.mark.parametrize("block, trees, rows, scale", hparam_values)
def test_sklearnex_import_rf_classifier(dataframe, queue, block, trees, rows, scale):
from sklearnex.ensemble import RandomForestClassifier

X, y = make_classification(
Expand All @@ -42,6 +51,12 @@ def test_sklearnex_import_rf_classifier(dataframe, queue):
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
rf = RandomForestClassifier(max_depth=2, random_state=0).fit(X, y)
hparams = rf.get_hyperparameters("infer")
if hparams and block is not None:
hparams.block_size = block
hparams.min_trees_for_threading = trees
hparams.min_number_of_rows_for_vect_seq_compute = rows
hparams.scale_factor_for_vect_parallel_compute = scale
assert "sklearnex" in rf.__module__
assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))

Expand Down

0 comments on commit d9db8cc

Please sign in to comment.