@@ -25,8 +25,8 @@ def train(args):
2525 cora = CoraDataLoader (verbose = True )
2626
2727 # To account for the initial CUDA Context object for pynvml
28- tmp = StaticGraph ([(0 ,0 )], [1 ], 1 )
29-
28+ tmp = StaticGraph ([(0 , 0 )], [1 ], 1 )
29+
3030 features = torch .FloatTensor (cora .get_all_features ())
3131 labels = torch .LongTensor (cora .get_all_targets ())
3232 train_mask = cora .get_train_mask ()
@@ -49,15 +49,15 @@ def train(args):
4949
5050 assert train_mask .shape [0 ] == num_nodes
5151
52- print (' dataset {}' .format ("Cora" ))
53- print (' # of edges : {}' .format (num_edges ))
54- print (' # of nodes : {}' .format (num_nodes ))
55- print (' # of features : {}' .format (num_feats ))
52+ print (" dataset {}" .format ("Cora" ))
53+ print (" # of edges : {}" .format (num_edges ))
54+ print (" # of nodes : {}" .format (num_nodes ))
55+ print (" # of features : {}" .format (num_feats ))
5656
5757 features = torch .FloatTensor (features )
5858 labels = torch .LongTensor (labels )
5959
60- if hasattr (torch , ' BoolTensor' ):
60+ if hasattr (torch , " BoolTensor" ):
6161 train_mask = torch .BoolTensor (train_mask )
6262
6363 else :
@@ -74,17 +74,19 @@ def train(args):
7474
7575 # create model
7676 heads = ([args .num_heads ] * args .num_layers ) + [args .num_out_heads ]
77- model = GAT (g ,
78- args .num_layers ,
79- num_feats ,
80- args .num_hidden ,
81- n_classes ,
82- heads ,
83- F .elu ,
84- args .in_drop ,
85- args .attn_drop ,
86- args .negative_slope ,
87- args .residual )
77+ model = GAT (
78+ g ,
79+ args .num_layers ,
80+ num_feats ,
81+ args .num_hidden ,
82+ n_classes ,
83+ heads ,
84+ F .elu ,
85+ args .in_drop ,
86+ args .attn_drop ,
87+ args .negative_slope ,
88+ args .residual ,
89+ )
8890 print (model )
8991 if args .early_stop :
9092 stopper = EarlyStopping (patience = 100 )
@@ -94,7 +96,8 @@ def train(args):
9496
9597 # use optimizer
9698 optimizer = torch .optim .Adam (
97- model .parameters (), lr = args .lr , weight_decay = args .weight_decay )
99+ model .parameters (), lr = args .lr , weight_decay = args .weight_decay
100+ )
98101
99102 # initialize graph
100103 dur = []
@@ -103,8 +106,8 @@ def train(args):
103106 Used_memory = 0
104107
105108 for epoch in range (args .num_epochs ):
106- #print('epoch = ', epoch)
107- #print('mem0 = {}'.format(mem0))
109+ # print('epoch = ', epoch)
110+ # print('mem0 = {}'.format(mem0))
108111 torch .cuda .synchronize ()
109112 tf = time .time ()
110113 model .train ()
@@ -120,7 +123,7 @@ def train(args):
120123 torch .cuda .synchronize ()
121124 loss .backward ()
122125 optimizer .step ()
123- t2 = time .time ()
126+ t2 = time .time ()
124127 run_time_this_epoch = t2 - tf
125128
126129 if epoch >= 3 :
@@ -131,56 +134,77 @@ def train(args):
131134
132135 train_acc = accuracy (logits [train_mask ], labels [train_mask ])
133136
134- #log for each step
135- print ('Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb' .format (
136- epoch , run_time_this_epoch , train_acc , (now_mem * 1.0 / (1024 ** 2 ))
137- ))
137+ # log for each step
138+ print (
139+ "Epoch {:05d} | Time(s) {:.4f} | train_acc {:.6f} | Used_Memory {:.6f} mb" .format (
140+ epoch , run_time_this_epoch , train_acc , (now_mem * 1.0 / (1024 ** 2 ))
141+ )
142+ )
138143
139144 if args .early_stop :
140- model .load_state_dict (torch .load (' es_checkpoint.pt' ))
145+ model .load_state_dict (torch .load (" es_checkpoint.pt" ))
141146
142- #OUTPUT we need
143- avg_run_time = avg_run_time * 1. / record_time
144- Used_memory /= ( 1024 ** 3 )
145- print (' ^^^{:6f}^^^{:6f}' .format (Used_memory , avg_run_time ))
147+ # OUTPUT we need
148+ avg_run_time = avg_run_time * 1.0 / record_time
149+ Used_memory /= 1024 ** 3
150+ print (" ^^^{:6f}^^^{:6f}" .format (Used_memory , avg_run_time ))
146151
147- if __name__ == '__main__' :
148152
149- parser = argparse .ArgumentParser (description = 'GAT' )
153+ if __name__ == "__main__" :
154+ parser = argparse .ArgumentParser (description = "GAT" )
150155
151156 # COMMENT IF SNOOP IS TO BE ENABLED
152157 snoop .install (enabled = False )
153158
154- parser .add_argument ("--gpu" , type = int , default = 0 ,
155- help = "which GPU to use. Set -1 to use CPU." )
156- parser .add_argument ("--num_epochs" , type = int , default = 200 ,
157- help = "number of training epochs" )
158- parser .add_argument ("--num_heads" , type = int , default = 8 ,
159- help = "number of hidden attention heads" )
160- parser .add_argument ("--num_out_heads" , type = int , default = 1 ,
161- help = "number of output attention heads" )
162- parser .add_argument ("--num_layers" , type = int , default = 1 ,
163- help = "number of hidden layers" )
164- parser .add_argument ("--num_hidden" , type = int , default = 32 ,
165- help = "number of hidden units" )
166- parser .add_argument ("--residual" , action = "store_true" , default = False ,
167- help = "use residual connection" )
168- parser .add_argument ("--in_drop" , type = float , default = .6 ,
169- help = "input feature dropout" )
170- parser .add_argument ("--attn_drop" , type = float , default = .6 ,
171- help = "attention dropout" )
172- parser .add_argument ("--lr" , type = float , default = 0.005 ,
173- help = "learning rate" )
174- parser .add_argument ('--weight_decay' , type = float , default = 5e-4 ,
175- help = "weight decay" )
176- parser .add_argument ('--negative_slope' , type = float , default = 0.2 ,
177- help = "the negative slope of leaky relu" )
178- parser .add_argument ('--early_stop' , action = 'store_true' , default = False ,
179- help = "indicates whether to use early stop or not" )
180- parser .add_argument ('--fastmode' , action = "store_true" , default = False ,
181- help = "skip re-evaluate the validation set" )
159+ parser .add_argument (
160+ "--gpu" , type = int , default = 0 , help = "which GPU to use. Set -1 to use CPU."
161+ )
162+ parser .add_argument (
163+ "--num_epochs" , type = int , default = 200 , help = "number of training epochs"
164+ )
165+ parser .add_argument (
166+ "--num_heads" , type = int , default = 8 , help = "number of hidden attention heads"
167+ )
168+ parser .add_argument (
169+ "--num_out_heads" , type = int , default = 1 , help = "number of output attention heads"
170+ )
171+ parser .add_argument (
172+ "--num_layers" , type = int , default = 1 , help = "number of hidden layers"
173+ )
174+ parser .add_argument (
175+ "--num_hidden" , type = int , default = 32 , help = "number of hidden units"
176+ )
177+ parser .add_argument (
178+ "--residual" , action = "store_true" , default = False , help = "use residual connection"
179+ )
180+ parser .add_argument (
181+ "--in_drop" , type = float , default = 0.6 , help = "input feature dropout"
182+ )
183+ parser .add_argument (
184+ "--attn_drop" , type = float , default = 0.6 , help = "attention dropout"
185+ )
186+ parser .add_argument ("--lr" , type = float , default = 0.005 , help = "learning rate" )
187+ parser .add_argument ("--weight_decay" , type = float , default = 5e-4 , help = "weight decay" )
188+ parser .add_argument (
189+ "--negative_slope" ,
190+ type = float ,
191+ default = 0.2 ,
192+ help = "the negative slope of leaky relu" ,
193+ )
194+ parser .add_argument (
195+ "--early_stop" ,
196+ action = "store_true" ,
197+ default = False ,
198+ help = "indicates whether to use early stop or not" ,
199+ )
200+ parser .add_argument (
201+ "--fastmode" ,
202+ action = "store_true" ,
203+ default = False ,
204+ help = "skip re-evaluate the validation set" ,
205+ )
182206 args = parser .parse_args ()
183207
184208 print (args )
185-
209+
186210 train (args )
0 commit comments