44
55"""
66
7- import os
87from abc import ABC , abstractmethod
98
9+ import tensorflow as tf
1010import keras
1111
1212
@@ -32,25 +32,11 @@ def __init__(
3232 self ,
3333 dataset ,
3434 ):
35- self .basedir = os .path .join (os .path .dirname (__file__ ), ".." , "predictors" )
3635 self .dataset = dataset
3736
3837 # Filled with get data if needed
3938 self ._data = None
4039
41- keras_version_file = f"{ dataset } _keras_version"
42-
43- try :
44- with open (os .path .join (self .basedir , keras_version_file )) as file_in :
45- version = file_in .read ().strip ()
46- except FileNotFoundError :
47- version = None
48- if version != keras .__version__ :
49- print (f"Keras version changed. Regenerate predictors for { dataset } " )
50- self .build_all_predictors ()
51- with open (os .path .join (self .basedir , keras_version_file ), "w" ) as file_out :
52- print (keras .__version__ , file = file_out )
53-
5440 def __iter__ (self ):
5541 return self .all_tested_layers .__iter__ ()
5642
@@ -69,39 +55,25 @@ def data(self):
6955 self .load_data ()
7056 return self ._data
7157
72- def predictor_file (self , predictor ):
73- return f"{ self .dataset } _{ layers_as_string (predictor )} .keras"
74-
75- def build_predictor (self , layers ):
58+ def get_case (self , layers ):
7659 """Build model for one predictor"""
7760 X , y = self .data
7861 predictor = self .compile (layers )
7962 predictor .fit (X , y )
8063
81- predictor .save (self .predictor_file (layers ))
8264 return predictor
8365
84- def build_all_predictors (self ):
85- """Build all the predictor for this case.
86- (Done when we have a new sklearn version)"""
87- for predictor in self :
88- self .build_predictor (predictor )
89-
90- def get_case (self , predictor ):
91- filename = self .predictor_file (predictor )
92- try :
93- return keras .saving .load_model (os .path .join (self .basedir , filename ))
94- except ValueError :
95- return self .build_predictor (predictor )
96-
9766
9867class HousingCases (Cases ):
9968 """Base class to have cases for testing regression models on diabetes set
10069
10170 This is appropriate for testing a regression with a single output."""
10271
10372 def __init__ (self ):
104- self .all_tested_layers = [[keras .layers .Dense (16 , activation = "relu" )]]
73+ self .all_tested_layers = [
74+ [keras .layers .Dense (16 , activation = "relu" )],
75+ [keras .layers .Dense (16 , activation = "sigmoid" )],
76+ ]
10577 super ().__init__ ("housing" )
10678 self .load_data ()
10779
@@ -117,3 +89,33 @@ def compile(self, layers):
11789 )
11890 nn .compile (loss = "mean_squared_error" , optimizer = "adam" )
11991 return nn
92+
93+
94+ class MNISTCases (Cases ):
95+ """Base class to have cases for testing regression models on diabetes set
96+
97+ This is appropriate for testing a regression with a single output."""
98+
99+ def __init__ (self ):
100+ self .all_tested_layers = [
101+ [keras .layers .Dense (20 , activation = "relu" )],
102+ [keras .layers .Dense (20 , activation = "sigmoid" )],
103+ ]
104+ super ().__init__ ("housing" )
105+ self .load_data ()
106+
107+ def load_data (self ):
108+ (X_train , y_train ), (_ , _ ) = keras .datasets .fashion_mnist .load_data ()
109+ X_train = tf .reshape (tf .cast (X_train , tf .float32 ) / 255.0 , [- 1 , 28 * 28 ])
110+ self ._data = (X_train , y_train )
111+
112+ def compile (self , layers ):
113+ nn = keras .models .Sequential (
114+ [keras .layers .InputLayer ((28 * 28 ,))] + layers + [keras .layers .Dense (10 )]
115+ )
116+ nn .compile (
117+ optimizer = "adam" ,
118+ loss = tf .keras .losses .SparseCategoricalCrossentropy (),
119+ metrics = [tf .keras .metrics .SparseCategoricalAccuracy ()],
120+ )
121+ return nn
0 commit comments