@@ -36,7 +36,7 @@ def get_datasets_and_generator(args, no_target=False):
36
36
else :
37
37
normal_dataset = datasets .NormalRVDataset (num_samples = args .num_samples ,
38
38
shape = args .out_shape ,
39
- static_sample = args .static_sample )
39
+ static_sample = not args .dynamic_sample )
40
40
normal_dataloader = torch .utils .data .DataLoader (normal_dataset ,
41
41
batch_size = args .batch_size )
42
42
return uniform_dataloader , normal_dataloader , generator
@@ -51,7 +51,7 @@ def parse_cli(parser, train_func, generate_func):
51
51
subparsers = parser .add_subparsers ()
52
52
train_parser = subparsers .add_parser ('train' )
53
53
train_parser .add_argument ('--epochs' , default = 5 , type = int )
54
- train_parser .add_argument ('--static -sample' , action = 'store_true' )
54
+ train_parser .add_argument ('--dynamic -sample' , action = 'store_true' )
55
55
train_parser .add_argument ('--learning-rate' , default = 1E-3 , type = float )
56
56
train_parser .set_defaults (func = train_func )
57
57
generate_parser = subparsers .add_parser ('generate' )
@@ -71,4 +71,6 @@ def generate(args):
71
71
for input_ in uniform_dataloader :
72
72
# Model forward pass.
73
73
output = generator (input_ .float ())
74
- print ('\t ' .join (map (str , output .squeeze ().tolist ())))
74
+ if len (output ) > 1 :
75
+ print ('\t ' .join (map (str , output .squeeze ().tolist ())))
76
+ print (output .item ())
0 commit comments