Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: further fp32 GPU green CI enabling #2187

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions onedal/linear_model/linear_model.py
Original file line number Diff line number Diff line change
@@ -198,6 +198,9 @@ def fit(self, X, y, queue=None):
if not isinstance(X, np.ndarray):
X = np.asarray(X)

policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)

dtype = get_dtype(X)
if dtype not in [np.float32, np.float64]:
dtype = np.float64
@@ -207,11 +210,8 @@ def fit(self, X, y, queue=None):

X, y = _check_X_y(X, y, force_all_finite=False, accept_2d_y=True)

policy = self._get_policy(queue, X, y)

self.n_features_in_ = _num_features(X, fallback_1d=True)

X, y = _convert_to_supported(policy, X, y)
params = self._get_onedal_params(get_dtype(X))
X_table, y_table = to_table(X, y)

2 changes: 1 addition & 1 deletion onedal/linear_model/tests/test_linear_regression.py
Original file line number Diff line number Diff line change
@@ -248,7 +248,7 @@ def test_multioutput_regression(queue, dtype, fit_intercept, problem_type):

pred = model.predict(X, queue=queue)
expected_pred = X @ model.coef_.T + model.intercept_.reshape((1, -1))
tol = 1e-5 if pred.dtype == np.float32 else 1e-7
tol = 2e-5 if pred.dtype == np.float32 else 1e-7
assert_allclose(pred, expected_pred, rtol=tol)

# check that it also works when 'y' is a list of lists
2 changes: 1 addition & 1 deletion onedal/linear_model/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ def test_csr(queue, dtype, dims):
model_sp.fit(X_sp, y, queue=queue)
pred_sp = model_sp.predict(X_sp, queue=queue)

rtol = 2e-4
rtol = 2e-3
assert_allclose(pred, pred_sp, rtol=rtol)
assert_allclose(model.coef_, model_sp.coef_, rtol=rtol)
assert_allclose(model.intercept_, model_sp.intercept_, rtol=rtol)
6 changes: 3 additions & 3 deletions sklearnex/linear_model/tests/test_incremental_linear.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def test_sklearnex_fit_on_gold_data(dataframe, queue, fit_intercept, macro_block
y_pred = inclin.predict(X_df)
np_y_pred = _as_numpy(y_pred)

tol = 5e-5 if dtype == np.float32 else 1e-7
tol = 5e-5 if y_pred.dtype == np.float32 else 1e-7
assert_allclose(inclin.coef_, [1], atol=tol)
if fit_intercept:
assert_allclose(inclin.intercept_, [0], atol=tol)
@@ -89,7 +89,7 @@ def test_sklearnex_partial_fit_on_gold_data(
np_y_pred = _as_numpy(y_pred)

assert inclin.n_features_in_ == 1
tol = 1e-5 if dtype == np.float32 else 1e-7
tol = 1e-5 if y_pred.dtype == np.float32 else 1e-7
assert_allclose(inclin.coef_, [[1]], atol=tol)
if fit_intercept:
assert_allclose(inclin.intercept_, 3, atol=tol)
@@ -131,7 +131,7 @@ def test_sklearnex_partial_fit_multitarget_on_gold_data(

assert inclin.n_features_in_ == 2
tol = 1e-7
if dtype == np.float32:
if y_pred.dtype == np.float32:
tol = 7e-6 if _IS_INTEL else 2e-5

assert_allclose(inclin.coef_, [1.0, 2.0], atol=tol)
15 changes: 11 additions & 4 deletions sklearnex/linear_model/tests/test_incremental_ridge.py
Original file line number Diff line number Diff line change
@@ -76,10 +76,13 @@ def test_inc_ridge_fit_coefficients(
coefficients_manual, intercept_manual = _compute_ridge_coefficients(
X, y, alpha, fit_intercept
)

tol = 2e-4 if inc_ridge.coef_.dtype == np.float32 else 1e-6

if fit_intercept:
assert_allclose(inc_ridge.intercept_, intercept_manual, rtol=1e-6, atol=1e-6)
assert_allclose(inc_ridge.intercept_, intercept_manual, rtol=tol, atol=tol)

assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=1e-6, atol=1e-6)
assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=tol, atol=tol)

@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@pytest.mark.parametrize("batch_size", [2, 5])
@@ -106,8 +109,10 @@ def test_inc_ridge_partial_fit_coefficients(dataframe, queue, alpha, batch_size)
inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity)
xt_y = np.dot(X.T, y)
coefficients_manual = np.dot(inverse_term, xt_y)

tol = 5e-3 if inc_ridge.coef_.dtype == np.float32 else 1e-6

assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=1e-6, atol=1e-6)
assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=tol, atol=tol)

