Skip to content

Commit e4194fc

Browse files
authored
Merge pull request #47 from 1041206149/LLM_dafen
使用LLM修正预测公式
2 parents 0a6ce17 + 47b32fa commit e4194fc

File tree

4 files changed

+174
-35
lines changed

4 files changed

+174
-35
lines changed

examples/multi_extractor_compare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def all_extractor_comparison():
88
print("\n=== 多抽取器对比演示 ===\n")
99

1010
# 创建数据集
11-
dataset_path = Path("data/sample_dataset.jsonl")
11+
dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl")
1212
dataset = DataLoader.load_jsonl(dataset_path)
1313

1414
# 创建webkit抽取器

webmainbench/data/saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,14 @@ def save_dataset_with_extraction(results: Union["EvaluationResult", Dict[str, An
303303

304304
# 解析预测值(predicted)
305305
predicted_content = extraction_result.get('extracted_content', '')
306-
predicted_parts = BaseMetric._extract_from_markdown(predicted_content) # 关键:解析预测内容
306+
predicted_parts = BaseMetric._extract_from_markdown(predicted_content, field_name="llm_webkit_md") # 关键:解析预测内容
307307
for part_type in ['code', 'formula', 'table', 'text']:
308308
sample_dict[f'{current_extractor_name}_predicted_{part_type}'] = predicted_parts.get(part_type, '')
309309

310310
# 解析真实值(groundtruth)- 只需要解析一次
311311
if extractor_names: # 只有当存在extractor时才解析
312312
groundtruth_content = sample_dict.get('groundtruth_content', '')
313-
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content) # 关键:解析真实内容
313+
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content, field_name="groundtruth_content") # 关键:解析真实内容
314314
for part_type in ['code', 'formula', 'table', 'text']:
315315
# 使用第一个extractor的名字作为前缀,或者使用通用前缀
316316
prefix = extractor_names[0] if len(extractor_names) == 1 else 'groundtruth'

webmainbench/metrics/base.py

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import traceback
99
import re
1010
from bs4 import BeautifulSoup
11+
import os
12+
import hashlib
1113

1214
@dataclass
1315
class MetricResult:
@@ -121,16 +123,16 @@ def batch_calculate(self, predicted_list: List[Any],
121123
result = self.calculate(pred, gt, **kwargs)
122124
results.append(result)
123125
return results
124-
126+
125127
@staticmethod
126-
def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[str, str]:
128+
def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_name: str = None) -> Dict[str, str]:
127129
"""
128130
统一的内容分割方法,将文本分为代码、公式、表格和剩余文本4个部分。
129-
131+
130132
Args:
131133
text: 原始markdown文本
132134
content_list: 结构化内容列表(来自llm-webkit等)
133-
135+
field_name: 当前处理的字段名称,传递给_extract_from_markdown
134136
Returns:
135137
Dict with keys: 'code', 'formula', 'table', 'text'
136138
"""
@@ -139,9 +141,9 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[
139141
extracted_content = BaseMetric._extract_from_content_list(content_list)
140142
if any(extracted_content.values()):
141143
return extracted_content
142-
143-
# 从markdown文本中提取
144-
return BaseMetric._extract_from_markdown(text or "")
144+
145+
# 从markdown文本中提取,传递字段名称
146+
return BaseMetric._extract_from_markdown(text or "", field_name=field_name)
145147

146148
@staticmethod
147149
def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]:
@@ -193,12 +195,12 @@ def _recursive_extract(items):
193195
'text': '\n'.join(extracted['text'])
194196
}
195197

196-
@staticmethod
197-
def _extract_from_markdown(text: str) -> Dict[str, str]:
198+
@staticmethod
199+
def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
198200
"""从markdown文本中提取各种类型的内容"""
199201
if not text:
200202
return {'code': '', 'formula': '', 'table': '', 'text': ''}
201-
203+
202204
# 收集所有需要移除的内容片段
203205
extracted_segments = []
204206
code_parts = []
@@ -291,34 +293,56 @@ def _extract_from_markdown(text: str) -> Dict[str, str]:
291293
if code_content.strip():
292294
code_parts.append(code_content)
293295

