-
Notifications
You must be signed in to change notification settings - Fork 23
/
train.py
110 lines (96 loc) · 4.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# script to train interactive bots in toy world
# author: satwik kottur
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import itertools, pdb, random, os
import numpy as np
from chatbots import Team
from dataloader import Dataloader
import options
from time import gmtime, strftime
# read the command line options
options = options.read()
#------------------------------------------------------------------------
# setup experiment and dataset
#------------------------------------------------------------------------
data = Dataloader(options)
numInst = data.getInstCount()
params = data.params
# append options from options to params
for key, value in options.items():
params[key] = value
#------------------------------------------------------------------------
# build agents, and setup optmizer
#------------------------------------------------------------------------
team = Team(params)
team.train()
optimizer = optim.Adam([{'params': team.aBot.parameters(), \
'lr':params['learningRate']},\
{'params': team.qBot.parameters(), \
'lr':params['learningRate']}])
#------------------------------------------------------------------------
# train agents
#------------------------------------------------------------------------
# begin training
numIterPerEpoch = int(np.ceil(numInst['train']/params['batchSize']))
numIterPerEpoch = max(1, numIterPerEpoch)
count = 0
savePath = 'models/tasks_inter_%dH_%.4flr_%r_%d_%d.tar' %\
(params['hiddenSize'], params['learningRate'], params['remember'],\
options['aOutVocab'], options['qOutVocab'])
matches = {}
accuracy = {}
bestAccuracy = 0
for iterId in range(params['numEpochs'] * numIterPerEpoch):
epoch = float(iterId)/numIterPerEpoch
# get double attribute tasks
if 'train' not in matches:
batchImg, batchTask, batchLabels \
= data.getBatch(params['batchSize'])
else:
batchImg, batchTask, batchLabels \
= data.getBatchSpecial(params['batchSize'], matches['train'],\
params['negFraction'])
# forward pass
team.forward(Variable(batchImg), Variable(batchTask))
# backward pass
batchReward = team.backward(optimizer, batchLabels, epoch)
# take a step by optimizer
optimizer.step()
#--------------------------------------------------------------------------
# switch to evaluate
team.evaluate()
for dtype in ['train', 'test']:
# get the entire batch
img, task, labels = data.getCompleteData(dtype)
# evaluate on the train dataset, using greedy policy
guess, _, _ = team.forward(Variable(img), Variable(task))
# compute accuracy for color, shape, and both
firstMatch = guess[0].data == labels[:, 0].long()
secondMatch = guess[1].data == labels[:, 1].long()
matches[dtype] = firstMatch & secondMatch
accuracy[dtype] = 100*torch.sum(matches[dtype])\
/float(matches[dtype].size(0))
# switch to train
team.train()
# break if train accuracy reaches 100%
if accuracy['train'] == 100: break
# save for every 5k epochs
# if iterId > 0 and iterId % (10000*numIterPerEpoch) == 0:
if iterId >= 0 and iterId % (10000*numIterPerEpoch) == 0:
team.saveModel(savePath, optimizer, params)
if iterId % 100 != 0: continue
time = strftime("%a, %d %b %Y %X", gmtime())
print('[%s][Iter: %d][Ep: %.2f][R: %.4f][Tr: %.2f Te: %.2f]' % \
(time, iterId, epoch, team.totalReward,\
accuracy['train'], accuracy['test']))
#------------------------------------------------------------------------
# save final model with a time stamp
timeStamp = strftime("%a-%d-%b-%Y-%X", gmtime())
replaceWith = 'final_%s' % timeStamp
finalSavePath = savePath.replace('inter', replaceWith)
print('Saving : ' + finalSavePath)
team.saveModel(finalSavePath, optimizer, params)
#------------------------------------------------------------------------