Skip to content

Commit e2faa21

Browse files
topdutopdu
authored andcommitted
fix unirec postprocess
1 parent efcd54e commit e2faa21

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

demo_unirec.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from threading import Thread
44

55
import numpy as np
6-
import re
76
from openrec.postprocess.unirec_postprocess import clean_special_tokens
87
from openrec.preprocess import create_operators, transform
98
from tools.engine.config import Config
@@ -41,18 +40,6 @@ def set_device(device):
4140
transforms, ratio_resize_flag = build_rec_process(cfg)
4241
ops = create_operators(transforms, global_config)
4342

44-
rules = [
45-
(r'-<\|sn\|>', ''),
46-
(r' <\|sn\|>', ' '),
47-
(r'<\|sn\|>', ' '),
48-
(r'<\|unk\|>', ''),
49-
(r'<s>', ''),
50-
(r'</s>', ''),
51-
(r'\uffff', ''),
52-
(r'_{4,}', '___'),
53-
(r'\.{4,}', '...'),
54-
]
55-
5643

5744
# --- 2. 定义流式生成函数 ---
5845
def stream_chat_with_image(input_image, history):
@@ -74,17 +61,20 @@ def stream_chat_with_image(input_image, history):
7461
'input_ids': None,
7562
'attention_mask': None
7663
}
77-
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
64+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
7865
# 后台线程运行生成
7966
thread = Thread(target=model.generate, kwargs=generation_kwargs)
8067
thread.start()
8168
# 流式输出
82-
generated_text = ''
8369
history = history + [('🖼️(图片)', '')]
70+
generated_text_ori = ''
8471
for new_text in streamer:
85-
generated_text += clean_special_tokens(new_text)
86-
for rule in rules:
87-
generated_text = re.sub(rule[0], rule[1], generated_text)
72+
generated_text_ori += new_text
73+
generated_text = clean_special_tokens(
74+
generated_text_ori.replace(' ', ''))
75+
text = generated_text.replace('<tdcolspan=', '<td colspan=')
76+
text = text.replace('<tdrowspan=', '<td rowspan=')
77+
generated_text = text.replace('"colspan=', '" colspan=')
8878
history[-1] = ('🖼️(图片)', generated_text)
8979
yield history
9080

openrec/postprocess/unirec_postprocess.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515

1616

1717
def clean_special_tokens(text):
18-
text = text.replace(' ', '').replace('Ġ', ' ').replace('Ċ', '\n').replace(
19-
'<|bos|>', '').replace('<|eos|>', '').replace('<|pad|>', '')
18+
text = text.replace('Ġ',
19+
' ').replace('Ċ', '\n').replace('<|bos|>', '').replace(
20+
'<|eos|>', '').replace('<|pad|>', '')
2021
for rule in rules:
2122
text = re.sub(rule[0], rule[1], text)
22-
text = text.replace('<tdcolspan=', '<td colspan=')
23-
text = text.replace('<tdrowspan=', '<td rowspan=')
24-
text = text.replace('"colspan=', '" colspan=')
2523
return text
2624

2725

@@ -44,7 +42,10 @@ def __init__(self,
4442
def __call__(self, preds, batch=None, *args, **kwargs):
4543
result_list = []
4644
pred_ids = preds
47-
res = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
45+
res = [
46+
''.join(self.tokenizer.convert_ids_to_tokens(seq.tolist()))
47+
for seq in pred_ids
48+
]
4849
for i in range(len(res)):
4950
res[i] = clean_special_tokens(res[i])
5051
result_list.append(

0 commit comments

Comments
 (0)