|
| 1 | +"""Unit tests for aeon classifier compatability with sklearn interfaces.""" |
| 2 | + |
| 3 | +__maintainer__ = [] |
| 4 | +__all__ = [ |
| 5 | + "test_sklearn_cross_validation", |
| 6 | + "test_sklearn_cross_validation_iterators", |
| 7 | + "test_sklearn_parameter_tuning", |
| 8 | + "test_sklearn_composite_classifiers", |
| 9 | +] |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pytest |
| 13 | +from sklearn.calibration import CalibratedClassifierCV |
| 14 | +from sklearn.ensemble import VotingClassifier |
| 15 | +from sklearn.experimental import enable_halving_search_cv # noqa |
| 16 | +from sklearn.model_selection import ( |
| 17 | + GridSearchCV, |
| 18 | + GroupKFold, |
| 19 | + GroupShuffleSplit, |
| 20 | + HalvingGridSearchCV, |
| 21 | + HalvingRandomSearchCV, |
| 22 | + KFold, |
| 23 | + LeaveOneOut, |
| 24 | + LeavePGroupsOut, |
| 25 | + LeavePOut, |
| 26 | + RandomizedSearchCV, |
| 27 | + RepeatedKFold, |
| 28 | + ShuffleSplit, |
| 29 | + StratifiedKFold, |
| 30 | + StratifiedShuffleSplit, |
| 31 | + TimeSeriesSplit, |
| 32 | + cross_val_score, |
| 33 | +) |
| 34 | +from sklearn.pipeline import Pipeline |
| 35 | + |
| 36 | +from tsml.dummy import DummyClassifier |
| 37 | +from tsml.transformations import PeriodogramTransformer |
| 38 | +from tsml.utils.testing import generate_3d_test_data |
| 39 | + |
| 40 | +# StratifiedGroupKFold(n_splits=2), removed because it is not available in sklearn 0.24 |
| 41 | +CROSS_VALIDATION_METHODS = [ |
| 42 | + KFold(n_splits=2), |
| 43 | + RepeatedKFold(n_splits=2, n_repeats=2), |
| 44 | + LeaveOneOut(), |
| 45 | + LeavePOut(p=5), |
| 46 | + ShuffleSplit(n_splits=2, test_size=0.25), |
| 47 | + StratifiedKFold(n_splits=2), |
| 48 | + StratifiedShuffleSplit(n_splits=2, test_size=0.25), |
| 49 | + GroupKFold(n_splits=2), |
| 50 | + LeavePGroupsOut(n_groups=5), |
| 51 | + GroupShuffleSplit(n_splits=2, test_size=0.25), |
| 52 | + TimeSeriesSplit(n_splits=2), |
| 53 | +] |
| 54 | +PARAMETER_TUNING_METHODS = [ |
| 55 | + GridSearchCV, |
| 56 | + RandomizedSearchCV, |
| 57 | + HalvingGridSearchCV, |
| 58 | + HalvingRandomSearchCV, |
| 59 | +] |
| 60 | +COMPOSITE_ESTIMATORS = [ |
| 61 | + Pipeline( |
| 62 | + [ |
| 63 | + ("transform", PeriodogramTransformer()), |
| 64 | + ("clf", DummyClassifier()), |
| 65 | + ] |
| 66 | + ), |
| 67 | + VotingClassifier( |
| 68 | + estimators=[ |
| 69 | + ("clf1", DummyClassifier()), |
| 70 | + ("clf2", DummyClassifier()), |
| 71 | + ("clf3", DummyClassifier()), |
| 72 | + ] |
| 73 | + ), |
| 74 | + CalibratedClassifierCV( |
| 75 | + estimator=DummyClassifier(), |
| 76 | + cv=2, |
| 77 | + ), |
| 78 | +] |
| 79 | + |
| 80 | + |
| 81 | +def test_sklearn_cross_validation(): |
| 82 | + """Test sklearn cross-validation works with tsml data and classifiers.""" |
| 83 | + clf = DummyClassifier() |
| 84 | + X, y = generate_3d_test_data(n_samples=20, n_channels=2, series_length=30) |
| 85 | + scores = cross_val_score(clf, X, y=y, cv=KFold(n_splits=2)) |
| 86 | + assert isinstance(scores, np.ndarray) |
| 87 | + |
| 88 | + |
| 89 | +@pytest.mark.parametrize("cross_validation_method", CROSS_VALIDATION_METHODS) |
| 90 | +def test_sklearn_cross_validation_iterators(cross_validation_method): |
| 91 | + """Test if sklearn cross-validation iterators can handle tsml data.""" |
| 92 | + X, y = generate_3d_test_data(n_samples=20, n_channels=2, series_length=30) |
| 93 | + groups = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10] |
| 94 | + |
| 95 | + for train, test in cross_validation_method.split(X=X, y=y, groups=groups): |
| 96 | + assert isinstance(train, np.ndarray) and isinstance(test, np.ndarray) |
| 97 | + |
| 98 | + |
| 99 | +@pytest.mark.parametrize("parameter_tuning_method", PARAMETER_TUNING_METHODS) |
| 100 | +def test_sklearn_parameter_tuning(parameter_tuning_method): |
| 101 | + """Test if sklearn parameter tuners can handle tsml data and classifiers.""" |
| 102 | + clf = DummyClassifier() |
| 103 | + param_grid = {"strategy": ["prior", "constant"], "constant": [0, 1]} |
| 104 | + X, y = generate_3d_test_data(n_samples=20, n_channels=2, series_length=30) |
| 105 | + |
| 106 | + parameter_tuning_method = parameter_tuning_method( |
| 107 | + clf, param_grid, cv=KFold(n_splits=2) |
| 108 | + ) |
| 109 | + parameter_tuning_method.fit(X, y) |
| 110 | + assert isinstance(parameter_tuning_method.best_estimator_, DummyClassifier) |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.parametrize("composite_classifier", COMPOSITE_ESTIMATORS) |
| 114 | +def test_sklearn_composite_classifiers(composite_classifier): |
| 115 | + """Test if sklearn composite classifiers can handle tsml data and classifiers.""" |
| 116 | + X, y = generate_3d_test_data(n_samples=20, n_channels=2, series_length=30) |
| 117 | + composite_classifier.fit(X, y) |
| 118 | + preds = composite_classifier.predict(X=X) |
| 119 | + assert isinstance(preds, np.ndarray) |
0 commit comments