-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_wandb.py
More file actions
executable file
·297 lines (267 loc) · 16.4 KB
/
train_wandb.py
File metadata and controls
executable file
·297 lines (267 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import os
import argparse
import logging
import torch
import pickle
from scipy.sparse import SparseEfficiencyWarning
import sys
from subgraph_extraction.datasets import SubgraphDataset, generate_subgraph_datasets
from utils.initialization_utils import initialize_experiment, initialize_model
from utils.graph_utils import collate_dgl, move_batch_to_device_dgl, move_batch_to_device_dgl_ddi2
from model.dgl.graph_classifier import GraphClassifier as dgl_model
from warnings import simplefilter
from nodefeaturing.get_dt_embedding import generate_protein_feature, generate_drug_feature
''' wandb'''
from managers.evaluator import Evaluator, Evaluator_ddi2
from managers.trainer_wandb import Trainer
# from managers.trainer import Trainer
import numpy as np
# os.environ['WANDB_DIR'] = os.getcwd()
# os.environ['WANDB_CACHE_DIR'] = os.getcwd() + "/.cache/"
# os.environ['WANDB_CONFIG_DIR'] = os.getcwd() + "/.config/"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
def main(params):
simplefilter(action='ignore', category=UserWarning)
simplefilter(action='ignore', category=SparseEfficiencyWarning)
params.db_path = os.path.join(params.main_dir, f'data/{params.dataset}/subgraphs_en_{params.enclosing_sub_graph}_neg_{params.num_neg_samples_per_link}_hop_{params.hop}')
if not os.path.isdir(params.db_path):
generate_subgraph_datasets(params)
train = SubgraphDataset(params.db_path, 'train_pos', 'train_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.train_file)
#assert 0
valid = SubgraphDataset(params.db_path, 'valid_pos', 'valid_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.valid_file,
ssp_graph = train.ssp_graph,
id2entity= train.id2entity, id2relation= train.id2relation, rel= train.num_rels, graph = train.graph)
test = SubgraphDataset(params.db_path, 'test_pos', 'test_neg', params.file_paths,
add_traspose_rels=params.add_traspose_rels,
num_neg_samples_per_link=params.num_neg_samples_per_link,
use_kge_embeddings=params.use_kge_embeddings, dataset=params.dataset,
kge_model=params.kge_model, file_name=params.test_file,
ssp_graph = train.ssp_graph,
id2entity= train.id2entity, id2relation= train.id2relation, rel= train.num_rels, graph = train.graph)
# print("####train num rels")
# print(train.num_rels)
# print()
params.num_rels = train.num_rels
params.aug_num_rels = train.aug_num_rels
params.inp_dim = train.n_feat_dim
params.train_rels = 200 if params.dataset == 'BioSNAP' else params.num_rels
params.num_nodes = 35000
# Log the max label value to save it in the model. This will be used to cap the labels generated on test set.
params.max_label_value = train.max_n_label
logging.info(f"Device: {params.device}")
logging.info(f"Input dim : {params.inp_dim}, # Relations : {params.num_rels}, # Augmented relations : {params.aug_num_rels}")
graph_classifier = initialize_model(params, dgl_model, params.load_model)
def to_float(arr):
return [float(x) for x in arr]
if params.dataset in ['vec', 'mydrugbank', 'davis'] and params.feat == 'morganprotbert': ### may rename params.feat
### create sparse index matrix
dind = np.zeros(1710)
pind = np.zeros(34123)
### drug feature
if not os.path.exists(f'data/{params.dataset}/VEC_drug_feats_{params.drug_embedding_method}.pkl'):
generate_drug_feature(params) ### module to generate drug feature
if not os.path.exists(f'data/{params.dataset}/VEC_drug_feats_{params.drug_embedding_method}.pkl'):
raise FileNotFoundError("Feature file have not created yet")
sys.exit()
import pickle
with open(f'data/{params.dataset}/VEC_drug_feats_{params.drug_embedding_method}.pkl', 'rb') as f:
x = pickle.load(f, encoding='utf-8')
mfeat = []
for y in x[f'{params.drug_embedding_method.upper()}_Features']:
row_feat = to_float(y)
mfeat.append(row_feat)
for idx, y in enumerate(x["Drug_enco"]) :
if 0 < y :
dind[int(y)] = idx
### protein feature
if not os.path.exists(f'data/{params.dataset}/VEC_target_feats_{params.protein_embedding_method}.pkl'):
generate_protein_feature(params) ### module to generate protein feature
with open(f'data/{params.dataset}/VEC_target_feats_{params.protein_embedding_method}.pkl', 'rb') as f:
x = pickle.load(f, encoding='utf-8')
pfeat = []
for y in x[f'{params.protein_embedding_method.upper()}_Features']:
# for y in x['ProtBERT_Features']:
pfeat.append(y)
for idx, y in enumerate(x["Gene_enco"]) :
if 0 < y :
pind[int(y)] = idx
if params.dataset == 'drugbank':
if params.feat == 'morgan':
import pickle
with open('data/{}/DB_molecular_feats.pkl'.format(params.dataset), 'rb') as f:
x = pickle.load(f, encoding='utf-8')
mfeat = []
for y in x['Morgan_Features']:
mfeat.append(y)
params.feat_dim = 1024
elif params.feat == 'pca':
mfeat = np.loadtxt('data/{}/PCA.txt'.format(params.dataset))
params.feat_dim = 200
elif params.feat == 'pretrained':
mfeat = np.loadtxt('data/{}/pretrained.txt'.format(params.dataset))
params.feat_dim = 200
# pfeat = np.loadtxt('data/{}/protein_emb_np.txt'.format(params.dataset))
# params.pfeat_dim = 1024
#np.loadtxt('data/{}/protein_index.txt'.format(params.dataset))
elif params.dataset == 'BioSNAP':
mfeat = []
rfeat = []
import pickle
with open('data/{}/id2drug_feat.pkl'.format(params.dataset), 'rb') as f:
x = pickle.load(f, encoding='utf-8')
for z in x:
y = x[z]['Morgan']
mfeat.append(y)
y = x[z]['rdkit2d']
rfeat.append(y)
params.feat_dim = 1024
#exit(0)
graph_classifier.drug_feat(torch.FloatTensor(np.array(mfeat)).to(params.device))
graph_classifier.pro_feat(torch.FloatTensor(np.array(pfeat)).to(params.device))
graph_classifier.pro_ind(torch.LongTensor(np.array(pind)).to(params.device))
graph_classifier.drug_ind(torch.LongTensor(np.array(dind)).to(params.device))
train_evaluator = Evaluator(params, graph_classifier, train) if params.dataset in ['drugbank', 'vec', 'davis'] else Evaluator_ddi2(params, graph_classifier, train)
valid_evaluator = Evaluator(params, graph_classifier, valid) if params.dataset in ['drugbank', 'vec', 'davis'] else Evaluator_ddi2(params, graph_classifier, valid)
test_evaluator = Evaluator(params, graph_classifier, test) if params.dataset in ['drugbank', 'vec', 'davis'] else Evaluator_ddi2(params, graph_classifier, test)
trainer = Trainer(params, graph_classifier, train, valid, test, train_evaluator, valid_evaluator,test_evaluator)
# from torchsummary import summary
# summary(graph_classifier, [(128, 128)]) # 서브그래프가 연결안되서 KeyError: 'h' 발생
# exit(0)
logging.info('Starting training with full batch...')
trainer.train()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='TransE model')
# Experiment setup params
parser.add_argument("--experiment_name", "-e", type=str, default="default1",
help="A folder with this name would be created to dump saved models and log files")
parser.add_argument("--dataset", "-d", type=str,
help="Dataset string")
parser.add_argument("--gpu", type=int, default=2,
help="Which GPU to use?")
parser.add_argument('--disable_cuda', action='store_true',
help='Disable CUDA')
parser.add_argument('--load_model', type=bool, default=False,
help='Load existing model?')
parser.add_argument("--train_file", "-tf", type=str, default="train",
help="Name of file containing training triplets")
parser.add_argument("--valid_file", "-vf", type=str, default="dev",
help="Name of file containing validation triplets")
parser.add_argument("--test_file", "-ttf", type=str, default="test",
help="Name of file containing test triplets")
# Training regime params
parser.add_argument("--num_epochs", "-ne", type=int, default=300,
help="Number of epochs to train for")
parser.add_argument("--eval_every", type=int, default=1,
help="Interval of epochs to evaluate the model?")
parser.add_argument("--eval_every_iter", type=int, default=97, # len(train) / batch_size : 12331 / 128
help="Interval of iterations to evaluate the model?")
parser.add_argument("--save_every", type=int, default=10,
help="Interval of epochs to save a checkpoint of the model?")
parser.add_argument("--early_stop", type=int, default=100,
help="Early stopping patience")
parser.add_argument("--optimizer", type=str, default="Adam",
help="Which optimizer to use?")
parser.add_argument("--lr", type=float, default=1e-5,
help="Learning rate of the optimizer")
parser.add_argument("--lr_scheduling", type=bool, default=False,
help="Whether to use CosineLRScheduler")
parser.add_argument("--clip", type=int, default=500,
help="Maximum gradient norm allowed")
parser.add_argument("--l2", type=float, default=1e-5,
help="Regularization constant for GNN weights")
# Data processing pipeline params
parser.add_argument("--max_links", type=int, default=250000,
help="Set maximum number of train links (to fit into memory)")
parser.add_argument("--hop", type=int, default=2,
help="Enclosing subgraph hop number")
parser.add_argument("--max_nodes_per_hop", "-max_h", type=int, default=200,
help="if > 0, upper bound the # nodes per hop by subsampling")
parser.add_argument("--use_kge_embeddings", "-kge", type=bool, default=False,
help='whether to use pretrained KGE embeddings')
parser.add_argument("--kge_model", type=str, default="TransE",
help="Which KGE model to load entity embeddings from")
parser.add_argument('--model_type', '-m', type=str, choices=['ssp', 'dgl'], default='dgl',
help='what format to store subgraphs in for model')
parser.add_argument('--constrained_neg_prob', '-cn', type=float, default=0.0,
help='with what probability to sample constrained heads/tails while neg sampling')
parser.add_argument("--batch_size", type=int, default=128,
help="Batch size")
parser.add_argument("--num_neg_samples_per_link", '-neg', type=int, default=0,
help="Number of negative examples to sample per positive link")
parser.add_argument("--num_workers", type=int, default=18,
help="Number of dataloading processes")
parser.add_argument('--add_traspose_rels', '-tr', type=bool, default=False,
help='whether to append adj matrix list with symmetric relations')
parser.add_argument('--enclosing_sub_graph', '-en', type=bool, default=True,
help='whether to only consider enclosing subgraph')
parser.add_argument('--protein_embedding_method', '-pem', type=str,
choices=['prot_t5_xl_bfd', 'prot_t5_xl_uniref50', 'prot_bert_bfd','prot_bert', 'prot_bert_average', 'prostt5'], default='prot_bert',
help='what protein embedding to use')
parser.add_argument('--drug_embedding_method', '-dem', type=str,
choices=['morgan', 'map4'], default='morgan',
help='what drug embedding to use')
parser.add_argument('--protein_embedding_replace', '-per', type=bool, default=True,
help='whether to replace U, Z, O, B with X in protein sequence')
# Model params
parser.add_argument("--rel_emb_dim", "-r_dim", type=int, default=16,
help="Relation embedding size")
parser.add_argument("--attn_rel_emb_dim", "-ar_dim", type=int, default=16,
help="Relation embedding size for attention")
parser.add_argument("--emb_dim", "-dim", type=int, default=16,
help="Entity embedding size")
parser.add_argument("--num_gcn_layers", "-l", type=int, default=1,
help="Number of GCN layers")
parser.add_argument("--num_bases", "-b", type=int, default=4,
help="Number of basis functions to use for GCN weights")
parser.add_argument("--dropout", type=float, default=0.3,
help="Dropout rate in GNN layers")
parser.add_argument("--edge_dropout", type=float, default=0.4,
help="Dropout rate in edges of the subgraphs")
parser.add_argument('--gnn_agg_type', '-a', type=str, choices=['sum', 'mlp', 'gru'], default='sum',
help='what type of aggregation to do in gnn msg passing')
parser.add_argument('--add_ht_emb', '-ht', type=bool, default=True,
help='whether to concatenate head/tail embedding with pooled graph representation')
parser.add_argument('--add_sb_emb', '-sb', type=bool, default=True,
help='whether to concatenate subgraph embedding with pooled graph representation')
parser.add_argument('--has_attn', '-attn', type=bool, default=True,
help='whether to have attn in model or not')
parser.add_argument('--has_kg', '-kg', type=bool, default=True,
help='whether to have kg in model or not')
parser.add_argument('--feat', '-f', type=str, default='morganprotbert',
help='the type of the feature we use in molecule modeling')
parser.add_argument('--feat_dim', type=int, default=1024,
help='the dimension of the feature')
parser.add_argument('--pfeat_dim', type=int, default=1024,
help='the dimension of the feature')
parser.add_argument('--add_feat_emb', '-feat', type=bool, default=True,
help='whether to morgan feature embedding in model or not')
parser.add_argument('--add_transe_emb', type=bool, default=True,
help='whether to have knowledge graph embedding in model or not')
parser.add_argument('--add_pfeat_emb', '-pfeat', type=bool, default=True,
help='whether to protein feature embedding in model or not')
parser.add_argument('--gamma', type=float, default=0.2,
help='The threshold for attention')
params = parser.parse_args()
initialize_experiment(params, __file__)
params.file_paths = {
'train': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.train_file)),
'valid': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.valid_file)),
'test': os.path.join(params.main_dir, 'data/{}/{}.txt'.format(params.dataset, params.test_file))
}
if torch.cuda.is_available():
params.device = torch.device('cuda:{}'.format(params.gpu))
else:
print("running on CPU")
params.collate_fn = collate_dgl
#minsik
params.move_batch_to_device = move_batch_to_device_dgl if params.dataset == 'drugbank' or params.dataset == 'davis' or params.dataset == 'vec' else move_batch_to_device_dgl_ddi2
main(params)