Skip to content

Commit

Permalink
Add freeze test.
Browse files Browse the repository at this point in the history
  • Loading branch information
jtilly committed Dec 26, 2023
1 parent 0bf4f10 commit fc01885
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
3 changes: 3 additions & 0 deletions environment-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ dependencies:
- ruff
- setuptools_scm

# for test data
- scikit-learn

# build tools
- c-compiler
- cxx-compiler
Expand Down
5 changes: 4 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ dependencies:
- ruff
- setuptools_scm

# for test data
- scikit-learn

# build tools
- c-compiler
- cxx-compiler
Expand All @@ -31,4 +34,4 @@ dependencies:
- seaborn-base
- sphinx
- sphinx_rtd_theme
- sphinxcontrib-apidoc
- sphinxcontrib-apidoc
21 changes: 21 additions & 0 deletions tests/test_categorical_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import pytest

import tabmat as tm
from tabmat.categorical_matrix import CategoricalMatrix


Expand Down Expand Up @@ -71,3 +72,23 @@ def test_categorical_indexing(drop_first):
mat = CategoricalMatrix(catvec, drop_first=drop_first)
expected = pd.get_dummies(catvec, drop_first=drop_first).to_numpy()[:, [0, 1]]
np.testing.assert_allclose(mat[:, [0, 1]].A, expected)


# The only categorical variable (zipcode) has 70 levels in the data; the test
# fails when we set the cat treshold to 70.
@pytest.mark.parametrize("cat_threshold", [70, 71])
def test_freeze(cat_threshold):
pytest.importorskip("sklearn")

from sklearn.datasets import fetch_openml

house_data = fetch_openml(name="house_sales", version=3, as_frame=True)
X = house_data.data

weights = np.ones(len(X)) / len(X)

mat = tm.from_pandas(X, cat_threshold=cat_threshold)

np.testing.assert_array_equal(
mat.transpose_matvec(weights), mat.transpose_matvec(weights)
)

0 comments on commit fc01885

Please sign in to comment.