Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] WIP new finite checking in SVM algorithms #2209

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onedal/svm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# limitations under the License.
# ==============================================================================

from .svm import SVC, SVR, NuSVC, NuSVR, SVMtype
from .svm import SVC, SVR, NuSVC, NuSVR

__all__ = ["SVC", "SVR", "NuSVC", "NuSVR", "SVMtype"]
__all__ = ["SVC", "SVR", "NuSVC", "NuSVR"]
194 changes: 36 additions & 158 deletions onedal/svm/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# ==============================================================================

from abc import ABCMeta, abstractmethod
from enum import Enum

import numpy as np
from scipy import sparse as sp

from onedal import _backend

from ..common._base import BaseEstimator
from ..common._estimator_checks import _check_is_fitted
from ..common._mixin import ClassifierMixin, RegressorMixin
from ..common._policy import _get_policy
Expand All @@ -35,14 +35,7 @@
)


class SVMtype(Enum):
c_svc = 0
epsilon_svr = 1
nu_svc = 2
nu_svr = 3


class BaseSVM(metaclass=ABCMeta):
class BaseSVM(BaseEstimator, metaclass=ABCMeta):
@abstractmethod
def __init__(
self,
Expand All @@ -63,8 +56,6 @@ def __init__(
decision_function_shape,
break_ties,
algorithm,
svm_type=None,
**kwargs,
):
self.C = C
self.nu = nu
Expand All @@ -82,21 +73,20 @@ def __init__(
self.decision_function_shape = decision_function_shape
self.break_ties = break_ties
self.algorithm = algorithm
self.svm_type = svm_type

def _validate_targets(self, y, dtype):
self.class_weight_ = None
self.classes_ = None
return _column_or_1d(y, warn=True).astype(dtype, copy=False)

def _get_onedal_params(self, data):
def _get_onedal_params(self, dtype):
max_iter = 10000 if self.max_iter == -1 else self.max_iter
# TODO: remove this workaround
# when oneDAL SVM starts support of 'n_iterations' result
self.n_iter_ = 1 if max_iter < 1 else max_iter
class_count = 0 if self.classes_ is None else len(self.classes_)
return {
"fptype": data.dtype,
"fptype": dtype,
"method": self.algorithm,
"kernel": self.kernel,
"c": self.C,
Expand Down Expand Up @@ -129,6 +119,7 @@ def _fit(self, X, y, sample_weight, module, queue):
force_all_finite=True,
accept_sparse="csr",
)
# hard work remains on moving validate targets away from onedal
y = self._validate_targets(y, X.dtype)
if sample_weight is not None and len(sample_weight) > 0:
sample_weight = _check_array(
Expand All @@ -154,29 +145,12 @@ def _fit(self, X, y, sample_weight, module, queue):
self._scale_, self._sigma_ = 1.0, 1.0
self.coef0 = 0.0
else:
if isinstance(self.gamma, str):
if self.gamma == "scale":
if sp.issparse(X):
# var = E[X^2] - E[X]^2
X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2
else:
X_sc = X.var()
_gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0
elif self.gamma == "auto":
_gamma = 1.0 / X.shape[1]
else:
raise ValueError(
"When 'gamma' is a string, it should be either 'scale' or "
"'auto'. Got '{}' instead.".format(self.gamma)
)
else:
_gamma = self.gamma
self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma)
self._scale_, self._sigma_ = self.gamma, np.sqrt(0.5 / self.gamma)

policy = _get_policy(queue, *data)
X = _convert_to_supported(policy, X)
params = self._get_onedal_params(X)
result = module.train(policy, params, *to_table(*data))
data_t = to_table(*_convert_to_supported(policy, *data))
params = self._get_onedal_params(data_t[0].dtype)
result = module.train(policy, params, *data_t)

if self._sparse:
self.dual_coef_ = sp.csr_matrix(from_table(result.coeffs).T)
Expand All @@ -190,6 +164,7 @@ def _fit(self, X, y, sample_weight, module, queue):
self.n_features_in_ = X.shape[1]
self.shape_fit_ = X.shape

# _n_support not used in this object, will be moved to sklearnex
if getattr(self, "classes_", None) is not None:
indices = y.take(self.support_, axis=0)
self._n_support = np.array(
Expand All @@ -206,128 +181,37 @@ def _create_model(self, module):
m.support_vectors = to_table(self.support_vectors_)
m.coeffs = to_table(self.dual_coef_.T)
m.biases = to_table(self.intercept_)

if self.svm_type is SVMtype.c_svc or self.svm_type is SVMtype.nu_svc:
m.first_class_response, m.second_class_response = 0, 1
return m

def _predict(self, X, module, queue):
def _infer(self, X, module, queue):
_check_is_fitted(self)
if self.break_ties and self.decision_function_shape == "ovo":
raise ValueError(
"break_ties must be False when " "decision_function_shape is 'ovo'"
)

if module in [_backend.svm.classification, _backend.svm.nu_classification]:
sv = self.support_vectors_
if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
raise ValueError(
"The internal representation "
f"of {self.__class__.__name__} was altered"
)

if (
self.break_ties
and self.decision_function_shape == "ovr"
and len(self.classes_) > 2
):
y = np.argmax(self.decision_function(X), axis=1)
else:
X = _check_array(
X,
dtype=[np.float64, np.float32],
force_all_finite=True,
accept_sparse="csr",
)
_check_n_features(self, X, False)

if self._sparse and not sp.isspmatrix(X):
if self._sparse:
if not sp.isspmatrix(X):
X = sp.csr_matrix(X)
if self._sparse:
X.sort_indices()

if sp.issparse(X) and not self._sparse and not callable(self.kernel):
raise ValueError(
"cannot use sparse input in %r trained on dense data"
% type(self).__name__
)

policy = _get_policy(queue, X)
X = _convert_to_supported(policy, X)
params = self._get_onedal_params(X)

if hasattr(self, "_onedal_model"):
model = self._onedal_model
else:
model = self._create_model(module)
result = module.infer(policy, params, model, to_table(X))
y = from_table(result.responses)
return y

def _ovr_decision_function(self, predictions, confidences, n_classes):
n_samples = predictions.shape[0]
votes = np.zeros((n_samples, n_classes))
sum_of_confidences = np.zeros((n_samples, n_classes))

k = 0
for i in range(n_classes):
for j in range(i + 1, n_classes):
sum_of_confidences[:, i] -= confidences[:, k]
sum_of_confidences[:, j] += confidences[:, k]
votes[predictions[:, k] == 0, i] += 1
votes[predictions[:, k] == 1, j] += 1
k += 1

transformed_confidences = sum_of_confidences / (
3 * (np.abs(sum_of_confidences) + 1)
)
return votes + transformed_confidences

def _decision_function(self, X, module, queue):
_check_is_fitted(self)
X = _check_array(
X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse="csr"
)
_check_n_features(self, X, False)

if self._sparse and not sp.isspmatrix(X):
X = sp.csr_matrix(X)
if self._sparse:
X.sort_indices()

if sp.issparse(X) and not self._sparse and not callable(self.kernel):
X.sort_indices()
elif sp.issparse(X) and not callable(self.kernel):
raise ValueError(
"cannot use sparse input in %r trained on dense data"
% type(self).__name__
)

if module in [_backend.svm.classification, _backend.svm.nu_classification]:
sv = self.support_vectors_
if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
raise ValueError(
"The internal representation "
f"of {self.__class__.__name__} was altered"
)

policy = _get_policy(queue, X)
X = _convert_to_supported(policy, X)
params = self._get_onedal_params(X)
X = to_table(_convert_to_supported(policy, X))
params = self._get_onedal_params(X.dtype)

if hasattr(self, "_onedal_model"):
model = self._onedal_model
else:
model = self._create_model(module)
result = module.infer(policy, params, model, to_table(X))
decision_function = from_table(result.decision_function)
return module.infer(policy, params, model, X)

if len(self.classes_) == 2:
decision_function = decision_function.ravel()
def _predict(self, X, module, queue):
return from_table(self._infer(X, module, queue).responses)

if self.decision_function_shape == "ovr" and len(self.classes_) > 2:
decision_function = self._ovr_decision_function(
decision_function < 0, -decision_function, len(self.classes_)
)
return decision_function
def _decision_function(self, X, module, queue):
return from_table(self._infer(X, module, queue).decision_function)


class SVR(RegressorMixin, BaseSVM):
Expand All @@ -350,7 +234,6 @@ def __init__(
max_iter=-1,
tau=1e-12,
algorithm="thunder",
**kwargs,
):
super().__init__(
C=C,
Expand All @@ -370,14 +253,12 @@ def __init__(
break_ties=False,
algorithm=algorithm,
)
self.svm_type = SVMtype.epsilon_svr

def fit(self, X, y, sample_weight=None, queue=None):
return super()._fit(X, y, sample_weight, _backend.svm.regression, queue)

def predict(self, X, queue=None):
y = super()._predict(X, _backend.svm.regression, queue)
return y.ravel()
return super()._predict(X, _backend.svm.regression, queue)


class SVC(ClassifierMixin, BaseSVM):
Expand All @@ -402,7 +283,6 @@ def __init__(
decision_function_shape="ovr",
break_ties=False,
algorithm="thunder",
**kwargs,
):
super().__init__(
C=C,
Expand All @@ -422,7 +302,11 @@ def __init__(
break_ties=break_ties,
algorithm=algorithm,
)
self.svm_type = SVMtype.c_svc

def _create_model(self, module):
m = super()._create_model(module)
m.first_class_response, m.second_class_response = 0, 1
return m

def _validate_targets(self, y, dtype):
y, self.class_weight_, self.classes_ = _validate_targets(
Expand All @@ -434,10 +318,7 @@ def fit(self, X, y, sample_weight=None, queue=None):
return super()._fit(X, y, sample_weight, _backend.svm.classification, queue)

def predict(self, X, queue=None):
y = super()._predict(X, _backend.svm.classification, queue)
if len(self.classes_) == 2:
y = y.ravel()
return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel()
return super()._predict(X, _backend.svm.classification, queue)

def decision_function(self, X, queue=None):
return super()._decision_function(X, _backend.svm.classification, queue)
Expand All @@ -463,7 +344,6 @@ def __init__(
max_iter=-1,
tau=1e-12,
algorithm="thunder",
**kwargs,
):
super().__init__(
C=C,
Expand All @@ -483,14 +363,12 @@ def __init__(
break_ties=False,
algorithm=algorithm,
)
self.svm_type = SVMtype.nu_svr

def fit(self, X, y, sample_weight=None, queue=None):
return super()._fit(X, y, sample_weight, _backend.svm.nu_regression, queue)

def predict(self, X, queue=None):
y = super()._predict(X, _backend.svm.nu_regression, queue)
return y.ravel()
return super()._predict(X, _backend.svm.nu_regression, queue)


class NuSVC(ClassifierMixin, BaseSVM):
Expand All @@ -515,7 +393,6 @@ def __init__(
decision_function_shape="ovr",
break_ties=False,
algorithm="thunder",
**kwargs,
):
super().__init__(
C=1.0,
Expand All @@ -535,7 +412,11 @@ def __init__(
break_ties=break_ties,
algorithm=algorithm,
)
self.svm_type = SVMtype.nu_svc

def _create_model(self, module):
m = super()._create_model(module)
m.first_class_response, m.second_class_response = 0, 1
return m

def _validate_targets(self, y, dtype):
y, self.class_weight_, self.classes_ = _validate_targets(
Expand All @@ -547,10 +428,7 @@ def fit(self, X, y, sample_weight=None, queue=None):
return super()._fit(X, y, sample_weight, _backend.svm.nu_classification, queue)

def predict(self, X, queue=None):
y = super()._predict(X, _backend.svm.nu_classification, queue)
if len(self.classes_) == 2:
y = y.ravel()
return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel()
return super()._predict(X, _backend.svm.nu_classification, queue)

def decision_function(self, X, queue=None):
return super()._decision_function(X, _backend.svm.nu_classification, queue)
Loading
Loading