Skip to content

Hacked up Main File Example #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions mnist/build/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class Task(str, Enum):
DownloadData = 'download'
Train = 'train'
Evaluate = 'evaluate'


def create_directory(path: str) -> None:
Expand Down Expand Up @@ -73,9 +74,13 @@ def train(task_args: List[str]) -> None:
"""
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', '--data-dir', type=str, default=None, help="Dataset path.")
parser.add_argument('--model_in', '--model-in', type=str, default=None, help="Model output directory.")
parser.add_argument('--model_dir', '--model-dir', type=str, default=None, help="Model output directory.")
parser.add_argument('--parameters_file', '--parameters-file', type=str, default=None,
help="Parameters default values.")
parser.add_argument('--metrics', '--metrics', type=str, default=None,
help="Parameters default values.")

args = parser.parse_args(args=task_args)

with open(args.parameters_file, 'r') as stream:
Expand All @@ -89,12 +94,20 @@ def train(task_args: List[str]) -> None:
x_train, x_test = x_train / 255.0, x_test / 255.0
logger.info("Dataset has been loaded (%s).", dataset_file)

model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])


if args.model_in != '':
# Load from checkpoint; TODO confirm this API
model = tf.keras.models.load_model(os.path.join(args.model_in, 'mnist_model'))
else:
# if no model given on CLI, create a new one
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])

logger.info("Model has been built.")

model.compile(
Expand All @@ -105,20 +118,63 @@ def train(task_args: List[str]) -> None:
logger.info("Model has been compiled.")

# Train and evaluate
model.fit(
history = model.fit(
x_train,
y_train,
batch_size=parameters.get('batch_size', 32),
# TODO we will want to rename train_epochs probably in the file
epochs=parameters.get('train_epochs', 5)
)
logger.info("Model has been trained.")

model.evaluate(x_test, y_test, verbose=2)
logger.info("Model has been evaluated.")
# No evaluate in training
# model.evaluate(x_test, y_test, verbose=2)
# logger.info("Model has been evaluated.")

with open(args.metrics, 'w') as f:
# TODO import json at the top
# TODO may need to modify this
f.write(json.dumps({'loss': history[-1]}))

os.makedirs(args.model_dir, exist_ok=True)
model.save(os.path.join(args.model_dir, 'mnist_model'))
logger.info("Model has been saved.")



def evaluate(task_args: List[str]) -> None:
""" Task: train.
Input parameters:
--data_dir, --log_dir, --model_dir, --parameters_file
"""
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', '--data-dir', type=str, default=None, help="Dataset path.")
parser.add_argument('--model_in', '--model-in', type=str, default=None, help="Model output directory.")
parser.add_argument('--parameters_file', '--parameters-file', type=str, default=None,
help="Parameters default values.")
args = parser.parse_args(args=task_args)

with open(args.parameters_file, 'r') as stream:
parameters = yaml.load(stream, Loader=yaml.FullLoader)
logger.info("Parameters have been read (%s).", args.parameters_file)

dataset_file = os.path.join(args.data_dir, 'mnist.npz')
with np.load(dataset_file, allow_pickle=True) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
logger.info("Dataset has been loaded (%s).", dataset_file)

model = tf.keras.models.load_model(os.path.join(args.model_in, 'mnist_model'))

eval_result = model.evaluate(x_test, y_test, verbose=2)

with open(args.metrics, 'w') as f:
# TODO import json at the top
# TODO may need to modify this
f.write(json.dumps({'accuracy': eval_result}))

logger.info("Model has been evaluated.")


def main():
Expand Down Expand Up @@ -159,6 +215,8 @@ def main():
download(task_args)
elif mlcube_args.mlcube_task == Task.Train:
train(task_args)
elif mlcube_args.mlcube_task == Task.Evaluate:
evaluate(task_args)
else:
raise ValueError(f"Unknown task: {task_args}")
except Exception as err:
Expand Down