Skip to content

Commit 955d169

Browse files
committed
FIX: Do not join by tabs if output is size 1
1 parent bc01493 commit 955d169

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_datasets_and_generator(args, no_target=False):
3636
else:
3737
normal_dataset = datasets.NormalRVDataset(num_samples=args.num_samples,
3838
shape=args.out_shape,
39-
static_sample=args.static_sample)
39+
static_sample=not args.dynamic_sample)
4040
normal_dataloader = torch.utils.data.DataLoader(normal_dataset,
4141
batch_size=args.batch_size)
4242
return uniform_dataloader, normal_dataloader, generator
@@ -51,7 +51,7 @@ def parse_cli(parser, train_func, generate_func):
5151
subparsers = parser.add_subparsers()
5252
train_parser = subparsers.add_parser('train')
5353
train_parser.add_argument('--epochs', default=5, type=int)
54-
train_parser.add_argument('--static-sample', action='store_true')
54+
train_parser.add_argument('--dynamic-sample', action='store_true')
5555
train_parser.add_argument('--learning-rate', default=1E-3, type=float)
5656
train_parser.set_defaults(func=train_func)
5757
generate_parser = subparsers.add_parser('generate')
@@ -71,4 +71,6 @@ def generate(args):
7171
for input_ in uniform_dataloader:
7272
# Model forward pass.
7373
output = generator(input_.float())
74-
print('\t'.join(map(str, output.squeeze().tolist())))
74+
if len(output) > 1:
75+
print('\t'.join(map(str, output.squeeze().tolist())))
76+
print(output.item())

0 commit comments

Comments
 (0)