def test_inc_ridge_score_before_fit():
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
@@ -149,5 +154,7 @@ def test_inc_ridge_predict_after_fit(dataframe, queue, fit_intercept):
y_pred_manual = np.dot(X, coefficients_manual)
if fit_intercept:
y_pred_manual += intercept_manual

tol = 1e-5 if inc_ridge.coef_.dtype == np.float32 else 1e-6

assert_allclose(_as_numpy(y_pred), y_pred_manual, rtol=1e-6, atol=1e-6)
assert_allclose(_as_numpy(y_pred), y_pred_manual, rtol=tol, atol=tol)
2 changes: 1 addition & 1 deletion sklearnex/linear_model/tests/test_logreg.py
Original file line number Diff line number Diff line change
@@ -127,7 +127,7 @@ def test_csr(queue, dtype, dims):
pred_sp = model_sp.predict(X_sp)
prob_sp = model_sp.predict_proba(X_sp)

rtol = 2e-4
rtol = 1e-3 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 2e-4
assert_allclose(pred, pred_sp, rtol=rtol)
assert_allclose(prob, prob_sp, rtol=rtol)
assert_allclose(model.coef_, model_sp.coef_, rtol=rtol)
6 changes: 4 additions & 2 deletions sklearnex/linear_model/tests/test_ridge.py
Original file line number Diff line number Diff line change
@@ -129,8 +129,10 @@ def test_ridge_coefficients(
X, y, alpha, fit_intercept=fit_intercept
)

assert_allclose(ridge_reg.coef_, coefficients_manual, rtol=1e-6, atol=1e-6)
assert_allclose(ridge_reg.intercept_, intercept_manual, rtol=1e-6, atol=1e-6)
tol = 1e-5 if ridge_reg.coef_.dtype == np.float32 else 1e-6

assert_allclose(ridge_reg.coef_, coefficients_manual, rtol=tol, atol=tol)
assert_allclose(ridge_reg.intercept_, intercept_manual, rtol=tol, atol=tol)


