Skip to content

Commit

Permalink
fix testing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
icfaust committed Nov 21, 2024
1 parent 5355039 commit b5b8442
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions sklearnex/utils/tests/test_finite.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True])
def test_sum_infinite_actually_finite(dtype, shape, ensure_all_finite):
est = DummyEstimator()
X = np.array(shape, dtype=dtype)
X = np.empty(shape, dtype=dtype)
X.fill(np.finfo(dtype).max)
X = np.atleast_2d(X)
X_array = validate_data(est, X, ensure_all_finite=ensure_all_finite)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_validate_data_random_shape_and_location(

allow_nan = ensure_all_finite == "allow-nan"
if check is None or (allow_nan and check == "NaN"):
validate_data(est, X)
validate_data(est, X, ensure_all_finite=ensure_all_finite)
else:
type_err = "infinity" if allow_nan else "NaN, infinity"
msg_err = f"Input X contains {type_err}."
Expand All @@ -129,26 +129,25 @@ def test_validate_data_random_shape_and_location(


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("array_api_dispatch", [True, False])
@pytest.mark.parametrize("array_api_dispatch", [True, False] if sklearn_check_version("1.2") else [False])
@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
def test_validate_data_output(array_api_dispatch, dtype, dataframe, queue):
est = DummyEstimator()
X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0]

dispatch = {}
if sklearn_check_version("1.2"):
if array_api_dispatch:
pytest.skip(dataframe == "pandas", "pandas inputs do not work with sklearn's array_api_dispatch")
dispatch["array_api_dispatch"] = array_api_dispatch

with config_context(**dispatch):
validate_data(est, X, y)
est.fit(X, y)
X_out, y_out = validate_data(est, X, y)
# check sklearn validate_data operations work underneath
X_array = validate_data(est, X, reset=False)
X_out = est.predict(X)

if dataframe == "pandas" or (
dataframe == "array_api"
and not (sklearn_check_version("1.2") and array_api_dispatch)
):
and not array_api_dispatch):
# array_api_strict from sklearn < 1.2 and pandas will convert to numpy arrays
assert isinstance(X_array, np.ndarray)
assert isinstance(X_out, np.ndarray)
Expand Down

0 comments on commit b5b8442

Please sign in to comment.