forked from simonalexanderson/StyleGestures
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_moglow.py
61 lines (51 loc) · 1.8 KB
/
train_moglow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Train script.
Usage:
train_moglow.py <hparams> <dataset>
"""
import os
import motion
import numpy as np
import datetime
from docopt import docopt
from torch.utils.data import DataLoader, Dataset
from glow.builder import build
from glow.trainer import Trainer
from glow.config import JsonConfig
if __name__ == "__main__":
args = docopt(__doc__)
hparams = args["<hparams>"]
dataset = args["<dataset>"]
assert dataset in motion.Datasets, (
"`{}` is not supported, use `{}`".format(dataset, motion.Datasets.keys()))
assert os.path.exists(hparams), (
"Failed to find hparams josn `{}`".format(hparams))
hparams = JsonConfig(hparams)
dataset = motion.Datasets[dataset]
date = str(datetime.datetime.now())
date = date[:date.rfind(":")].replace("-", "")\
.replace(":", "")\
.replace(" ", "_")
log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
print("log_dir:" + str(log_dir))
data = dataset(hparams)
x_channels, cond_channels = data.get_train_dataset().n_channels()
# build graph
if hparams.Infer.pre_trained == "":
built = build(x_channels, cond_channels, hparams, True)
else:
built = build(x_channels, cond_channels, hparams, False)
# build trainer
trainer = Trainer(**built, data=data, log_dir=log_dir, hparams=hparams)
if hparams.Infer.pre_trained == "":
# train model
trainer.train()
else:
# generate from pre-trained model
if "temperature" in hparams.Infer:
temp = hparams.Infer.temperature
else:
temp = 1
for i in range(5):
trainer.generate_sample(eps_std=temp, counter=i)