@pytest.mark.skipif(
Original file line number Diff line number Diff line change
@@ -200,7 +200,7 @@ def test_sklearnex_partial_fit_on_gold_data(dataframe, queue, whiten, num_blocks

X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
transformed_data = incpca.transform(X_df)
check_pca_on_gold_data(incpca, dtype, whiten, transformed_data)
check_pca_on_gold_data(incpca, transformed_data.dtype, whiten, transformed_data)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@@ -217,7 +217,7 @@ def test_sklearnex_fit_on_gold_data(dataframe, queue, whiten, num_blocks, dtype)
incpca.fit(X_df)
transformed_data = incpca.transform(X_df)

check_pca_on_gold_data(incpca, dtype, whiten, transformed_data)
check_pca_on_gold_data(incpca, transformed_data.dtype, whiten, transformed_data)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@@ -235,7 +235,7 @@ def test_sklearnex_fit_transform_on_gold_data(
X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
transformed_data = incpca.fit_transform(X_df)

check_pca_on_gold_data(incpca, dtype, whiten, transformed_data)
check_pca_on_gold_data(incpca, transformed_data.dtype, whiten, transformed_data)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@@ -263,7 +263,7 @@ def test_sklearnex_partial_fit_on_random_data(

X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
transformed_data = incpca.transform(X_df)
check_pca(incpca, dtype, whiten, X, transformed_data)
check_pca(incpca, transformed_data.dtype, whiten, X, transformed_data)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
Original file line number Diff line number Diff line change
@@ -65,8 +65,9 @@ def test_basic_stats_spmd_gold(dataframe, queue):
spmd_result = BasicStatistics_SPMD().fit(local_dpt_data)
batch_result = BasicStatistics_Batch().fit(data)

tol = 1e-7 if queue.sycl_device.has_aspect_fp64 else 1e-6
for option in options_and_tests:
assert_allclose(getattr(spmd_result, option), getattr(batch_result, option))
assert_allclose(getattr(spmd_result, option), getattr(batch_result, option), rtol=tol)


@pytest.mark.skipif(
@@ -97,7 +98,7 @@ def test_basic_stats_spmd_synthetic(n_samples, n_features, dataframe, queue, dty
spmd_result = BasicStatistics_SPMD().fit(local_dpt_data)
batch_result = BasicStatistics_Batch().fit(data)

tol = 1e-5 if dtype == np.float32 else 1e-7
tol = 1e-5 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7
for option in options_and_tests:
assert_allclose(
getattr(spmd_result, option),
Original file line number Diff line number Diff line change
@@ -260,7 +260,7 @@ def test_incremental_basic_statistics_partial_fit_spmd_synthetic(
IncrementalBasicStatistics as IncrementalBasicStatistics_SPMD,
)

tol = 2e-3 if dtype == np.float32 else 1e-7
tol = 2e-3 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7

# Create gold data and process into dpt
data = _generate_statistic_data(n_samples, n_features, dtype=dtype)
2 changes: 1 addition & 1 deletion sklearnex/spmd/cluster/tests/test_dbscan_spmd.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ def test_dbscan_spmd_gold(dataframe, queue):
from sklearnex.cluster import DBSCAN as DBSCAN_Batch
from sklearnex.spmd.cluster import DBSCAN as DBSCAN_SPMD

data = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]])
data = np.array([[1., 2.], [2., 2.], [2., 3.], [8., 7.], [8., 8.], [25., 80.]])

local_dpt_data = _convert_to_dataframe(
_get_local_tensor(data), sycl_queue=queue, target_df=dataframe
29 changes: 14 additions & 15 deletions sklearnex/spmd/cluster/tests/test_kmeans_spmd.py
Original file line number Diff line number Diff line change
@@ -47,22 +47,21 @@ def test_kmeans_spmd_gold(dataframe, queue):

X_train = np.array(
[
[1, 2],
[2, 2],
[2, 3],
[8, 7],
[8, 8],
[25, 80],
[5, 65],
[2, 8],
[1, 3],
[2, 2],
[1, 3],
[2, 2],
[1., 2.],
[2., 2.],
[2., 3.],
[8., 7.],
[8., 8.],
[25., 80.],
[5., 65.],
[2., 8.],
[1., 3.],
[2., 2.],
[1., 3.],
[2., 2.],
],
dtype=np.float64,
)
X_test = np.array([[0, 0], [12, 3], [2, 2], [7, 8]], dtype=np.float64)
X_test = np.array([[0., 0.], [12., 3.], [2., 2.], [7., 8.]])

local_dpt_X_train = _convert_to_dataframe(
_get_local_tensor(X_train), sycl_queue=queue, target_df=dataframe
@@ -146,7 +145,7 @@ def test_kmeans_spmd_synthetic(
n_clusters=n_clusters, init=spmd_model_init.cluster_centers_, random_state=0
).fit(X_train)

atol = 1e-5 if dtype == np.float32 else 1e-7
atol = 1e-5 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7
_assert_unordered_allclose(
spmd_model.cluster_centers_, batch_model.cluster_centers_, atol=atol
)
7 changes: 4 additions & 3 deletions sklearnex/spmd/covariance/tests/test_covariance_spmd.py
Original file line number Diff line number Diff line change
@@ -64,8 +64,9 @@ def test_covariance_spmd_gold(dataframe, queue):
spmd_result = EmpiricalCovariance_SPMD().fit(local_dpt_data)
batch_result = EmpiricalCovariance_Batch().fit(data)

assert_allclose(spmd_result.covariance_, batch_result.covariance_)
assert_allclose(spmd_result.location_, batch_result.location_)
atol = 1e-7 if queue.sycl_device.has_aspect_fp64 else 1e-5
assert_allclose(spmd_result.covariance_, batch_result.covariance_, atol=atol)
assert_allclose(spmd_result.location_, batch_result.location_, atol=atol)


@pytest.mark.skipif(
@@ -102,6 +103,6 @@ def test_covariance_spmd_synthetic(
)
batch_result = EmpiricalCovariance_Batch(assume_centered=assume_centered).fit(data)

atol = 1e-5 if dtype == np.float32 else 1e-7
atol = 1e-5 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7
assert_allclose(spmd_result.covariance_, batch_result.covariance_, atol=atol)
assert_allclose(spmd_result.location_, batch_result.location_, atol=atol)
Original file line number Diff line number Diff line change
@@ -178,7 +178,7 @@ def test_incremental_covariance_partial_fit_spmd_synthetic(

inccov.fit(dpt_data)

tol = 1e-7
tol = 1e-7 if queue.sycl_device.has_aspect_fp64 else 1e-6

assert_allclose(inccov_spmd.covariance_, inccov.covariance_, atol=tol)
assert_allclose(inccov_spmd.location_, inccov.location_, atol=tol)
Original file line number Diff line number Diff line change
@@ -174,7 +174,7 @@ def test_incremental_pca_fit_spmd_random(
from sklearnex.spmd.decomposition import IncrementalPCA as IncrementalPCA_SPMD

# Increased test dataset size requires a higher tol setting in comparison to other tests
tol = 7e-5 if dtype == np.float32 else 1e-7
tol = 3e-2 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7

# Create data and process into dpt
X = _generate_statistic_data(num_samples, num_features, dtype)
@@ -233,7 +233,7 @@ def test_incremental_pca_partial_fit_spmd_random(
from sklearnex.preview.decomposition import IncrementalPCA
from sklearnex.spmd.decomposition import IncrementalPCA as IncrementalPCA_SPMD

tol = 3e-4 if dtype == np.float32 else 1e-7
tol = 3e-4 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7

# Create data and process into dpt
X = _generate_statistic_data(num_samples, num_features, dtype)
10 changes: 6 additions & 4 deletions sklearnex/spmd/decomposition/tests/test_pca_spmd.py
Original file line number Diff line number Diff line change
@@ -65,16 +65,18 @@ def test_pca_spmd_gold(dataframe, queue):
spmd_result = PCA_SPMD(n_components=2).fit(local_dpt_data)
batch_result = PCA_Batch(n_components=2).fit(data)

tol = 1e-7 if queue.sycl_device.has_aspect_fp64 else 1e-5

assert_allclose(spmd_result.mean_, batch_result.mean_)
assert_allclose(spmd_result.components_, batch_result.components_)
assert_allclose(spmd_result.singular_values_, batch_result.singular_values_)
assert_allclose(spmd_result.singular_values_, batch_result.singular_values_, rtol=tol)
assert_allclose(
spmd_result.noise_variance_,
batch_result.noise_variance_,
atol=1e-7,
atol=tol,
)
assert_allclose(
spmd_result.explained_variance_ratio_, batch_result.explained_variance_ratio_
spmd_result.explained_variance_ratio_, batch_result.explained_variance_ratio_, rtol=tol
)


@@ -116,7 +118,7 @@ def test_pca_spmd_synthetic(
spmd_result = PCA_SPMD(n_components=n_components, whiten=whiten).fit(local_dpt_data)
batch_result = PCA_Batch(n_components=n_components, whiten=whiten).fit(data)

tol = 1e-3 if dtype == np.float32 else 1e-7
tol = 1e-3 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7
assert_allclose(spmd_result.mean_, batch_result.mean_, atol=tol)
assert_allclose(spmd_result.components_, batch_result.components_, atol=tol, rtol=tol)
assert_allclose(spmd_result.singular_values_, batch_result.singular_values_, atol=tol)
Original file line number Diff line number Diff line change
@@ -208,7 +208,7 @@ def test_incremental_linear_regression_fit_spmd_random(
IncrementalLinearRegression as IncrementalLinearRegression_SPMD,
)

tol = 5e-3 if dtype == np.float32 else 1e-7
tol = 5e-3 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7

# Generate random data and process into dpt
X_train, X_test, y_train, _ = _generate_regression_data(
@@ -279,7 +279,7 @@ def test_incremental_linear_regression_partial_fit_spmd_random(
IncrementalLinearRegression as IncrementalLinearRegression_SPMD,
)

tol = 5e-3 if dtype == np.float32 else 1e-7
tol = 5e-3 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7

# Generate random data and process into dpt
X_train, X_test, y_train, _ = _generate_regression_data(
Original file line number Diff line number Diff line change
@@ -82,14 +82,16 @@ def test_linear_spmd_gold(dataframe, queue):
spmd_model = LinearRegression_SPMD().fit(local_dpt_X_train, local_dpt_y_train)
batch_model = LinearRegression_Batch().fit(X_train, y_train)

assert_allclose(spmd_model.coef_, batch_model.coef_)
assert_allclose(spmd_model.intercept_, batch_model.intercept_)
tol = 1e-7 if queue.sycl_device.has_aspect_fp64 else 1e-5

assert_allclose(spmd_model.coef_, batch_model.coef_, rtol=tol)
assert_allclose(spmd_model.intercept_, batch_model.intercept_, rtol=tol)

# ensure predictions of batch algo match spmd
spmd_result = spmd_model.predict(local_dpt_X_test)
batch_result = batch_model.predict(X_test)

_spmd_assert_allclose(spmd_result, batch_result)
_spmd_assert_allclose(spmd_result, batch_result, rtol=tol)


@pytest.mark.skipif(
@@ -134,7 +136,7 @@ def test_linear_spmd_synthetic(n_samples, n_features, dataframe, queue, dtype):
spmd_model = LinearRegression_SPMD().fit(local_dpt_X_train, local_dpt_y_train)
batch_model = LinearRegression_Batch().fit(X_train, y_train)

tol = 1e-3 if dtype == np.float32 else 1e-7
tol = 1e-3 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-7
assert_allclose(spmd_model.coef_, batch_model.coef_, rtol=tol, atol=tol)
assert_allclose(spmd_model.intercept_, batch_model.intercept_, rtol=tol, atol=tol)

6 changes: 3 additions & 3 deletions sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py
Original file line number Diff line number Diff line change
@@ -155,7 +155,7 @@ def test_knncls_spmd_synthetic(
batch_result = batch_model.predict(X_test)

tol = 1e-4
if dtype == np.float64:
if dtype == np.float64 and queue.sycl_device.has_aspect_fp64:
_assert_unordered_allclose(spmd_indcs, batch_indcs, localize=True)
_assert_unordered_allclose(
spmd_dists, batch_dists, localize=True, rtol=tol, atol=tol
@@ -279,8 +279,8 @@ def test_knnreg_spmd_synthetic(
spmd_result = spmd_model.predict(local_dpt_X_test)
batch_result = batch_model.predict(X_test)

tol = 0.005 if dtype == np.float32 else 1e-4
if dtype == np.float64:
tol = 0.005 if (dtype == np.float32 or not queue.sycl_device.has_aspect_fp64) else 1e-4
if dtype == np.float64 and queue.sycl_device.has_aspect_fp64:
_assert_unordered_allclose(spmd_indcs, batch_indcs, localize=True)
_assert_unordered_allclose(
spmd_dists, batch_dists, localize=True, rtol=tol, atol=tol
16 changes: 16 additions & 0 deletions tests/run_examples.py
Original file line number Diff line number Diff line change
@@ -79,8 +79,24 @@
available_devices = ["cpu"]

gpu_available = False
import site
path_to_env = site.getsitepackages()[0]
path_to_libs = os.path.join(path_to_env, "Library", "bin")
try:
os.add_dll_directory(path_to_libs)
except FileNotFoundError:
print("FILENOTFOUNDERROR sklearnex")

import dpctl
import dpctl.tensor as dpt
print("dpctl available: {}".format(dpctl_available))
print("dpctl had gpu devices: {}".format(dpctl.has_gpu_devices()))
if dpctl.has_gpu_devices():
gpu_available = True
available_devices.append("gpu")
if dpctl_available:
import dpctl
print("dpctl had gpu devices: {}".format(dpctl.has_gpu_devices()))

if dpctl.has_gpu_devices():
gpu_available = True