Skip to content

Commit ef85107

Browse files
author
刘宇
committed
重置了trainer into utils
1 parent dc3ba44 commit ef85107

File tree

7 files changed

+64
-83
lines changed

7 files changed

+64
-83
lines changed

utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from trainer import Trainer

utils/trainer.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch.nn as nn
2+
import torch
3+
from sklearn.model_selection import StratifiedKFold, KFold
4+
import transformers
5+
from tqdm import tqdm
6+
transformers.logging.set_verbosity_error()
7+
8+
9+
class Train(object):
10+
def __init__(self, model: nn.Module, epochs=20, lr=1e-5, weight_decay=0,
11+
show_batch=50, use_cuda=True,compute_metrics=None):
12+
self.model = model
13+
self.device = torch.device(
14+
"cuda:0" if torch.cuda.is_available() else "cpu")
15+
self.model.to(self.device)
16+
self.epochs = epochs
17+
self.lr = lr
18+
self.show_batch = show_batch
19+
self.weight_decay = weight_decay
20+
self.optimizer = torch.optim.AdamW(
21+
self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
22+
self.compute_metrics = compute_metrics
23+
def train(self, dataset_train, dataset_eval=None):
24+
for epoch in range(self.epochs):
25+
self.model.train()
26+
for idx, batch in tqdm(enumerate(dataset_train), total=len(dataset_train)):
27+
batch = {k: v.to(self.device) for k, v in batch.items()}
28+
score = self.compute_metrics(batch)
29+
loss = self.model(**batch)["loss"]
30+
self.optimizer.zero_grad()
31+
loss.backward()
32+
self.optimizer.step()
33+
if idx % self.show_batch == 0:
34+
print(
35+
'Epoch [{}/{}],batch:{} Loss: {:.4f}'.format(self.epochs, epoch + 1, idx, loss.item()))
36+
with torch.no_grad(): # 评估时禁止计算梯度
37+
self.evaluation(dataset_eval, epoch)
38+
39+
def evaluation(self, dataset_eval, epoch):
40+
print("evaluation.....")
41+
self.model.eval()
42+
score_list = []
43+
for idx, batch in tqdm(enumerate(dataset_eval), total=len(dataset_eval)):
44+
batch = {k: v.to(self.device) for k, v in batch.items()}
45+
score = self.compute_metrics(batch)
46+
score_list.append(score)
47+
score = sum(score_list) / len(score_list) * 100
48+
print(
49+
'Epoch [{}/{}], score: {:.4f} %'.format(self.epochs, epoch + 1, score))
50+

深度学习/nlp/MRC_阅读理解/main.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,23 @@ def forward(self, input_ids, attention_mask, token_type_ids, start_positions, en
5757

5858

5959
import sys
60+
from os.path import dirname as rn
61+
import os
6062

61-
sys.path.append("..")
62-
from trainer import Trainer
63-
63+
path = rn(rn(rn(rn(__file__))))
64+
print(path)
65+
sys.path.append(path)
66+
from utils.trainer import Trainer
67+
from utils import Trainer
6468
if __name__ == '__main__':
65-
# train_dataset = get_squad_dataset(data_dir="./data/cmrc2018_public", filename="train.json")
66-
# dev_dataset = get_squad_dataset(data_dir="./data/cmrc2018_public", filename="dev.json")
69+
train_dataset = get_squad_dataset(data_dir="./data/cmrc2018_public", filename="train.json")
70+
dev_dataset = get_squad_dataset(data_dir="./data/cmrc2018_public", filename="dev.json")
6771
import pickle
6872

6973
# pickle.dump(train_dataset, open("train.pt", "wb"))
7074
# pickle.dump(dev_dataset, open("dev.pt", "wb"))
71-
train_dataset = pickle.load(open("train.pt", "rb"))
72-
dev_dataset = pickle.load(open("dev.pt", "rb"))
75+
# train_dataset = pickle.load(open("train.pt", "rb"))
76+
# dev_dataset = pickle.load(open("dev.pt", "rb"))
7377
batch_size = 8
7478
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
7579
dev_dataloader = DataLoader(dataset=dev_dataset, batch_size=batch_size, shuffle=True)
File renamed without changes.

深度学习/nlp/trainer.py

-72
This file was deleted.

项目实战/医疗诊疗对话意图识别挑战赛/BERT-DAC/run.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,4 @@ def f(batch):
6969
model = BertForSequenceClassification.from_pretrained(check_point, num_labels=16)
7070

7171

72-
# train
73-
# # 'dac_predictions.npy'
74-
train(model, data_loader, dev_loader, test_loader, args)
72+
train(model, data_loader, dev_loader, test_loader, args)

项目实战/医疗诊疗对话意图识别挑战赛/BERT-DAC/train_eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def train(model, train_iter, dev_iter, test_iter, args):
8686
break
8787
if flag:
8888
break
89-
test(config, model, test_iter, args)
89+
test(model, test_iter, args)
9090

9191

9292
def test(config, model, test_iter, args):

0 commit comments

Comments
 (0)