-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
executable file
·65 lines (44 loc) · 1.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import string
import pickle
import numpy as np
import torch
from torch.autograd import Variable
def init_weights(modules):
if isinstance(modules, torch.nn.Module):
modules = modules.modules()
for m in modules:
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_normal(m.weight.data)
m.bias.data.zero_()
if isinstance(m, torch.nn.GRU):
#weights = [m.W_ir, m.W_hr, m.W_iz, m.W_hz, m.W_in, m.W_hn, ]
#biases = [m.b_ir, m.b_hr, m.b_iz, m.b_hz, m.b_in, m.b_hn, ]
init_range = 0.01
for w in m._all_weights:
if 'weight' in w:
w.data.uniform_(-init_range, init_range)
elif 'bias' in w:
w.data.zero_()
def argmax(inputs, dim=-1):
values, indices = inputs.max(dim=dim)
return indices
def cuda(obj):
if torch.cuda.is_available():
obj = obj.cuda()
return obj
def variable(obj, volatile=False):
if isinstance(obj, (list, tuple)):
return [variable(o, volatile=volatile) for o in obj]
if isinstance(obj, np.ndarray):
obj = torch.from_numpy(obj)
obj = cuda(obj)
obj = Variable(obj, volatile=volatile)
return obj
def get_sequence_from_indices(indices, id2token):
tokens = [id2token[idx] for idx in indices]
tokens = [
' ' + t if i != 0 and not t.startswith("'") and not t.startswith("n'") and t not in string.punctuation else t
for i, t in enumerate(tokens)
]
sequence = ''.join(tokens)
return sequence