|
| 1 | +import json |
| 2 | +import os |
| 3 | +import argparse |
| 4 | +import tensorflow.compat.v1 as tf |
| 5 | +from tfbert import ( |
| 6 | + Trainer, Dataset, |
| 7 | + SequenceClassification, |
| 8 | + CONFIGS, TOKENIZERS, devices, set_seed) |
| 9 | +from tqdm import tqdm |
| 10 | +from sklearn.metrics import accuracy_score |
| 11 | +import pandas as pd |
| 12 | +from typing import Dict |
| 13 | +import numpy as np |
| 14 | + |
| 15 | + |
| 16 | +def create_args(): |
| 17 | + parser = argparse.ArgumentParser() |
| 18 | + parser.add_argument('--model_type', default='bert', type=str, choices=CONFIGS.keys()) |
| 19 | + parser.add_argument('--optimizer_type', default='adamw', type=str, help="优化器类型") |
| 20 | + parser.add_argument('--model_dir', default='model_path', type=str, |
| 21 | + help="预训练模型存放文件夹,文件夹下ckpt文件名为model.ckpt," |
| 22 | + "config文件名为config.json,词典文件名为vocab.txt") |
| 23 | + |
| 24 | + parser.add_argument('--config_path', default=None, type=str, help="若配置文件名不是默认的,可在这里输入") |
| 25 | + parser.add_argument('--vocab_path', default=None, type=str, help="若词典文件名不是默认的,可在这里输入") |
| 26 | + parser.add_argument('--pretrained_checkpoint_path', default=None, type=str, help="若模型文件名不是默认的,可在这里输入") |
| 27 | + parser.add_argument('--output_dir', default='output/classification', type=str, help="") |
| 28 | + parser.add_argument('--export_dir', default='output/classification/pb', type=str, help="") |
| 29 | + |
| 30 | + parser.add_argument('--labels', default='体育,娱乐,家居,房产,教育', type=str, help="文本分类标签") |
| 31 | + parser.add_argument('--train_file', default='data/classification/train.csv', type=str, help="") |
| 32 | + parser.add_argument('--dev_file', default='data/classification/dev.csv', type=str, help="") |
| 33 | + parser.add_argument('--test_file', default='data/classification/test.csv', type=str, help="") |
| 34 | + |
| 35 | + parser.add_argument("--num_train_epochs", default=3, type=int, help="训练轮次") |
| 36 | + parser.add_argument("--max_seq_length", default=32, type=int, help="最大句子长度") |
| 37 | + parser.add_argument("--batch_size", default=32, type=int, help="训练批次") |
| 38 | + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="梯度累积") |
| 39 | + parser.add_argument("--learning_rate", default=2e-5, type=float, help="学习率") |
| 40 | + parser.add_argument("--warmup_proportion", default=0.1, type=float, |
| 41 | + help="Proportion of training to perform linear learning rate warmup for.") |
| 42 | + parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.") |
| 43 | + |
| 44 | + parser.add_argument("--do_train", action="store_true", help="Whether to run training.") |
| 45 | + parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") |
| 46 | + parser.add_argument("--do_predict", action="store_true", help="Whether to run test on the test set.") |
| 47 | + parser.add_argument("--evaluate_during_training", action="store_true", help="是否边训练边验证") |
| 48 | + parser.add_argument("--do_export", action="store_true", help="将模型导出为pb格式.") |
| 49 | + |
| 50 | + parser.add_argument("--logging_steps", default=1000, type=int, help="训练时每隔几步验证一次") |
| 51 | + parser.add_argument("--saving_steps", default=1000, type=int, help="训练时每隔几步保存一次") |
| 52 | + parser.add_argument("--random_seed", default=42, type=int, help="随机种子") |
| 53 | + parser.add_argument("--threads", default=8, type=int, help="数据处理进程数") |
| 54 | + parser.add_argument("--max_checkpoints", default=1, type=int, help="模型保存最大数量,默认只保存一个") |
| 55 | + parser.add_argument("--single_device", action="store_true", help="是否只使用一个device,默认使用所有的device训练") |
| 56 | + parser.add_argument("--use_xla", action="store_true", help="是否使用XLA加速") |
| 57 | + parser.add_argument( |
| 58 | + "--mixed_precision", action="store_true", |
| 59 | + help="混合精度训练,tf下测试需要同时使用xla才有加速效果,但是开始编译很慢") |
| 60 | + args = parser.parse_args() |
| 61 | + |
| 62 | + if not os.path.exists(args.output_dir): |
| 63 | + os.makedirs(args.output_dir) |
| 64 | + |
| 65 | + if not args.single_device: |
| 66 | + args.batch_size = args.batch_size * len(devices()) |
| 67 | + |
| 68 | + args.labels = args.labels.split(',') |
| 69 | + return args |
| 70 | + |
| 71 | + |
| 72 | +def create_dataset(set_type, tokenizer, args): |
| 73 | + filename_map = { |
| 74 | + 'train': args.train_file, 'dev': args.dev_file, 'test': args.test_file |
| 75 | + } |
| 76 | + features = [] |
| 77 | + datas = pd.read_csv(filename_map[set_type], encoding='utf-8', sep='\t').values.tolist() |
| 78 | + label_map = {label: i for i, label in enumerate(args.labels)} |
| 79 | + # glyce_bert 目前还没适配内置的数据处理代码,因此需要修改这里 |
| 80 | + for data in tqdm(datas): |
| 81 | + encoded = tokenizer(data[1], |
| 82 | + max_length=args.max_seq_length, # 最大长度 |
| 83 | + padding="max_length", # 是否将句子padding到最大长度 |
| 84 | + truncation=True) |
| 85 | + encoded['label_ids'] = label_map[data[0]] |
| 86 | + features.append(encoded) |
| 87 | + dataset = Dataset(features, |
| 88 | + is_training=bool(set_type == 'train'), |
| 89 | + batch_size=args.batch_size, |
| 90 | + drop_last=bool(set_type == 'train'), |
| 91 | + buffer_size=len(features), |
| 92 | + max_length=args.max_seq_length) |
| 93 | + dataset.format_as(['input_ids', 'pinyin_ids', 'attention_mask', 'token_type_ids', 'label_ids']) |
| 94 | + return dataset |
| 95 | + |
| 96 | + |
| 97 | +def get_model_fn(config, args): |
| 98 | + def model_fn(inputs, is_training): |
| 99 | + model = SequenceClassification( |
| 100 | + model_type=args.model_type, config=config, |
| 101 | + num_classes=len(args.labels), is_training=is_training, |
| 102 | + **inputs) |
| 103 | + |
| 104 | + outputs = {'outputs': {'logits': model.logits, 'label_ids': inputs['label_ids']}} |
| 105 | + if model.loss is not None: |
| 106 | + loss = model.loss / args.gradient_accumulation_steps |
| 107 | + outputs['loss'] = loss |
| 108 | + return outputs |
| 109 | + |
| 110 | + return model_fn |
| 111 | + |
| 112 | + |
| 113 | +def get_serving_fn(config, args): |
| 114 | + def serving_fn(): |
| 115 | + input_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='input_ids') |
| 116 | + pinyin_ids = tf.placeholder(shape=[None, args.max_seq_length, 8], dtype=tf.int64, name='pinyin_ids') |
| 117 | + attention_mask = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='attention_mask') |
| 118 | + token_type_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='token_type_ids') |
| 119 | + model = SequenceClassification( |
| 120 | + model_type=args.model_type, config=config, |
| 121 | + num_classes=len(args.labels), is_training=False, |
| 122 | + input_ids=input_ids, |
| 123 | + pinyin_ids=pinyin_ids, |
| 124 | + attention_mask=attention_mask, |
| 125 | + token_type_ids=token_type_ids |
| 126 | + ) |
| 127 | + inputs = { |
| 128 | + 'input_ids': input_ids, 'pinyin_ids': pinyin_ids, |
| 129 | + 'attention_mask': attention_mask, 'token_type_ids': token_type_ids} |
| 130 | + outputs = {'logits': model.logits} |
| 131 | + return inputs, outputs |
| 132 | + |
| 133 | + return serving_fn |
| 134 | + |
| 135 | + |
| 136 | +def metric_fn(outputs: Dict) -> Dict: |
| 137 | + """ |
| 138 | + 这里定义评估函数 |
| 139 | + :param outputs: trainer evaluate 返回的预测结果,model fn的outputs包含哪些字段就会有哪些字段 |
| 140 | + :return: 需要返回字典结果 |
| 141 | + """ |
| 142 | + predictions = np.argmax(outputs['logits'], -1) |
| 143 | + score = accuracy_score(outputs['label_ids'], predictions) |
| 144 | + return {'accuracy': score} |
| 145 | + |
| 146 | + |
| 147 | +def main(): |
| 148 | + args = create_args() |
| 149 | + set_seed(args.random_seed) |
| 150 | + |
| 151 | + config = CONFIGS[args.model_type].from_pretrained( |
| 152 | + args.model_dir if args.config_path is None else args.config_path) |
| 153 | + |
| 154 | + tokenizer = TOKENIZERS[args.model_type].from_pretrained( |
| 155 | + args.model_dir if args.vocab_path is None else args.vocab_path, do_lower_case=True) |
| 156 | + |
| 157 | + train_dataset, dev_dataset, predict_dataset = None, None, None |
| 158 | + if args.do_train: |
| 159 | + train_dataset = create_dataset('train', tokenizer, args) |
| 160 | + |
| 161 | + if args.do_eval: |
| 162 | + dev_dataset = create_dataset('dev', tokenizer, args) |
| 163 | + |
| 164 | + if args.do_predict: |
| 165 | + predict_dataset = create_dataset('test', tokenizer, args) |
| 166 | + |
| 167 | + output_types, output_shapes = (train_dataset or dev_dataset or predict_dataset).output_types_and_shapes() |
| 168 | + trainer = Trainer( |
| 169 | + train_dataset=train_dataset, |
| 170 | + eval_dataset=dev_dataset, |
| 171 | + output_types=output_types, |
| 172 | + output_shapes=output_shapes, |
| 173 | + metric_fn=metric_fn, |
| 174 | + use_xla=args.use_xla, |
| 175 | + optimizer_type=args.optimizer_type, |
| 176 | + learning_rate=args.learning_rate, |
| 177 | + num_train_epochs=args.num_train_epochs, |
| 178 | + gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 179 | + max_checkpoints=1, |
| 180 | + max_grad=1.0, |
| 181 | + warmup_proportion=args.warmup_proportion, |
| 182 | + mixed_precision=args.mixed_precision, |
| 183 | + single_device=args.single_device, |
| 184 | + logging=True |
| 185 | + ) |
| 186 | + trainer.build_model(model_fn=get_model_fn(config, args)) |
| 187 | + if args.do_train and train_dataset is not None: |
| 188 | + trainer.compile() |
| 189 | + trainer.from_pretrained( |
| 190 | + args.model_dir if args.pretrained_checkpoint_path is None else args.pretrained_checkpoint_path) |
| 191 | + |
| 192 | + trainer.train( |
| 193 | + output_dir=args.output_dir, |
| 194 | + evaluate_during_training=args.evaluate_during_training, |
| 195 | + logging_steps=args.logging_steps, |
| 196 | + saving_steps=args.saving_steps, |
| 197 | + greater_is_better=True, metric_for_best_model='accuracy') |
| 198 | + config.save_pretrained(args.output_dir) |
| 199 | + tokenizer.save_pretrained(args.output_dir) |
| 200 | + |
| 201 | + if args.do_eval and dev_dataset is not None: |
| 202 | + trainer.from_pretrained(args.output_dir) |
| 203 | + eval_outputs = trainer.evaluate() |
| 204 | + print(json.dumps( |
| 205 | + eval_outputs, ensure_ascii=False, indent=4 |
| 206 | + )) |
| 207 | + |
| 208 | + if args.do_predict and predict_dataset is not None: |
| 209 | + trainer.from_pretrained(args.output_dir) |
| 210 | + outputs = trainer.predict('test', ['logits'], dataset=predict_dataset) |
| 211 | + label_ids = np.argmax(outputs['logits'], axis=-1) |
| 212 | + labels = list(map(lambda x: args.labels[x], label_ids)) |
| 213 | + open( |
| 214 | + os.path.join(args.output_dir, 'prediction.txt'), 'w', encoding='utf-8' |
| 215 | + ).write("\n".join(labels)) |
| 216 | + |
| 217 | + if args.do_export: |
| 218 | + trainer.export( |
| 219 | + get_serving_fn(config, args), |
| 220 | + args.output_dir, |
| 221 | + args.export_dir |
| 222 | + ) |
| 223 | + |
| 224 | + |
| 225 | +if __name__ == '__main__': |
| 226 | + main() |
0 commit comments