-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtrain.py
198 lines (172 loc) · 9.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# -*- encoding: utf-8 -*-
'''
@File : train.py
@Time : 2021/02/20 10:25:13
@Author : Fei gao
@Contact : [email protected]
BNU, Beijing, China
'''
import argparse
import networkx as nx
import numpy as np
import dill
import pickle as pkl
import scipy
from torch.utils.data import DataLoader
from utils.preprocess import load_graphs, get_context_pairs, get_evaluation_data
from utils.minibatch import MyDataset
from utils.utilities import to_device
from eval.link_prediction import evaluate_classifier
from models.model import DySAT
import torch
torch.autograd.set_detect_anomaly(True)
def inductive_graph(graph_former, graph_later):
"""Create the adj_train so that it includes nodes from (t+1)
but only edges from t: this is for the purpose of inductive testing.
Args:
graph_former ([type]): [description]
graph_later ([type]): [description]
"""
newG = nx.MultiGraph()
newG.add_nodes_from(graph_later.nodes(data=True))
newG.add_edges_from(graph_former.edges(data=False))
return newG
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--time_steps', type=int, nargs='?', default=16,
help="total time steps used for train, eval and test")
# Experimental settings.
parser.add_argument('--dataset', type=str, nargs='?', default='Enron',
help='dataset name')
parser.add_argument('--GPU_ID', type=int, nargs='?', default=0,
help='GPU_ID (0/1 etc.)')
parser.add_argument('--epochs', type=int, nargs='?', default=200,
help='# epochs')
parser.add_argument('--val_freq', type=int, nargs='?', default=1,
help='Validation frequency (in epochs)')
parser.add_argument('--test_freq', type=int, nargs='?', default=1,
help='Testing frequency (in epochs)')
parser.add_argument('--batch_size', type=int, nargs='?', default=512,
help='Batch size (# nodes)')
parser.add_argument('--featureless', type=bool, nargs='?', default=True,
help='True if one-hot encoding.')
parser.add_argument("--early_stop", type=int, default=10,
help="patient")
# 1-hot encoding is input as a sparse matrix - hence no scalability issue for large datasets.
# Tunable hyper-params
# TODO: Implementation has not been verified, performance may not be good.
parser.add_argument('--residual', type=bool, nargs='?', default=True,
help='Use residual')
# Number of negative samples per positive pair.
parser.add_argument('--neg_sample_size', type=int, nargs='?', default=10,
help='# negative samples per positive')
# Walk length for random walk sampling.
parser.add_argument('--walk_len', type=int, nargs='?', default=20,
help='Walk length for random walk sampling')
# Weight for negative samples in the binary cross-entropy loss function.
parser.add_argument('--neg_weight', type=float, nargs='?', default=1.0,
help='Weightage for negative samples')
parser.add_argument('--learning_rate', type=float, nargs='?', default=0.01,
help='Initial learning rate for self-attention model.')
parser.add_argument('--spatial_drop', type=float, nargs='?', default=0.1,
help='Spatial (structural) attention Dropout (1 - keep probability).')
parser.add_argument('--temporal_drop', type=float, nargs='?', default=0.5,
help='Temporal attention Dropout (1 - keep probability).')
parser.add_argument('--weight_decay', type=float, nargs='?', default=0.0005,
help='Initial learning rate for self-attention model.')
# Architecture params
parser.add_argument('--structural_head_config', type=str, nargs='?', default='16,8,8',
help='Encoder layer config: # attention heads in each GAT layer')
parser.add_argument('--structural_layer_config', type=str, nargs='?', default='128',
help='Encoder layer config: # units in each GAT layer')
parser.add_argument('--temporal_head_config', type=str, nargs='?', default='16',
help='Encoder layer config: # attention heads in each Temporal layer')
parser.add_argument('--temporal_layer_config', type=str, nargs='?', default='128',
help='Encoder layer config: # units in each Temporal layer')
parser.add_argument('--position_ffn', type=str, nargs='?', default='True',
help='Position wise feedforward')
parser.add_argument('--window', type=int, nargs='?', default=-1,
help='Window for temporal attention (default : -1 => full)')
args = parser.parse_args()
print(args)
#graphs, feats, adjs = load_graphs(args.dataset)
graphs, adjs = load_graphs(args.dataset)
if args.featureless == True:
feats = [scipy.sparse.identity(adjs[args.time_steps - 1].shape[0]).tocsr()[range(0, x.shape[0]), :] for x in adjs if
x.shape[0] <= adjs[args.time_steps - 1].shape[0]]
assert args.time_steps <= len(adjs), "Time steps is illegal"
context_pairs_train = get_context_pairs(graphs, adjs)
# Load evaluation data for link prediction.
train_edges_pos, train_edges_neg, val_edges_pos, val_edges_neg, \
test_edges_pos, test_edges_neg = get_evaluation_data(graphs)
print("No. Train: Pos={}, Neg={} \nNo. Val: Pos={}, Neg={} \nNo. Test: Pos={}, Neg={}".format(
len(train_edges_pos), len(train_edges_neg), len(val_edges_pos), len(val_edges_neg),
len(test_edges_pos), len(test_edges_neg)))
# Create the adj_train so that it includes nodes from (t+1) but only edges from t: this is for the purpose of
# inductive testing.
new_G = inductive_graph(graphs[args.time_steps-2], graphs[args.time_steps-1])
graphs[args.time_steps-1] = new_G
adjs[args.time_steps-1] = nx.adjacency_matrix(new_G)
# build dataloader and model
device = torch.device("cuda:0")
dataset = MyDataset(args, graphs, feats, adjs, context_pairs_train)
dataloader = DataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=10,
collate_fn=MyDataset.collate_fn)
#dataloader = NodeMinibatchIterator(args, graphs, feats, adjs, context_pairs_train, device)
model = DySAT(args, feats[0].shape[1], args.time_steps).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
# in training
best_epoch_val = 0
patient = 0
for epoch in range(args.epochs):
model.train()
epoch_loss = []
for idx, feed_dict in enumerate(dataloader):
feed_dict = to_device(feed_dict, device)
opt.zero_grad()
loss = model.get_loss(feed_dict)
loss.backward()
opt.step()
epoch_loss.append(loss.item())
model.eval()
emb = model(feed_dict["graphs"])[:, -2, :].detach().cpu().numpy()
val_results, test_results, _, _ = evaluate_classifier(train_edges_pos,
train_edges_neg,
val_edges_pos,
val_edges_neg,
test_edges_pos,
test_edges_neg,
emb,
emb)
epoch_auc_val = val_results["HAD"][1]
epoch_auc_test = test_results["HAD"][1]
if epoch_auc_val > best_epoch_val:
best_epoch_val = epoch_auc_val
torch.save(model.state_dict(), "./model_checkpoints/model.pt")
patient = 0
else:
patient += 1
if patient > args.early_stop:
break
print("Epoch {:<3}, Loss = {:.3f}, Val AUC {:.3f} Test AUC {:.3f}".format(epoch,
np.mean(epoch_loss),
epoch_auc_val,
epoch_auc_test))
# Test Best Model
model.load_state_dict(torch.load("./model_checkpoints/model.pt"))
model.eval()
emb = model(feed_dict["graphs"])[:, -2, :].detach().cpu().numpy()
val_results, test_results, _, _ = evaluate_classifier(train_edges_pos,
train_edges_neg,
val_edges_pos,
val_edges_neg,
test_edges_pos,
test_edges_neg,
emb,
emb)
auc_val = val_results["HAD"][1]
auc_test = test_results["HAD"][1]
print("Best Test AUC = {:.3f}".format(auc_test))