-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_lgnn.py
149 lines (128 loc) · 5.86 KB
/
main_lgnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# -*- encoding: utf-8 -*-
'''
@File : main_lgnn.py
@Time : 2021/03/25 10:49:06
@Author : Fei gao
@Contact : [email protected]
BNU, Beijing, China
'''
import argparse
import time
import numpy as np
import pickle as pkl
import torch
from torch.utils.data import DataLoader
from utils.dataset import SBM_dataset
from model.LGNN import LGNN
from utils.losses import compute_loss_multiclass, compute_accuracy_multiclass
from utils.utils import data_convert
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset setting
parser.add_argument("--num_examples_train", type=int, default=6000)
parser.add_argument("--num_examples_test", type=int, default=100)
parser.add_argument('--p_SBM', type=float, default=0.0)
parser.add_argument('--q_SBM', type=float, default=0.045)
parser.add_argument("--N_train", type=int, default=400, help="num of nodes")
parser.add_argument("--N_test", type=int, default=400)
parser.add_argument("--n_classes", type=int, default=5)
parser.add_argument("--save_path_root", type=str, default="./data/")
# GNN parameters
parser.add_argument("--J", type=int, default=2)
parser.add_argument('--hid_dim', type=int, default=8)
parser.add_argument('--num_layers', type=int, default=30)
parser.add_argument("--lr", type=float, default=0.004)
# pytorch setting #
parser.add_argument("--cuda_id", type=int, default=0, help="-1 for cpu")
parser.add_argument("--num_workers", type=int, default=15, help="number of workers for data loader")
parser.add_argument("--torch_seed", type=int, default=42)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float, default=40.0)
# experiments setting
parser.add_argument("--show_freq", type=int, default=10)
parser.add_argument("--model_save_path", type=str, default="./model_ckp/")
parser.add_argument("--results_save_path", type=str, default="./results/")
args = parser.parse_args()
if torch.cuda.is_available() and args.cuda_id>=0:
device = torch.device("cuda:{}".format(args.cuda_id))
else:
device = torch.device("cpu")
torch.manual_seed(args.torch_seed)
# Dataset
dataset_train = SBM_dataset(args.p_SBM,
args.q_SBM,
graph_size=args.N_train,
n_classes=args.n_classes,
num_graphs=args.num_examples_train,
J=args.J,
train=True,
path_root=args.save_path_root,
save_data=True)
dataloader_train = DataLoader(dataset_train,
batch_size=1,
num_workers=args.num_workers,
shuffle=True,
collate_fn=dataset_train.collate_fn)
# Model
model = LGNN(hid_dim=args.hid_dim,
num_layers=args.num_layers,
J=args.J,
num_classes=args.n_classes,
device=device)
# Optimizer
optimizer = torch.optim.Adamax(model.parameters(), lr=args.lr)
# Training
losses = []
acces = []
t0 = time.time()
for cnt, data in enumerate(dataloader_train):
optimizer.zero_grad()
inputs = data_convert(data, device)
pred = model(inputs) # [N, n_classes]
pred = pred[None, :, :]
label = torch.Tensor(data["label"])[None, :]
loss = compute_loss_multiclass(pred, label, args.n_classes)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
acc = compute_accuracy_multiclass(pred, label, args.n_classes)
losses.append(loss.item())
acces.append(acc.item())
t1 = time.time()
if (cnt+1)%args.show_freq == 0 :
ave_loss = np.mean(losses[-args.show_freq:])
ave_acc = np.mean(acces[-args.show_freq:])
print("Sample : {:<3}, Loss={:.3f}, ACC={:.4f}, cmu_Time={:.1f}s, iter_Time={:.1f}s".format(cnt, ave_loss, ave_acc, t1-t0, (t1-t0)/(cnt+1)))
with open(args.results_save_path+"loss_acc.pkl", "wb") as f:
pkl.dump([losses, acces], f)
torch.save(model.state_dict(), args.model_save_path+"model.pt")
# Testing
dataset_test = SBM_dataset(args.p_SBM,
args.q_SBM,
graph_size=args.N_train,
n_classes=args.n_classes,
num_graphs=args.num_examples_test,
J=args.J,
train=False,
path_root=args.save_path_root,
save_data=True)
dataloader_test = DataLoader(dataset_test,
batch_size=1,
num_workers=args.num_workers,
shuffle=True,
collate_fn=dataset_test.collate_fn)
model.eval()
acces_test = []
t0 = time.time()
for cnt, data in enumerate(dataloader_test):
inputs = data_convert(data, device)
pred = model(inputs) # [N, n_classes]
pred = pred[None, :, :]
label = torch.Tensor(data["label"])[None, :]
acc = compute_accuracy_multiclass(pred, label, args.n_classes)
acces_test.append(acc.item())
t1 = time.time()
if (cnt+1)%args.show_freq == 0 :
ave_acc = np.mean(acces_test[-args.show_freq:])
print("Test : {:<3}, ACC={:.4f}, cmu_Time={:.1f}s, iter_Time={:.1f}s".format(cnt, ave_acc, t1-t0, (t1-t0)/(cnt+1)))
with open(args.results_save_path+"test_acc.pkl", "wb") as f:
pkl.dump(acces_test, f)