From 61dd925d0614b0e649d624293f8110425a332aca Mon Sep 17 00:00:00 2001 From: Victor Bittorf Date: Tue, 22 Jun 2021 15:42:42 -0700 Subject: [PATCH] Hacked up Main File Example This doesn't run, but it most of the way there hopefully! --- mnist/build/mnist.py | 76 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 9 deletions(-) diff --git a/mnist/build/mnist.py b/mnist/build/mnist.py index 86e13ec..117fd04 100644 --- a/mnist/build/mnist.py +++ b/mnist/build/mnist.py @@ -24,6 +24,7 @@ class Task(str, Enum): DownloadData = 'download' Train = 'train' + Evaluate = 'evaluate' def create_directory(path: str) -> None: @@ -73,9 +74,13 @@ def train(task_args: List[str]) -> None: """ parser = argparse.ArgumentParser() parser.add_argument('--data_dir', '--data-dir', type=str, default=None, help="Dataset path.") + parser.add_argument('--model_in', '--model-in', type=str, default=None, help="Model output directory.") parser.add_argument('--model_dir', '--model-dir', type=str, default=None, help="Model output directory.") parser.add_argument('--parameters_file', '--parameters-file', type=str, default=None, help="Parameters default values.") + parser.add_argument('--metrics', '--metrics', type=str, default=None, + help="Parameters default values.") + args = parser.parse_args(args=task_args) with open(args.parameters_file, 'r') as stream: @@ -89,12 +94,20 @@ def train(task_args: List[str]) -> None: x_train, x_test = x_train / 255.0, x_test / 255.0 logger.info("Dataset has been loaded (%s).", dataset_file) - model = tf.keras.models.Sequential([ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(128, activation='relu'), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10, activation='softmax') - ]) + + + if args.model_in != '': + # Load from checkpoint; TODO confirm this API + model = tf.keras.models.load_model(os.path.join(args.model_in, 'mnist_model')) + else: + # if no model given on CLI, create a new one + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(10, activation='softmax') + ]) + logger.info("Model has been built.") model.compile( @@ -105,20 +118,63 @@ def train(task_args: List[str]) -> None: logger.info("Model has been compiled.") # Train and evaluate - model.fit( + history = model.fit( x_train, y_train, batch_size=parameters.get('batch_size', 32), + # TODO we will want to rename train_epochs probably in the file epochs=parameters.get('train_epochs', 5) ) logger.info("Model has been trained.") - model.evaluate(x_test, y_test, verbose=2) - logger.info("Model has been evaluated.") + # No evaluate in training + # model.evaluate(x_test, y_test, verbose=2) + # logger.info("Model has been evaluated.") + + with open(args.metrics, 'w') as f: + # TODO import json at the top + # TODO may need to modify this + f.write(json.dumps({'loss': history[-1]})) os.makedirs(args.model_dir, exist_ok=True) model.save(os.path.join(args.model_dir, 'mnist_model')) logger.info("Model has been saved.") + + + +def evaluate(task_args: List[str]) -> None: + """ Task: train. + Input parameters: + --data_dir, --log_dir, --model_dir, --parameters_file + """ + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', '--data-dir', type=str, default=None, help="Dataset path.") + parser.add_argument('--model_in', '--model-in', type=str, default=None, help="Model output directory.") + parser.add_argument('--parameters_file', '--parameters-file', type=str, default=None, + help="Parameters default values.") + args = parser.parse_args(args=task_args) + + with open(args.parameters_file, 'r') as stream: + parameters = yaml.load(stream, Loader=yaml.FullLoader) + logger.info("Parameters have been read (%s).", args.parameters_file) + + dataset_file = os.path.join(args.data_dir, 'mnist.npz') + with np.load(dataset_file, allow_pickle=True) as f: + x_train, y_train = f['x_train'], f['y_train'] + x_test, y_test = f['x_test'], f['y_test'] + x_train, x_test = x_train / 255.0, x_test / 255.0 + logger.info("Dataset has been loaded (%s).", dataset_file) + + model = tf.keras.models.load_model(os.path.join(args.model_in, 'mnist_model')) + + eval_result = model.evaluate(x_test, y_test, verbose=2) + + with open(args.metrics, 'w') as f: + # TODO import json at the top + # TODO may need to modify this + f.write(json.dumps({'accuracy': eval_result})) + + logger.info("Model has been evaluated.") def main(): @@ -159,6 +215,8 @@ def main(): download(task_args) elif mlcube_args.mlcube_task == Task.Train: train(task_args) + elif mlcube_args.mlcube_task == Task.Evaluate: + evaluate(task_args) else: raise ValueError(f"Unknown task: {task_args}") except Exception as err: