Skip to content

Commit ad23f37

Browse files
author
hh125
committedJul 30, 2021
增加ChineseBert
1 parent 9661641 commit ad23f37

14 files changed

+801
-36
lines changed
 

‎README.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ config、tokenizer参考的transformers的实现。
2222
内置代码示例数据集[百度网盘提取码:rhxk](https://pan.baidu.com/s/1lYy7BJdadT0LJfMSsKz6AA)
2323
## 支持模型
2424

25-
bert、electra、albert、nezha、wobert
25+
bert、electra、albert、nezha、wobert、ChineseBert(GlyceBert)
2626

2727
## requirements
2828
```
@@ -82,8 +82,15 @@ CUDA_VISIBLE_DEVICES=1,2 python run.py
8282
| **`ELECTRA, Chinese`** | **[Chinese-ELECTRA](https://github.com/ymcui/Chinese-ELECTRA)**|
8383
| **`ERNIE 1.0.1, Chinese`** | **[百度网盘(xrku)](https://pan.baidu.com/s/13eRD6uVnr4xeUfYXk8XKIw)**|
8484
| **`ERNIE gram base, Chinese`** | **[百度网盘(7xet)](https://pan.baidu.com/s/1qzIuduI2ZRJDZSnNqTfscw)**|
85+
| **`ChineseBert, Chinese`** | **[base(sxhj)](https://pan.baidu.com/s/1ehO52PQd6TFVhOu5RiRtZA)** **[large(zi0r)](https://pan.baidu.com/s/1IifQuRFhpwWzLJHvMR9gOQ)**|
86+
8587

8688
## **更新记录**
89+
-2021/7/31 内置模型新增香侬科技开源的ChineseBert,见[glyce_bert](tfbert/models/glyce_bert.py),目前官方只有torch版本。
90+
模型增加了字形和拼音特征作为embedding表示,获得了和mac bert接近的效果,官方见[ChineseBert](https://github.com/ShannonAI/ChineseBert)
91+
。tf权重已经转好,可自行下载。
92+
目前内置数据处理没有增加拼音特征,因此只写了一个简单[文本分类示例](run_classifier_glyce_bert.py),后期再进行完善。
93+
8794
- 2021/5/19 增加机器阅读理解示例代码,以dureader2021比赛数据为例,应该兼容大部分squad格式的数据。
8895
同时更新tokenizer代码,贴合transformers使用接口,大部分直接整合的transformers的tokenizer
8996

‎convert_bert_torch_to_tf.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313

1414
def convert_pytorch_checkpoint_to_tf(pt_weight_file, pt_config_file, pt_vocab_file, save_dir: str):
1515
tensors_to_transpose = (
16-
"dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
16+
"dense.weight", "attention.self.query", "attention.self.key", "attention.self.value", "glyph_map.weight",
17+
"map_fc.weight")
18+
glyce_bert_conv_tensors = ("conv.weight",)
1719

1820
var_map = (
1921
("layer.", "layer_"),
2022
("word_embeddings.weight", "word_embeddings"),
2123
("position_embeddings.weight", "position_embeddings"),
2224
("token_type_embeddings.weight", "token_type_embeddings"),
25+
("pinyin_embeddings.embedding.weight", "pinyin_embeddings/embeddings"),
26+
("glyph_embeddings.embedding.weight", "glyph_embeddings/embeddings"),
2327
(".", "/"),
2428
("LayerNorm/weight", "LayerNorm/gamma"),
2529
("LayerNorm/bias", "LayerNorm/beta"),
@@ -50,6 +54,10 @@ def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
5054
torch_tensor = state_dict[var_name].numpy()
5155
if any([x in var_name for x in tensors_to_transpose]):
5256
torch_tensor = torch_tensor.T
57+
if any([x in var_name for x in glyce_bert_conv_tensors]):
58+
torch_tensor = torch_tensor.T
59+
torch_tensor = np.expand_dims(torch_tensor, axis=2)
60+
5361
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
5462
tf.keras.backend.set_value(tf_var, torch_tensor)
5563
tf_weight = session.run(tf_var)
@@ -70,6 +78,12 @@ def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
7078
if pt_vocab_file is not None and os.path.exists(pt_vocab_file):
7179
shutil.copyfile(pt_vocab_file, os.path.join(save_dir, 'vocab.txt'))
7280

81+
config_path = os.path.join(os.path.split(pt_config_file)[0], 'config')
82+
target_dir = os.path.join(save_dir, 'config')
83+
if os.path.isdir(config_path) and not os.path.exists(target_dir):
84+
os.makedirs(target_dir)
85+
shutil.copytree(config_path, target_dir)
86+
7387

7488
def main():
7589
parser = argparse.ArgumentParser()

‎run_classifier_glyce_bert.py

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import json
2+
import os
3+
import argparse
4+
import tensorflow.compat.v1 as tf
5+
from tfbert import (
6+
Trainer, Dataset,
7+
SequenceClassification,
8+
CONFIGS, TOKENIZERS, devices, set_seed)
9+
from tqdm import tqdm
10+
from sklearn.metrics import accuracy_score
11+
import pandas as pd
12+
from typing import Dict
13+
import numpy as np
14+
15+
16+
def create_args():
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument('--model_type', default='bert', type=str, choices=CONFIGS.keys())
19+
parser.add_argument('--optimizer_type', default='adamw', type=str, help="优化器类型")
20+
parser.add_argument('--model_dir', default='model_path', type=str,
21+
help="预训练模型存放文件夹,文件夹下ckpt文件名为model.ckpt,"
22+
"config文件名为config.json,词典文件名为vocab.txt")
23+
24+
parser.add_argument('--config_path', default=None, type=str, help="若配置文件名不是默认的,可在这里输入")
25+
parser.add_argument('--vocab_path', default=None, type=str, help="若词典文件名不是默认的,可在这里输入")
26+
parser.add_argument('--pretrained_checkpoint_path', default=None, type=str, help="若模型文件名不是默认的,可在这里输入")
27+
parser.add_argument('--output_dir', default='output/classification', type=str, help="")
28+
parser.add_argument('--export_dir', default='output/classification/pb', type=str, help="")
29+
30+
parser.add_argument('--labels', default='体育,娱乐,家居,房产,教育', type=str, help="文本分类标签")
31+
parser.add_argument('--train_file', default='data/classification/train.csv', type=str, help="")
32+
parser.add_argument('--dev_file', default='data/classification/dev.csv', type=str, help="")
33+
parser.add_argument('--test_file', default='data/classification/test.csv', type=str, help="")
34+
35+
parser.add_argument("--num_train_epochs", default=3, type=int, help="训练轮次")
36+
parser.add_argument("--max_seq_length", default=32, type=int, help="最大句子长度")
37+
parser.add_argument("--batch_size", default=32, type=int, help="训练批次")
38+
parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="梯度累积")
39+
parser.add_argument("--learning_rate", default=2e-5, type=float, help="学习率")
40+
parser.add_argument("--warmup_proportion", default=0.1, type=float,
41+
help="Proportion of training to perform linear learning rate warmup for.")
42+
parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.")
43+
44+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
45+
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
46+
parser.add_argument("--do_predict", action="store_true", help="Whether to run test on the test set.")
47+
parser.add_argument("--evaluate_during_training", action="store_true", help="是否边训练边验证")
48+
parser.add_argument("--do_export", action="store_true", help="将模型导出为pb格式.")
49+
50+
parser.add_argument("--logging_steps", default=1000, type=int, help="训练时每隔几步验证一次")
51+
parser.add_argument("--saving_steps", default=1000, type=int, help="训练时每隔几步保存一次")
52+
parser.add_argument("--random_seed", default=42, type=int, help="随机种子")
53+
parser.add_argument("--threads", default=8, type=int, help="数据处理进程数")
54+
parser.add_argument("--max_checkpoints", default=1, type=int, help="模型保存最大数量,默认只保存一个")
55+
parser.add_argument("--single_device", action="store_true", help="是否只使用一个device,默认使用所有的device训练")
56+
parser.add_argument("--use_xla", action="store_true", help="是否使用XLA加速")
57+
parser.add_argument(
58+
"--mixed_precision", action="store_true",
59+
help="混合精度训练,tf下测试需要同时使用xla才有加速效果,但是开始编译很慢")
60+
args = parser.parse_args()
61+
62+
if not os.path.exists(args.output_dir):
63+
os.makedirs(args.output_dir)
64+
65+
if not args.single_device:
66+
args.batch_size = args.batch_size * len(devices())
67+
68+
args.labels = args.labels.split(',')
69+
return args
70+
71+
72+
def create_dataset(set_type, tokenizer, args):
73+
filename_map = {
74+
'train': args.train_file, 'dev': args.dev_file, 'test': args.test_file
75+
}
76+
features = []
77+
datas = pd.read_csv(filename_map[set_type], encoding='utf-8', sep='\t').values.tolist()
78+
label_map = {label: i for i, label in enumerate(args.labels)}
79+
# glyce_bert 目前还没适配内置的数据处理代码,因此需要修改这里
80+
for data in tqdm(datas):
81+
encoded = tokenizer(data[1],
82+
max_length=args.max_seq_length, # 最大长度
83+
padding="max_length", # 是否将句子padding到最大长度
84+
truncation=True)
85+
encoded['label_ids'] = label_map[data[0]]
86+
features.append(encoded)
87+
dataset = Dataset(features,
88+
is_training=bool(set_type == 'train'),
89+
batch_size=args.batch_size,
90+
drop_last=bool(set_type == 'train'),
91+
buffer_size=len(features),
92+
max_length=args.max_seq_length)
93+
dataset.format_as(['input_ids', 'pinyin_ids', 'attention_mask', 'token_type_ids', 'label_ids'])
94+
return dataset
95+
96+
97+
def get_model_fn(config, args):
98+
def model_fn(inputs, is_training):
99+
model = SequenceClassification(
100+
model_type=args.model_type, config=config,
101+
num_classes=len(args.labels), is_training=is_training,
102+
**inputs)
103+
104+
outputs = {'outputs': {'logits': model.logits, 'label_ids': inputs['label_ids']}}
105+
if model.loss is not None:
106+
loss = model.loss / args.gradient_accumulation_steps
107+
outputs['loss'] = loss
108+
return outputs
109+
110+
return model_fn
111+
112+
113+
def get_serving_fn(config, args):
114+
def serving_fn():
115+
input_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='input_ids')
116+
pinyin_ids = tf.placeholder(shape=[None, args.max_seq_length, 8], dtype=tf.int64, name='pinyin_ids')
117+
attention_mask = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='attention_mask')
118+
token_type_ids = tf.placeholder(shape=[None, args.max_seq_length], dtype=tf.int64, name='token_type_ids')
119+
model = SequenceClassification(
120+
model_type=args.model_type, config=config,
121+
num_classes=len(args.labels), is_training=False,
122+
input_ids=input_ids,
123+
pinyin_ids=pinyin_ids,
124+
attention_mask=attention_mask,
125+
token_type_ids=token_type_ids
126+
)
127+
inputs = {
128+
'input_ids': input_ids, 'pinyin_ids': pinyin_ids,
129+
'attention_mask': attention_mask, 'token_type_ids': token_type_ids}
130+
outputs = {'logits': model.logits}
131+
return inputs, outputs
132+
133+
return serving_fn
134+
135+
136+
def metric_fn(outputs: Dict) -> Dict:
137+
"""
138+
这里定义评估函数
139+
:param outputs: trainer evaluate 返回的预测结果,model fn的outputs包含哪些字段就会有哪些字段
140+
:return: 需要返回字典结果
141+
"""
142+
predictions = np.argmax(outputs['logits'], -1)
143+
score = accuracy_score(outputs['label_ids'], predictions)
144+
return {'accuracy': score}
145+
146+
147+
def main():
148+
args = create_args()
149+
set_seed(args.random_seed)
150+
151+
config = CONFIGS[args.model_type].from_pretrained(
152+
args.model_dir if args.config_path is None else args.config_path)
153+
154+
tokenizer = TOKENIZERS[args.model_type].from_pretrained(
155+
args.model_dir if args.vocab_path is None else args.vocab_path, do_lower_case=True)
156+
157+
train_dataset, dev_dataset, predict_dataset = None, None, None
158+
if args.do_train:
159+
train_dataset = create_dataset('train', tokenizer, args)
160+
161+
if args.do_eval:
162+
dev_dataset = create_dataset('dev', tokenizer, args)
163+
164+
if args.do_predict:
165+
predict_dataset = create_dataset('test', tokenizer, args)
166+
167+
output_types, output_shapes = (train_dataset or dev_dataset or predict_dataset).output_types_and_shapes()
168+
trainer = Trainer(
169+
train_dataset=train_dataset,
170+
eval_dataset=dev_dataset,
171+
output_types=output_types,
172+
output_shapes=output_shapes,
173+
metric_fn=metric_fn,
174+
use_xla=args.use_xla,
175+
optimizer_type=args.optimizer_type,
176+
learning_rate=args.learning_rate,
177+
num_train_epochs=args.num_train_epochs,
178+
gradient_accumulation_steps=args.gradient_accumulation_steps,
179+
max_checkpoints=1,
180+
max_grad=1.0,
181+
warmup_proportion=args.warmup_proportion,
182+
mixed_precision=args.mixed_precision,
183+
single_device=args.single_device,
184+
logging=True
185+
)
186+
trainer.build_model(model_fn=get_model_fn(config, args))
187+
if args.do_train and train_dataset is not None:
188+
trainer.compile()
189+
trainer.from_pretrained(
190+
args.model_dir if args.pretrained_checkpoint_path is None else args.pretrained_checkpoint_path)
191+
192+
trainer.train(
193+
output_dir=args.output_dir,
194+
evaluate_during_training=args.evaluate_during_training,
195+
logging_steps=args.logging_steps,
196+
saving_steps=args.saving_steps,
197+
greater_is_better=True, metric_for_best_model='accuracy')
198+
config.save_pretrained(args.output_dir)
199+
tokenizer.save_pretrained(args.output_dir)
200+
201+
if args.do_eval and dev_dataset is not None:
202+
trainer.from_pretrained(args.output_dir)
203+
eval_outputs = trainer.evaluate()
204+
print(json.dumps(
205+
eval_outputs, ensure_ascii=False, indent=4
206+
))
207+
208+
if args.do_predict and predict_dataset is not None:
209+
trainer.from_pretrained(args.output_dir)
210+
outputs = trainer.predict('test', ['logits'], dataset=predict_dataset)
211+
label_ids = np.argmax(outputs['logits'], axis=-1)
212+
labels = list(map(lambda x: args.labels[x], label_ids))
213+
open(
214+
os.path.join(args.output_dir, 'prediction.txt'), 'w', encoding='utf-8'
215+
).write("\n".join(labels))
216+
217+
if args.do_export:
218+
trainer.export(
219+
get_serving_fn(config, args),
220+
args.output_dir,
221+
args.export_dir
222+
)
223+
224+
225+
if __name__ == '__main__':
226+
main()

‎tfbert/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@
1111

1212
from .models import (
1313
BertModel, ALBertModel, ElectraModel,
14-
NezhaModel, WoBertModel,
14+
NezhaModel, WoBertModel, GlyceBertModel,
1515
SequenceClassification, MODELS, crf,
1616
TokenClassification, MultiLabelClassification,
1717
MaskedLM, PretrainingLM, QuestionAnswering)
1818
from .config import (
1919
BaseConfig, BertConfig, ALBertConfig,
20-
ElectraConfig, NeZhaConfig, WoBertConfig, CONFIGS)
20+
ElectraConfig, NeZhaConfig, WoBertConfig, GlyceBertConfig, CONFIGS)
2121
from .tokenizer import (
2222
BasicTokenizer, BertTokenizer, WoBertTokenizer,
23-
ALBertTokenizer, ElectraTokenizer, NeZhaTokenizer, TOKENIZERS)
23+
ALBertTokenizer, ElectraTokenizer, NeZhaTokenizer,
24+
GlyceBertTokenizer, TOKENIZERS)
2425

2526
from .utils import (
2627
devices, init_checkpoints,

‎tfbert/config/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from .base import BaseConfig
99
from .ptm import (
10-
BertConfig, ALBertConfig, ElectraConfig)
10+
BertConfig, ALBertConfig, ElectraConfig, GlyceBertConfig)
1111
from .ptm import BertConfig as NeZhaConfig
1212
from .ptm import BertConfig as WoBertConfig
1313

1414
CONFIGS = {
1515
'bert': BertConfig, 'albert': ALBertConfig,
1616
'nezha': NeZhaConfig, 'electra': ElectraConfig,
17-
'wobert': WoBertConfig
17+
'wobert': WoBertConfig, 'glyce_bert': GlyceBertConfig
1818
}

‎tfbert/config/ptm.py

+80
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from . import BaseConfig
88
import re
99
import tensorflow.compat.v1 as tf
10+
import os
11+
import shutil
1012

1113

1214
class BertConfig(BaseConfig):
@@ -150,3 +152,81 @@ def from_checkpoint(cls, checkpoint_path, **kwargs):
150152
param['num_attention_heads'] = max(1, param["hidden_size"] // 64)
151153

152154
return cls(**param, **kwargs)
155+
156+
157+
class GlyceBertConfig(BaseConfig):
158+
def __init__(self,
159+
vocab_size,
160+
embedding_size=None,
161+
hidden_size=768,
162+
num_hidden_layers=12,
163+
num_attention_heads=12,
164+
intermediate_size=3072,
165+
hidden_act="gelu",
166+
hidden_dropout_prob=0.1,
167+
attention_probs_dropout_prob=0.1,
168+
max_position_embeddings=512,
169+
type_vocab_size=16,
170+
initializer_range=0.02,
171+
config_path="",
172+
**kwargs
173+
):
174+
super().__init__(**kwargs)
175+
176+
self.vocab_size = vocab_size
177+
self.embedding_size = embedding_size if embedding_size is not None else hidden_size
178+
self.hidden_size = hidden_size
179+
self.num_hidden_layers = num_hidden_layers
180+
self.num_attention_heads = num_attention_heads
181+
self.hidden_act = hidden_act
182+
self.intermediate_size = intermediate_size
183+
self.hidden_dropout_prob = hidden_dropout_prob
184+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
185+
self.max_position_embeddings = max_position_embeddings
186+
self.type_vocab_size = type_vocab_size
187+
self.initializer_range = initializer_range
188+
self.config_path = config_path
189+
190+
@classmethod
191+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
192+
'''
193+
从文件夹或文件中加载config
194+
:param pretrained_model_name_or_path:
195+
:param kwargs:
196+
:return:
197+
'''
198+
199+
if os.path.isdir(pretrained_model_name_or_path):
200+
config_file = os.path.join(pretrained_model_name_or_path, cls.filename)
201+
config_path = os.path.join(pretrained_model_name_or_path, "config")
202+
elif os.path.isfile(pretrained_model_name_or_path):
203+
config_file = pretrained_model_name_or_path
204+
dir_ = os.path.split(config_file)[0]
205+
config_path = os.path.join(dir_, 'config')
206+
else:
207+
raise ValueError('Config path should be a directory or file')
208+
209+
config_dict = cls._dict_from_json_file(config_file)
210+
kwargs['config_path'] = config_path
211+
return cls.from_dict(config_dict, **kwargs)
212+
213+
def save_pretrained(self, save_dir_or_file):
214+
if os.path.isdir(save_dir_or_file):
215+
output_config_file = os.path.join(save_dir_or_file, self.filename)
216+
config_path = os.path.join(save_dir_or_file, 'config')
217+
else:
218+
output_config_file = save_dir_or_file
219+
config_path = os.path.join(os.path.split(save_dir_or_file)[0], "config")
220+
if not os.path.exists(config_path):
221+
os.makedirs(config_path)
222+
223+
filenames = os.listdir(self.config_path)
224+
if len(filenames) > 0:
225+
for filename in filenames:
226+
if filename.endswith('.npy'):
227+
shutil.copyfile(
228+
os.path.join(self.config_path, filename), os.path.join(config_path, filename)
229+
)
230+
self.save_to_json_file(output_config_file)
231+
tf.logging.info(' Configuration saved in {}'.format(output_config_file))
232+
return output_config_file

‎tfbert/data/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def fn(values):
8989
return tf.int32
9090
elif (isinstance(values, list) and isinstance(values[0], str)) or isinstance(values, str):
9191
return tf.string
92+
elif isinstance(values, list) and isinstance(values[0], list):
93+
return fn(values[0])
9294
else:
9395
raise ValueError(f"values={values} has dtype {values.dtype}, which cannot be supported")
9496

‎tfbert/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .albert import ALBertModel
99
from .electra import ElectraModel
1010
from .nezha import NezhaModel
11+
from .glyce_bert import GlyceBertModel
1112
from .model_utils import (
1213
dropout, layer_norm_and_dropout, layer_norm,
1314
create_weight, get_shape_list, gather_indexes, create_initializer)
@@ -21,7 +22,8 @@
2122
'albert': ALBertModel,
2223
'electra': ElectraModel,
2324
'wobert': WoBertModel,
24-
'nezha': NezhaModel
25+
'nezha': NezhaModel,
26+
'glyce_bert': GlyceBertModel
2527
}
2628

2729
from .for_task import (

‎tfbert/models/embeddings.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
# @FileName :embeddings.py
33
# @Time :2021/1/31 15:32
44
# @Author :huanghui
5+
6+
import numpy as np
57
import tensorflow.compat.v1 as tf
6-
from . import model_utils
8+
from . import model_utils, layers
79

810

911
def create_word_embeddings(
@@ -78,3 +80,52 @@ def create_position_embeddings(
7880
# position_embeddings = tf.nn.embedding_lookup(full_position_embeddings, tf.range(0, seq_len))
7981

8082
return position_embeddings
83+
84+
85+
def create_pinyin_embeddings(pinyin_ids, embedding_size: int, pinyin_out_dim: int, initializer_range,
86+
pinyin_vocab_size):
87+
"""chineseBERT 的pinyin嵌入"""
88+
input_shape = model_utils.get_shape_list(pinyin_ids) # bs, seq_len, pinyin_locs
89+
pinyin_table = model_utils.create_weight(
90+
shape=[pinyin_vocab_size, embedding_size],
91+
var_name='pinyin_embeddings/embeddings',
92+
initializer_range=initializer_range
93+
)
94+
flat_pinyin_ids = tf.reshape(pinyin_ids, [-1])
95+
pinyin_embeddings = tf.gather(pinyin_table, flat_pinyin_ids)
96+
pinyin_embeddings = tf.reshape(pinyin_embeddings,
97+
[input_shape[0] * input_shape[1], input_shape[2],
98+
embedding_size]) # bs * seq_len, pinyin_locs, embed_size
99+
pinyin_embeddings = tf.expand_dims(pinyin_embeddings, -1) # bs * seq_len, pinyin_locs, embed_size, 1
100+
with tf.variable_scope("pinyin_embeddings/conv"):
101+
# 接一个charCNN
102+
filter_shape = [2, embedding_size, 1, pinyin_out_dim]
103+
pinyin_embeddings = layers.conv2d_layer(
104+
pinyin_embeddings, filter_shape, padding="VALID", act=None,
105+
initializer_range=0.1) # bs * seq_len, pinyin_locs - 2 + 1, 1, pinyin_out_dim
106+
pinyin_embeddings = layers.max_pooling_layer(
107+
pinyin_embeddings, ksize=[1, input_shape[2] - 2 + 1, 1, 1]) # bs * seq_len, 1, 1, pinyin_out_dim
108+
pinyin_embeddings = tf.reshape(pinyin_embeddings, input_shape[:2] + [pinyin_out_dim])
109+
return pinyin_embeddings
110+
111+
112+
def create_glyph_embeddings(input_ids, font_npy_files):
113+
font_arrays = [
114+
np.load(np_file).astype(np.float32) for np_file in font_npy_files
115+
]
116+
vocab_size = font_arrays[0].shape[0]
117+
font_num = len(font_arrays)
118+
font_size = font_arrays[0].shape[-1]
119+
font_array = np.stack(font_arrays, axis=1)
120+
glyph_table = tf.get_variable(
121+
name="glyph_embeddings/embeddings",
122+
shape=[vocab_size, font_size ** 2 * font_num],
123+
initializer=tf.constant_initializer(font_array.reshape([vocab_size, -1])))
124+
125+
flat_input_ids = tf.reshape(input_ids, [-1])
126+
output = tf.gather(glyph_table, flat_input_ids)
127+
input_shape = model_utils.get_shape_list(input_ids)
128+
129+
output = tf.reshape(output,
130+
input_shape + [font_size ** 2 * font_num])
131+
return output

‎tfbert/models/for_task.py

+48-22
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self,
1515
num_classes,
1616
is_training,
1717
input_ids,
18+
pinyin_ids=None,
1819
attention_mask=None,
1920
token_type_ids=None,
2021
label_ids=None,
@@ -38,13 +39,17 @@ def __init__(self,
3839
if model_type not in MODELS:
3940
raise ValueError("Unsupported model option: {}, "
4041
"you can choose one of {}".format(model_type, "、".join(MODELS.keys())))
42+
kwargs = {
43+
'input_ids': input_ids,
44+
'attention_mask': attention_mask,
45+
'token_type_ids': token_type_ids}
46+
if model_type == 'glyce_bert':
47+
kwargs['pinyin_ids'] = pinyin_ids
4148

4249
model = MODELS[model_type](
4350
config,
4451
is_training=is_training,
45-
input_ids=input_ids,
46-
attention_mask=attention_mask,
47-
token_type_ids=token_type_ids,
52+
**kwargs,
4853
compute_type=compute_type
4954
)
5055
pooled_output = model.get_pooled_output()
@@ -69,6 +74,7 @@ def __init__(self,
6974
num_classes,
7075
is_training,
7176
input_ids,
77+
pinyin_ids=None,
7278
attention_mask=None,
7379
token_type_ids=None,
7480
label_ids=None,
@@ -93,14 +99,17 @@ def __init__(self,
9399
if model_type not in MODELS:
94100
raise ValueError("Unsupported model option: {}, "
95101
"you can choose one of {}".format(model_type, "、".join(MODELS.keys())))
102+
kwargs = {
103+
'input_ids': input_ids,
104+
'attention_mask': attention_mask,
105+
'token_type_ids': token_type_ids}
106+
if model_type == 'glyce_bert':
107+
kwargs['pinyin_ids'] = pinyin_ids
96108

97109
model = MODELS[model_type](
98110
config,
99111
is_training=is_training,
100-
input_ids=input_ids,
101-
attention_mask=attention_mask,
102-
token_type_ids=token_type_ids,
103-
return_pool=False,
112+
**kwargs,
104113
compute_type=compute_type
105114
)
106115
sequence_output = model.get_sequence_output()
@@ -141,6 +150,7 @@ def __init__(self,
141150
num_classes,
142151
is_training,
143152
input_ids,
153+
pinyin_ids=None,
144154
attention_mask=None,
145155
token_type_ids=None,
146156
label_ids=None,
@@ -164,13 +174,17 @@ def __init__(self,
164174
if model_type not in MODELS:
165175
raise ValueError("Unsupported model option: {}, "
166176
"you can choose one of {}".format(model_type, "、".join(MODELS.keys())))
177+
kwargs = {
178+
'input_ids': input_ids,
179+
'attention_mask': attention_mask,
180+
'token_type_ids': token_type_ids}
181+
if model_type == 'glyce_bert':
182+
kwargs['pinyin_ids'] = pinyin_ids
167183

168184
model = MODELS[model_type](
169185
config,
170186
is_training=is_training,
171-
input_ids=input_ids,
172-
attention_mask=attention_mask,
173-
token_type_ids=token_type_ids,
187+
**kwargs,
174188
compute_type=compute_type
175189
)
176190
pooled_output = model.get_pooled_output()
@@ -196,6 +210,7 @@ def __init__(self,
196210
config,
197211
is_training,
198212
input_ids,
213+
pinyin_ids=None,
199214
attention_mask=None,
200215
token_type_ids=None,
201216
start_position=None,
@@ -219,14 +234,17 @@ def __init__(self,
219234
if model_type not in MODELS:
220235
raise ValueError("Unsupported model option: {}, "
221236
"you can choose one of {}".format(model_type, "、".join(MODELS.keys())))
237+
kwargs = {
238+
'input_ids': input_ids,
239+
'attention_mask': attention_mask,
240+
'token_type_ids': token_type_ids}
241+
if model_type == 'glyce_bert':
242+
kwargs['pinyin_ids'] = pinyin_ids
222243

223244
model = MODELS[model_type](
224245
config,
225246
is_training=is_training,
226-
input_ids=input_ids,
227-
attention_mask=attention_mask,
228-
token_type_ids=token_type_ids,
229-
return_pool=False,
247+
**kwargs,
230248
compute_type=compute_type
231249
)
232250
sequence_output = model.get_sequence_output()
@@ -270,6 +288,7 @@ def __init__(
270288
config,
271289
is_training,
272290
input_ids,
291+
pinyin_ids=None,
273292
attention_mask=None,
274293
token_type_ids=None,
275294
masked_lm_ids=None,
@@ -294,14 +313,17 @@ def __init__(
294313
if model_type not in MODELS:
295314
raise ValueError("Unsupported model option: {}, "
296315
"you can choose one of {}".format(model_type, "、".join(MODELS.keys())))
316+
kwargs = {
317+
'input_ids': input_ids,
318+
'attention_mask': attention_mask,
319+
'token_type_ids': token_type_ids}
320+
if model_type == 'glyce_bert':
321+
kwargs['pinyin_ids'] = pinyin_ids
297322

298323
model = MODELS[model_type](
299324
config,
300325
is_training=is_training,
301-
input_ids=input_ids,
302-
attention_mask=attention_mask,
303-
token_type_ids=token_type_ids,
304-
return_pool=False,
326+
**kwargs,
305327
compute_type=compute_type
306328
)
307329
sequence_output = model.get_sequence_output()
@@ -322,6 +344,7 @@ def __init__(
322344
config,
323345
is_training,
324346
input_ids,
347+
pinyin_ids=None,
325348
attention_mask=None,
326349
token_type_ids=None,
327350
masked_lm_ids=None,
@@ -334,14 +357,17 @@ def __init__(
334357
if model_type not in MODELS:
335358
raise ValueError("Unsupported model option: {}, "
336359
"you can choose one of {}".format(model_type, "、".join(MODELS.keys())))
360+
kwargs = {
361+
'input_ids': input_ids,
362+
'attention_mask': attention_mask,
363+
'token_type_ids': token_type_ids}
364+
if model_type == 'glyce_bert':
365+
kwargs['pinyin_ids'] = pinyin_ids
337366

338367
model = MODELS[model_type](
339368
config,
340369
is_training=is_training,
341-
input_ids=input_ids,
342-
attention_mask=attention_mask,
343-
token_type_ids=token_type_ids,
344-
return_pool=True,
370+
**kwargs,
345371
compute_type=compute_type
346372
)
347373
sequence_output = model.get_sequence_output()

‎tfbert/models/glyce_bert.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# -*- coding:utf-8 -*-
2+
# @FileName :glyce_bert.py
3+
# @Time :2021/7/29 14:11
4+
# @Author :huanghui
5+
import os
6+
import json
7+
import tensorflow.compat.v1 as tf
8+
from . import embeddings, layers, model_utils
9+
from .base import BaseModel
10+
from .bert import bert_encoder
11+
12+
13+
def glyph_bert_embeddings(
14+
config,
15+
input_ids,
16+
pinyin_ids,
17+
token_type_ids=None
18+
):
19+
(word_embeddings, embedding_table) = embeddings.create_word_embeddings(
20+
input_ids=input_ids,
21+
vocab_size=config.vocab_size,
22+
embedding_size=config.embedding_size,
23+
initializer_range=config.initializer_range,
24+
word_embedding_name="word_embeddings"
25+
)
26+
27+
with open(os.path.join(config.config_path, 'pinyin_map.json')) as fin:
28+
pinyin_dict = json.load(fin)
29+
pinyin_embeddings = embeddings.create_pinyin_embeddings(
30+
pinyin_ids,
31+
embedding_size=128,
32+
pinyin_out_dim=config.embedding_size,
33+
initializer_range=config.initializer_range,
34+
pinyin_vocab_size=len(pinyin_dict['idx2char']))
35+
36+
font_files = []
37+
for file in os.listdir(config.config_path):
38+
if file.endswith(".npy"):
39+
font_files.append(os.path.join(config.config_path, file))
40+
glyph_embeddings = embeddings.create_glyph_embeddings(
41+
input_ids, font_files
42+
)
43+
glyph_embeddings = layers.dense(glyph_embeddings, config.embedding_size, name="glyph_map")
44+
45+
# fusion layer
46+
concat_embeddings = tf.concat([word_embeddings, pinyin_embeddings, glyph_embeddings], axis=2)
47+
inputs_embeds = layers.dense(concat_embeddings, config.embedding_size, name='map_fc')
48+
49+
token_type_embeddings = embeddings.create_token_type_embeddings(
50+
token_type_ids=token_type_ids,
51+
embedding_size=config.embedding_size,
52+
token_type_vocab_size=config.type_vocab_size,
53+
token_type_embedding_name='token_type_embeddings',
54+
initializer_range=config.initializer_range
55+
)
56+
57+
position_embeddings = embeddings.create_position_embeddings(
58+
seq_len=model_utils.get_shape_list(input_ids)[1],
59+
embedding_size=config.embedding_size,
60+
position_embedding_name='position_embeddings',
61+
initializer_range=config.initializer_range,
62+
max_position_embeddings=config.max_position_embeddings
63+
)
64+
65+
embedding_output = inputs_embeds + position_embeddings + token_type_embeddings
66+
embedding_output = model_utils.layer_norm_and_dropout(
67+
embedding_output,
68+
config.hidden_dropout_prob
69+
)
70+
71+
return embedding_output, embedding_table
72+
73+
74+
class GlyceBertModel(BaseModel):
75+
def __init__(
76+
self,
77+
config,
78+
is_training,
79+
input_ids,
80+
pinyin_ids,
81+
attention_mask=None,
82+
token_type_ids=None,
83+
return_pool=True,
84+
scope=None,
85+
reuse=False,
86+
compute_type=tf.float32
87+
):
88+
super().__init__(config, is_training)
89+
90+
input_shape = model_utils.get_shape_list(input_ids, expected_rank=2)
91+
batch_size = input_shape[0]
92+
seq_length = input_shape[1]
93+
94+
if attention_mask is None:
95+
attention_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int64)
96+
97+
if token_type_ids is None:
98+
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int64)
99+
100+
with tf.variable_scope(
101+
scope, default_name="bert",
102+
reuse=tf.AUTO_REUSE if reuse else None,
103+
custom_getter=model_utils.get_custom_getter(compute_type)):
104+
with tf.variable_scope("embeddings"):
105+
self.embedding_output, self.embedding_table = glyph_bert_embeddings(
106+
config=self.config,
107+
input_ids=input_ids,
108+
pinyin_ids=pinyin_ids,
109+
token_type_ids=token_type_ids
110+
)
111+
112+
with tf.variable_scope("encoder"):
113+
attention_mask = model_utils.create_bert_mask(
114+
input_ids, attention_mask)
115+
if model_utils.get_shape_list(self.embedding_output)[-1] != self.config.hidden_size:
116+
self.embedding_output = layers.dense(
117+
self.embedding_output, self.config.hidden_size,
118+
'embedding_hidden_mapping_in', initializer_range=self.config.initializer_range
119+
)
120+
encoder_outputs = bert_encoder(
121+
input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
122+
attention_mask=attention_mask,
123+
config=self.config,
124+
use_relative_position=False
125+
)
126+
if return_pool:
127+
with tf.variable_scope("pooler"):
128+
pooled_output = layers.pooler_layer(
129+
sequence_output=encoder_outputs[0],
130+
hidden_size=self.config.hidden_size,
131+
initializer_range=self.config.initializer_range
132+
)
133+
else:
134+
pooled_output = None
135+
# (pooled output, sequence output, all layer outputs, all layer att probs)
136+
self.outputs = (pooled_output,) + encoder_outputs

‎tfbert/models/layers.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# @Author :huanghui
55

66
import tensorflow.compat.v1 as tf
7+
from tensorflow.python.ops import gen_nn_ops
78
import math
89
from . import model_utils, activations
910
from . import crf
@@ -292,20 +293,21 @@ def conv2d_layer(
292293
if strides is None:
293294
strides = [1, 1, 1, 1]
294295
W = tf.get_variable(
295-
name='weight', shape=filter_shape,
296+
name='kernel', shape=filter_shape,
296297
initializer=model_utils.create_initializer(initializer_range))
297298
b = tf.get_variable(
298299
name='bias', shape=[filter_shape[-1]],
299300
initializer=model_utils.create_initializer(initializer_range))
300301
output = tf.nn.conv2d(input_tensor, W, strides=strides, padding=padding)
302+
output = tf.nn.bias_add(output, b)
301303
act_fn = activations.get_activation(act)
302304
if act_fn is not None:
303-
output = act_fn(tf.nn.bias_add(output, b))
305+
output = act_fn(output)
304306
return output
305307

306308

307309
def max_pooling_layer(
308-
input_tensor, ksize: List[int],
310+
input_tensor, ksize,
309311
strides=None, padding="VALID",
310312
name='max_pool'):
311313
"""
@@ -319,13 +321,22 @@ def max_pooling_layer(
319321
"""
320322
if strides is None:
321323
strides = [1, 1, 1, 1]
322-
output = tf.nn.max_pool(
324+
325+
# 支持动态大小的池化
326+
output = gen_nn_ops.max_pool_v2(
323327
input_tensor,
324328
ksize=ksize,
325329
strides=strides,
326330
padding=padding,
327331
name=name
328332
)
333+
# output = tf.nn.max_pool(
334+
# input_tensor,
335+
# ksize=ksize,
336+
# strides=strides,
337+
# padding=padding,
338+
# name=name
339+
# )
329340
return output
330341

331342

‎tfbert/tokenizer/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from .bert import BertTokenizer as NeZhaTokenizer
1010
from .bert import BertTokenizer as ElectraTokenizer
1111
from .wobert import WoBertTokenizer
12+
from .glyce_bert import GlyceBertTokenizer
1213

1314
TOKENIZERS = {
1415
'bert': BertTokenizer, 'albert': ALBertTokenizer,
1516
'nezha': NeZhaTokenizer, 'electra': ElectraTokenizer,
16-
'wobert': WoBertTokenizer
17+
'wobert': WoBertTokenizer, 'glyce_bert': GlyceBertTokenizer
1718
}

‎tfbert/tokenizer/glyce_bert.py

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# -*- coding:utf-8 -*-
2+
# @FileName :glyce_bert.py
3+
# @Time :2021/7/29 18:19
4+
# @Author :huanghui
5+
import os
6+
import tensorflow.compat.v1 as tf
7+
import json
8+
from .tokenization_base import convert_to_unicode, PaddingStrategy, TruncationStrategy
9+
from .bert import BertTokenizer
10+
from typing import List, Union, Tuple, Optional
11+
12+
13+
class GlyceBertTokenizer(BertTokenizer):
14+
def __init__(self, config_path, **kwargs):
15+
super(GlyceBertTokenizer, self).__init__(**kwargs)
16+
# load pinyin map dict
17+
with open(os.path.join(config_path, 'pinyin_map.json'), encoding='utf8') as fin:
18+
self.pinyin_dict = json.load(fin)
19+
# load char id map tensor
20+
with open(os.path.join(config_path, 'id2pinyin.json'), encoding='utf8') as fin:
21+
self.id2pinyin = json.load(fin)
22+
# load pinyin map tensor
23+
with open(os.path.join(config_path, 'pinyin2tensor.json'), encoding='utf8') as fin:
24+
self.pinyin2tensor = json.load(fin)
25+
26+
def save_pretrained(self, save_directory):
27+
28+
if os.path.isdir(save_directory):
29+
vocab_file = os.path.join(save_directory, 'vocab.txt')
30+
config_path = os.path.join(save_directory, 'config')
31+
else:
32+
vocab_file = save_directory
33+
config_path = os.path.join(os.path.split(save_directory)[0], "config")
34+
35+
if not os.path.exists(config_path):
36+
os.makedirs(config_path)
37+
38+
with open(os.path.join(config_path, 'pinyin_map.json'), "w", encoding='utf8') as fin:
39+
fin.write(json.dumps(self.pinyin_dict, ensure_ascii=False))
40+
41+
with open(os.path.join(config_path, 'id2pinyin.json'), "w", encoding='utf8') as fin:
42+
fin.write(json.dumps(self.id2pinyin, ensure_ascii=False))
43+
44+
with open(os.path.join(config_path, 'pinyin2tensor.json'), "w", encoding='utf8') as fin:
45+
fin.write(json.dumps(self.pinyin2tensor, ensure_ascii=False))
46+
47+
with open(vocab_file, 'w', encoding='utf-8') as writer:
48+
for token, index in self.vocab.items():
49+
writer.write(token.strip() + '\n')
50+
tf.logging.info(" Tokenizer vocab saved in {}".format(vocab_file))
51+
return vocab_file
52+
53+
@classmethod
54+
def from_pretrained(cls, vocab_dir_or_file, **kwargs):
55+
do_lower_case = kwargs.pop('do_lower_case', True)
56+
if os.path.isdir(vocab_dir_or_file):
57+
filename = 'vocab.txt'
58+
vocab_file = os.path.join(vocab_dir_or_file, filename)
59+
config_path = os.path.join(vocab_dir_or_file, "config")
60+
else:
61+
vocab_file = vocab_dir_or_file
62+
config_path = os.path.join(os.path.split(vocab_dir_or_file)[0], "config")
63+
64+
return cls(config_path=config_path, vocab_file=vocab_file, do_lower_case=do_lower_case, **kwargs)
65+
66+
def convert_token_ids_to_pinyin_ids(self, ids):
67+
from pypinyin import pinyin, Style
68+
69+
tokens = self.convert_ids_to_tokens(ids)
70+
offsets = []
71+
pos = 0
72+
sentence = ""
73+
for token in tokens:
74+
token = token.replace("##", "").strip()
75+
76+
if len(token) == 0:
77+
token = " "
78+
if token in self.all_special_tokens:
79+
token = " "
80+
offsets.append((0, 0))
81+
else:
82+
offsets.append((pos, pos + len(token)))
83+
pos += len(token)
84+
sentence += token
85+
86+
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
87+
pinyin_locs = {}
88+
# get pinyin of each location
89+
for index, item in enumerate(pinyin_list):
90+
pinyin_string = item[0]
91+
# not a Chinese character, pass
92+
if pinyin_string == "not chinese":
93+
continue
94+
if pinyin_string in self.pinyin2tensor:
95+
pinyin_locs[index] = self.pinyin2tensor[pinyin_string]
96+
else:
97+
ids = [0] * 8
98+
for i, p in enumerate(pinyin_string):
99+
if p not in self.pinyin_dict["char2idx"]:
100+
ids = [0] * 8
101+
break
102+
ids[i] = self.pinyin_dict["char2idx"][p]
103+
pinyin_locs[index] = ids
104+
105+
# find chinese character location, and generate pinyin ids
106+
pinyin_ids = []
107+
for idx, offset in enumerate(offsets):
108+
if offset[1] - offset[0] != 1:
109+
pinyin_ids.append([0] * 8)
110+
continue
111+
if offset[0] in pinyin_locs:
112+
pinyin_ids.append(pinyin_locs[offset[0]])
113+
else:
114+
pinyin_ids.append([0] * 8)
115+
116+
return pinyin_ids
117+
118+
def _encode_plus(
119+
self,
120+
text: Union[str, List[str], List[int]],
121+
text_pair: Optional[Union[str, List[str], List[int]]] = None,
122+
add_special_tokens: bool = True,
123+
padding_strategy: Union[bool, str, PaddingStrategy] = PaddingStrategy.DO_NOT_PAD,
124+
truncation_strategy: Union[bool, str, TruncationStrategy] = TruncationStrategy.DO_NOT_TRUNCATE,
125+
max_length: Optional[int] = None,
126+
stride: int = 0,
127+
return_token_type_ids: Optional[bool] = None,
128+
return_attention_mask: Optional[bool] = None,
129+
return_overflowing_tokens: bool = False,
130+
return_special_tokens_mask: bool = False,
131+
return_length: bool = False,
132+
):
133+
first_ids = self.get_input_ids(text)
134+
second_ids = self.get_input_ids(text_pair) if text_pair is not None else None
135+
encoded = self.prepare_for_model(
136+
first_ids,
137+
pair_ids=second_ids,
138+
add_special_tokens=add_special_tokens,
139+
padding=padding_strategy,
140+
truncation=truncation_strategy,
141+
max_length=max_length,
142+
stride=stride,
143+
return_attention_mask=return_attention_mask,
144+
return_token_type_ids=return_token_type_ids,
145+
return_overflowing_tokens=return_overflowing_tokens,
146+
return_special_tokens_mask=return_special_tokens_mask,
147+
return_length=return_length
148+
)
149+
pinyin_ids = self.convert_token_ids_to_pinyin_ids(encoded['input_ids'])
150+
assert len(pinyin_ids) == len(encoded['input_ids'])
151+
encoded['pinyin_ids'] = pinyin_ids
152+
return encoded
153+
154+
def _batch_encode_plus(
155+
self,
156+
batch_text_or_text_pairs: Union[
157+
List[str],
158+
List[Tuple[str, str]],
159+
List[Tuple[List[str], List[str]]],
160+
List[Tuple[str, str]],
161+
List[List[int]],
162+
List[Tuple[List[int], List[int]]],
163+
],
164+
add_special_tokens: bool = True,
165+
padding_strategy: Union[bool, str, PaddingStrategy] = PaddingStrategy.DO_NOT_PAD,
166+
truncation_strategy: Union[bool, str, TruncationStrategy] = TruncationStrategy.DO_NOT_TRUNCATE,
167+
max_length: Optional[int] = None,
168+
stride: int = 0,
169+
is_split_into_words: bool = False,
170+
return_token_type_ids: Optional[bool] = None,
171+
return_attention_mask: Optional[bool] = None,
172+
return_overflowing_tokens: bool = False,
173+
return_special_tokens_mask: bool = False,
174+
return_length: bool = False,
175+
):
176+
input_ids = []
177+
for ids_or_pair_ids in batch_text_or_text_pairs:
178+
if not isinstance(ids_or_pair_ids, (list, tuple)):
179+
ids, pair_ids = ids_or_pair_ids, None
180+
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
181+
ids, pair_ids = ids_or_pair_ids, None
182+
else:
183+
ids, pair_ids = ids_or_pair_ids
184+
185+
first_ids = self.get_input_ids(ids)
186+
second_ids = self.get_input_ids(pair_ids) if pair_ids is not None else None
187+
input_ids.append((first_ids, second_ids))
188+
189+
batch_outputs = self._batch_prepare_for_model(
190+
input_ids,
191+
add_special_tokens=add_special_tokens,
192+
padding_strategy=padding_strategy,
193+
truncation_strategy=truncation_strategy,
194+
max_length=max_length,
195+
stride=stride,
196+
return_attention_mask=return_attention_mask,
197+
return_token_type_ids=return_token_type_ids,
198+
return_overflowing_tokens=return_overflowing_tokens,
199+
return_special_tokens_mask=return_special_tokens_mask,
200+
return_length=return_length
201+
)
202+
batch_pinyin_ids = []
203+
for i in batch_outputs['input_ids']:
204+
pinyin_ids = self.convert_token_ids_to_pinyin_ids(batch_outputs['input_ids'][i])
205+
assert len(pinyin_ids) == len(batch_outputs['input_ids'][i])
206+
batch_pinyin_ids.append(pinyin_ids)
207+
batch_outputs['pinyin_ids'] = batch_pinyin_ids
208+
return batch_outputs

0 commit comments

Comments
 (0)
Please sign in to comment.