-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun.py
129 lines (107 loc) · 4.26 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
import os
import json
import pprint as pp
import torch
import torch.optim as optim
import torch.nn as nn
from options import get_options
from train import train_epoch, validate, get_inner_model
from nets.attention_model import AttentionModel
from utils import torch_load_cpu, load_problem
def run(opts):
# Pretty print the run args
pp.pprint(vars(opts))
# Set the random seed
torch.manual_seed(opts.seed)
os.makedirs(opts.save_dir)
# Save arguments so exact configuration can always be found
with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
json.dump(vars(opts), f, indent=True)
# Set the device
opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu")
# Figure out what's the problem
problem = load_problem(opts.problem)
# Load data from load_path
load_data = {}
assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given"
load_path = opts.load_path if opts.load_path is not None else opts.resume
if load_path is not None:
print(' [*] Loading data from {}'.format(load_path))
load_data = torch_load_cpu(load_path)
# Initialize model
model_class = {
'attention': AttentionModel
}.get(opts.model, None)
assert model_class is not None, "Unknown model: {}".format(model_class)
model = model_class(
opts.embedding_dim,
opts.hidden_dim,
problem,
n_encode_layers=opts.n_encode_layers,
mask_inner=True,
mask_logits=True,
normalization=opts.normalization,
tanh_clipping=opts.tanh_clipping,
checkpoint_encoder=opts.checkpoint_encoder,
shrink_size=opts.shrink_size,
ft = opts.ft
).to(opts.device)
# if opts.use_cuda and torch.cuda.device_count() > 1:
# model = torch.nn.DataParallel(model)
# Overwrite model parameters by parameters to load
model_ = get_inner_model(model)
model_.load_state_dict({**model_.state_dict(), **load_data.get('model', {})})
if opts.ft == "N":
# Initialize optimizer
optimizer = optim.Adam(
[{'params': model.parameters(), 'lr': opts.lr_model}]
)
else:
for p in model.parameters():
p.requires_grad = False
model.contextual_emb = nn.Sequential(nn.Linear(opts.embedding_dim, 8 * opts.embedding_dim, bias=False),
nn.ReLU(),
nn.Linear(8 * opts.embedding_dim, opts.embedding_dim, bias=False)
)
model = model.to(opts.device)
optimizer = optim.Adam(
[{'params': model.contextual_emb.parameters(), 'lr': opts.lr_model}]
)
# # Load optimizer state
# if 'optimizer' in load_data:
# optimizer.load_state_dict(load_data['optimizer'])
# for state in optimizer.state.values():
# for k, v in state.items():
# # if isinstance(v, torch.Tensor):
# if torch.is_tensor(v):
# state[k] = v.to(opts.device)
# Initialize learning rate scheduler, decay by lr_decay once per epoch!
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: opts.lr_decay ** epoch)
# Start the actual training loop
val_dataset = problem.make_dataset(
size=opts.graph_size, num_samples=opts.val_size, filename=opts.val_dataset, distribution=opts.data_distribution)
if opts.resume:
epoch_resume = int(os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1])
torch.set_rng_state(load_data['rng_state'])
if opts.use_cuda:
torch.cuda.set_rng_state_all(load_data['cuda_rng_state'])
# Set the random states
# Dumping of state was done before epoch callback, so do that now (model is loaded)
print("Resuming after {}".format(epoch_resume))
opts.epoch_start = epoch_resume + 1
if opts.eval_only:
validate(model, val_dataset, opts)
else:
for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs):
train_epoch(
model,
optimizer,
lr_scheduler,
epoch,
val_dataset,
problem,
opts
)
if __name__ == "__main__":
run(get_options())