Skip to content

Commit e569c36

Browse files
committed
Try completing test with Keras
Also adds a classification test probably this will be overkill for GH actions but try it.
1 parent 367196e commit e569c36

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

tests/test_keras/keras_cases.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
55
"""
66

7-
import os
87
from abc import ABC, abstractmethod
98

9+
import tensorflow as tf
1010
import keras
1111

1212

@@ -32,25 +32,11 @@ def __init__(
3232
self,
3333
dataset,
3434
):
35-
self.basedir = os.path.join(os.path.dirname(__file__), "..", "predictors")
3635
self.dataset = dataset
3736

3837
# Filled with get data if needed
3938
self._data = None
4039

41-
keras_version_file = f"{dataset}_keras_version"
42-
43-
try:
44-
with open(os.path.join(self.basedir, keras_version_file)) as file_in:
45-
version = file_in.read().strip()
46-
except FileNotFoundError:
47-
version = None
48-
if version != keras.__version__:
49-
print(f"Keras version changed. Regenerate predictors for {dataset}")
50-
self.build_all_predictors()
51-
with open(os.path.join(self.basedir, keras_version_file), "w") as file_out:
52-
print(keras.__version__, file=file_out)
53-
5440
def __iter__(self):
5541
return self.all_tested_layers.__iter__()
5642

@@ -69,39 +55,25 @@ def data(self):
6955
self.load_data()
7056
return self._data
7157

72-
def predictor_file(self, predictor):
73-
return f"{self.dataset}_{layers_as_string(predictor)}.keras"
74-
75-
def build_predictor(self, layers):
58+
def get_case(self, layers):
7659
"""Build model for one predictor"""
7760
X, y = self.data
7861
predictor = self.compile(layers)
7962
predictor.fit(X, y)
8063

81-
predictor.save(self.predictor_file(layers))
8264
return predictor
8365

84-
def build_all_predictors(self):
85-
"""Build all the predictor for this case.
86-
(Done when we have a new sklearn version)"""
87-
for predictor in self:
88-
self.build_predictor(predictor)
89-
90-
def get_case(self, predictor):
91-
filename = self.predictor_file(predictor)
92-
try:
93-
return keras.saving.load_model(os.path.join(self.basedir, filename))
94-
except ValueError:
95-
return self.build_predictor(predictor)
96-
9766

9867
class HousingCases(Cases):
9968
"""Base class to have cases for testing regression models on diabetes set
10069
10170
This is appropriate for testing a regression with a single output."""
10271

10372
def __init__(self):
104-
self.all_tested_layers = [[keras.layers.Dense(16, activation="relu")]]
73+
self.all_tested_layers = [
74+
[keras.layers.Dense(16, activation="relu")],
75+
[keras.layers.Dense(16, activation="sigmoid")],
76+
]
10577
super().__init__("housing")
10678
self.load_data()
10779

@@ -117,3 +89,33 @@ def compile(self, layers):
11789
)
11890
nn.compile(loss="mean_squared_error", optimizer="adam")
11991
return nn
92+
93+
94+
class MNISTCases(Cases):
95+
"""Base class to have cases for testing regression models on diabetes set
96+
97+
This is appropriate for testing a regression with a single output."""
98+
99+
def __init__(self):
100+
self.all_tested_layers = [
101+
[keras.layers.Dense(20, activation="relu")],
102+
[keras.layers.Dense(20, activation="sigmoid")],
103+
]
104+
super().__init__("housing")
105+
self.load_data()
106+
107+
def load_data(self):
108+
(X_train, y_train), (_, _) = keras.datasets.fashion_mnist.load_data()
109+
X_train = tf.reshape(tf.cast(X_train, tf.float32) / 255.0, [-1, 28 * 28])
110+
self._data = (X_train, y_train)
111+
112+
def compile(self, layers):
113+
nn = keras.models.Sequential(
114+
[keras.layers.InputLayer((28 * 28,))] + layers + [keras.layers.Dense(10)]
115+
)
116+
nn.compile(
117+
optimizer="adam",
118+
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
119+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
120+
)
121+
return nn

tests/test_keras/test_keras_formulations.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from joblib import load
55

66
from ..fixed_formulation import FixedRegressionModel
7-
from .keras_cases import HousingCases
7+
from .keras_cases import HousingCases, MNISTCases
88

99
VERBOSE = False
1010

@@ -45,3 +45,13 @@ def test_housing_keras(self):
4545
onecase = {"predictor": regressor, "nonconvex": 0}
4646
self.do_one_case(onecase, X, 5, "all")
4747
self.do_one_case(onecase, X, 6, "pairs")
48+
49+
def test_mnist_keras(self):
50+
cases = MNISTCases()
51+
52+
X = cases._data[0].numpy()
53+
for case in cases:
54+
regressor = cases.get_case(case)
55+
onecase = {"predictor": regressor, "nonconvex": 0}
56+
self.do_one_case(onecase, X, 5, "all")
57+
self.do_one_case(onecase, X, 6, "pairs")

0 commit comments

Comments
 (0)