66from mrt import conf , utils
77
88import numpy as np
9+ import argparse
10+
11+ parser = argparse .ArgumentParser (description = 'Mnist Traning' )
12+ parser .add_argument ('--cpu' , default = False , action = 'store_true' ,
13+ help = 'whether enable cpu (default use gpu)' )
14+ parser .add_argument ('--gpu-id' , type = int , default = 0 ,
15+ help = 'gpu device id' )
16+ parser .add_argument ('--net' , type = str , default = '' ,
17+ help = 'choose available networks, optional: lenet, mlp' )
18+
19+ args = parser .parse_args ()
920
1021def load_fname (version , suffix = None , with_ext = False ):
1122 suffix = "." + suffix if suffix is not None else ""
12- prefix = "{}/mnist_{}{}" .format (conf .MRT_MODEL_ROOT , version , suffix )
23+ version = "_" + version if version is not None else ""
24+ prefix = "{}/mnist{}{}" .format (conf .MRT_MODEL_ROOT , version , suffix )
1325 return utils .extend_fname (prefix , with_ext )
1426
1527def data_xform (data ):
@@ -25,10 +37,12 @@ def data_xform(data):
2537train_loader = mx .gluon .data .DataLoader (train_data , shuffle = True , batch_size = batch_size )
2638val_loader = mx .gluon .data .DataLoader (val_data , shuffle = False , batch_size = batch_size )
2739
28- version = ''
40+ version = args .net
41+ print ("Training {} Mnist" .format (version ))
2942
3043# Set the gpu device id
31- ctx = mx .gpu (0 )
44+ ctx = mx .cpu () if args .cpu else mx .gpu (args .gpu_id )
45+ print ("Using device: {}" .format (ctx ))
3246
3347def train_mnist ():
3448 # Select a fixed random seed for reproducibility
@@ -70,6 +84,8 @@ def train_mnist():
7084 nn .Dense (64 , activation = 'relu' ),
7185 nn .Dense (10 , activation = None ) # loss function includes softmax already, see below
7286 )
87+ else :
88+ assert False
7389
7490 # Random initialize all the mnist model parameters
7591 net .initialize (mx .init .Xavier (), ctx = ctx )
@@ -118,5 +134,4 @@ def train_mnist():
118134 fout .write (sym .tojson ())
119135 net .collect_params ().save (param_file )
120136
121- print ("Test mnist" , version )
122137train_mnist ()
0 commit comments