294-
# 提取公式
296+
# 提取公式 - 新的两步处理逻辑
295297
formula_parts = []
296-
# 统一的公式提取模式
298+
299+
# 第一步:先用正则提取公式
300+
regex_formulas = []
297301
latex_patterns = [
298-
# r'(?<!\\)\$\$([^$]+)\$\$(?!\\)', # Display math (not escaped)
299-
# r'(?<!\\)\$([^$\n]+)\$(?![\\\$])', # Inline math (not escaped)
300-
# r'\\begin\{equation\*?\}(.*?)\\end\{equation\*?\}', # Equation environment
301-
# r'\\begin\{align\*?\}(.*?)\\end\{align\*?\}', # Align environment
302-
# r'\\begin\{gather\*?\}(.*?)\\end\{gather\*?\}', # Gather environment
303-
# r'\\begin\{eqnarray\*?\}(.*?)\\end\{eqnarray\*?\}', # Eqnarray environment
304-
# r'\\begin\{multline\*?\}(.*?)\\end\{multline\*?\}', # Multline environment
305-
# r'\\begin\{split\}(.*?)\\end\{split\}', # Split environment
306-
# r'(?<!\\)\$\$([^$]+)\$\$(?!\\)',
307-
# r'(?<!\\)\$([^$\n\w][^$\n]*[^$\n\w])\$(?![\\\$])',
308-
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$,确保 $ 没有被转义
309-
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\],确保 \ 没有被转义
310-
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$,确保 $ 没有被转义
311-
# r'(?<!\\)\$(.*?)(?<!\\)\$(?!\d)', # 第二个$后面不能是数字
312-
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\),确保 \ 没有被转义
302+
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
303+
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
304+
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
305+
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
313306
]
314-
307+
315308
for pattern in latex_patterns:
316309
for match in re.finditer(pattern, text, re.DOTALL):
317-
formula_full = match.group(0) # 完整匹配(包含$符号)
318-
formula_content = match.group(1) # 只是公式内容
310+
formula_full = match.group(0)
311+
formula_content = match.group(1)
319312
extracted_segments.append(formula_full)
320313
if formula_content.strip():
321-
formula_parts.append(formula_content.strip())
314+
regex_formulas.append(formula_content.strip())
315+
316+
# 第二步:根据字段类型决定是否需要API修正
317+
if field_name == "groundtruth_content":
318+
print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式")
319+
formula_parts = regex_formulas
320+
else:
321+
print(f"[DEBUG] 检测到llm_webkit_md内容,使用正则+API修正模式")
322+
# 对于llm_webkit_md,将正则结果传递给API进行修正
323+
if regex_formulas:
324+
# 将正则提取的公式作为输入传递给API
325+
regex_formulas_text = '\n'.join(regex_formulas)
326+
print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正")
327+
328+
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
329+
os.makedirs(cache_dir, exist_ok=True)
330+
331+
# 使用正则结果的哈希作为缓存文件名
332+
text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest()
333+
cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json')
334+
335+
try:
336+
from .formula_extractor import correct_formulas_with_llm
337+
corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file)
338+
formula_parts = corrected_formulas
339+
print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式")
340+
except Exception as e:
341+
print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果")
342+
formula_parts = regex_formulas
343+
else:
344+
print(f"[DEBUG] 正则未提取到公式,跳过API修正")
345+
formula_parts = []
322346

323347
# 提取表格
324348
table_parts = []
@@ -468,4 +492,4 @@ def __str__(self) -> str:
468492
return f"{self.__class__.__name__}(name='{self.name}')"
469493

470494
def __repr__(self) -> str:
471-
return self.__str__()
495+
return self.__str__()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# webmainbench/metrics/formula_extractor.py
2+
import json
3+
import os
4+
from openai import OpenAI
5+
6+
def correct_formulas_with_llm(regex_formulas, cache_file=None):
7+
"""使用LLM API修正正则提取的公式"""
8+
9+
if not regex_formulas:
10+
print(f"[DEBUG] 输入公式列表为空,跳过API修正")
11+
return []
12+
13+
# 检查缓存
14+
if cache_file and os.path.exists(cache_file):
15+
try:
16+
with open(cache_file, 'r', encoding='utf-8') as f:
17+
cached_result = json.load(f)
18+
print(f"[DEBUG] 从缓存加载修正结果: {len(cached_result)} 个")
19+
return cached_result
20+
except Exception as e:
21+
print(f"[DEBUG] 缓存读取失败: {e}")
22+
23+
# API配置
24+
client = OpenAI(
25+
base_url="",
26+
api_key=""
27+
)
28+
29+
# 将正则提取的公式转换为文本
30+
formulas_text = '\n'.join(regex_formulas)
31+
32+
CORRECTION_PROMPT = '''任务:请从以下正则表达式提取的内容中,识别并保留真正的LaTeX数学公式,剔除货币形式的内容。
33+
34+
### 识别规则
35+
**真正的数学公式**(保留):
36+
- 包含数学符号:+ - × ÷ = < > ≤ ≥ ± ∞ ∑ ∫ ∂ √ ^ _ { } 等
37+
- 包含希腊字母:α β γ δ θ λ μ π σ ω 等
38+
- 包含LaTeX命令:\\frac \\sum \\int \\sqrt \\alpha \\beta \\sin \\cos 等
39+
- 包含数学表达式:变量、函数、方程等
40+
41+
**货币形式内容**(剔除):
42+
- 仅包含数字、逗号、小数点的价格:如 1,150.00
43+
- 纯粹的金额数值:如 25.99、1,200、5.50
44+
- 不包含任何数学运算符或数学符号的数字
45+
46+
### 处理要求
47+
1. **严格区分**:只保留真正的数学公式,剔除所有货币价格
48+
2. **格式标准化**:统一公式格式,确保LaTeX语法正确
49+
3. **保持原意**:不修改数学公式内容
50+
51+
### 输出格式
52+
- 每个有效的数学公式独占一行
53+
- 只输出公式内容,不包含$符号或其他包装
54+
- 如果输入不是有效的数学公式(如货币),则输出<空>
55+
- 按原顺序输出保留的公式
56+
57+
### 示例 1 (剔除后有有效公式)
58+
输入:1,150.00 → 剔除(货币)
59+
输入:x^2 + y^2 = r^2 → 保留(数学公式)
60+
输入:25.99 → 剔除(货币)
61+
输入:\\frac{a}{b} + c → 保留(数学公式)
62+
63+
### 示例 2 (剔除后无有效公式)
64+
输入:1,150.00 → 剔除(货币)
65+
输入:25.99 → 剔除(货币)
66+
67+
输出:<空>
68+
69+
注意,输出结果中不要添加任何解释!。
70+
[输入内容列表开始]'''
71+
72+
try:
73+
print(f"[DEBUG] 开始调用 OpenAI API 进行公式修正...")
74+
response = client.chat.completions.create(
75+
model="deepseek-chat",
76+
temperature=0,
77+
messages=[
78+
{"role": "user", "content": CORRECTION_PROMPT + f"\n{formulas_text}\n" + '''[输入内容列表结束]
79+
---
80+
请按要求识别并输出真正的数学公式,剔除货币形式的内容。
81+
---'''}
82+
]
83+
)
84+
85+
result_text = response.choices[0].message.content.strip()
86+
print(f"[DEBUG] API 返回修正结果: {repr(result_text)}")
87+
88+
# 检测返回内容是否包含"空"字 - 如果包含则整个结果为空
89+
if '空' in result_text:
90+
print(f"[DEBUG] 检测到API返回包含'空'字,将整个结果设置为空列表")
91+
corrected_formulas = []
92+
elif not result_text:
93+
corrected_formulas = []
94+
else:
95+
# 正常解析返回的公式列表
96+
corrected_formulas = [line.strip() for line in result_text.split('\n') if line.strip()]
97+
98+
print(f"[DEBUG] 修正后的公式列表: {corrected_formulas}")
99+
100+
# 保存缓存
101+
if cache_file:
102+
try:
103+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
104+
with open(cache_file, 'w', encoding='utf-8') as f:
105+
json.dump(corrected_formulas, f, ensure_ascii=False, indent=2)
106+
print(f"[DEBUG] 修正结果已缓存到: {cache_file}")
107+
except Exception as e:
108+
print(f"[DEBUG] 缓存保存失败: {e}")
109+
110+
return corrected_formulas
111+
112+
except Exception as e:
113+
print(f"[DEBUG] API 修正异常: {type(e).__name__}: {e}")
114+
print(f"[DEBUG] 回退到原始正则结果")
115+
return regex_formulas

0 commit comments

Comments
 (0)