Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/multi_extractor_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def all_extractor_comparison():
print("\n=== 多抽取器对比演示 ===\n")

# 创建数据集
dataset_path = Path("data/sample_dataset.jsonl")
dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl")
dataset = DataLoader.load_jsonl(dataset_path)

# 创建webkit抽取器
Expand Down
4 changes: 2 additions & 2 deletions webmainbench/data/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,14 @@ def save_dataset_with_extraction(results: Union["EvaluationResult", Dict[str, An

# 解析预测值(predicted)
predicted_content = extraction_result.get('extracted_content', '')
predicted_parts = BaseMetric._extract_from_markdown(predicted_content) # 关键:解析预测内容
predicted_parts = BaseMetric._extract_from_markdown(predicted_content, field_name="llm_webkit_md") # 关键:解析预测内容
for part_type in ['code', 'formula', 'table', 'text']:
sample_dict[f'{current_extractor_name}_predicted_{part_type}'] = predicted_parts.get(part_type, '')

# 解析真实值(groundtruth)- 只需要解析一次
if extractor_names: # 只有当存在extractor时才解析
groundtruth_content = sample_dict.get('groundtruth_content', '')
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content) # 关键:解析真实内容
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content, field_name="groundtruth_content") # 关键:解析真实内容
for part_type in ['code', 'formula', 'table', 'text']:
# 使用第一个extractor的名字作为前缀,或者使用通用前缀
prefix = extractor_names[0] if len(extractor_names) == 1 else 'groundtruth'
Expand Down
88 changes: 56 additions & 32 deletions webmainbench/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import traceback
import re
from bs4 import BeautifulSoup
import os
import hashlib

@dataclass
class MetricResult:
Expand Down Expand Up @@ -121,16 +123,16 @@ def batch_calculate(self, predicted_list: List[Any],
result = self.calculate(pred, gt, **kwargs)
results.append(result)
return results

@staticmethod
def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[str, str]:
def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_name: str = None) -> Dict[str, str]:
"""
统一的内容分割方法,将文本分为代码、公式、表格和剩余文本4个部分。

Args:
text: 原始markdown文本
content_list: 结构化内容列表(来自llm-webkit等)

field_name: 当前处理的字段名称,传递给_extract_from_markdown
Returns:
Dict with keys: 'code', 'formula', 'table', 'text'
"""
Expand All @@ -139,9 +141,9 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[
extracted_content = BaseMetric._extract_from_content_list(content_list)
if any(extracted_content.values()):
return extracted_content
# 从markdown文本中提取
return BaseMetric._extract_from_markdown(text or "")

# 从markdown文本中提取,传递字段名称
return BaseMetric._extract_from_markdown(text or "", field_name=field_name)

@staticmethod
def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]:
Expand Down Expand Up @@ -193,12 +195,12 @@ def _recursive_extract(items):
'text': '\n'.join(extracted['text'])
}

@staticmethod
def _extract_from_markdown(text: str) -> Dict[str, str]:
@staticmethod
def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
"""从markdown文本中提取各种类型的内容"""
if not text:
return {'code': '', 'formula': '', 'table': '', 'text': ''}

# 收集所有需要移除的内容片段
extracted_segments = []
code_parts = []
Expand Down Expand Up @@ -291,34 +293,56 @@ def _extract_from_markdown(text: str) -> Dict[str, str]:
if code_content.strip():
code_parts.append(code_content)

# 提取公式
# 提取公式 - 新的两步处理逻辑
formula_parts = []
# 统一的公式提取模式

# 第一步:先用正则提取公式
regex_formulas = []
latex_patterns = [
# r'(?<!\\)\$\$([^$]+)\$\$(?!\\)', # Display math (not escaped)
# r'(?<!\\)\$([^$\n]+)\$(?![\\\$])', # Inline math (not escaped)
# r'\\begin\{equation\*?\}(.*?)\\end\{equation\*?\}', # Equation environment
# r'\\begin\{align\*?\}(.*?)\\end\{align\*?\}', # Align environment
# r'\\begin\{gather\*?\}(.*?)\\end\{gather\*?\}', # Gather environment
# r'\\begin\{eqnarray\*?\}(.*?)\\end\{eqnarray\*?\}', # Eqnarray environment
# r'\\begin\{multline\*?\}(.*?)\\end\{multline\*?\}', # Multline environment
# r'\\begin\{split\}(.*?)\\end\{split\}', # Split environment
# r'(?<!\\)\$\$([^$]+)\$\$(?!\\)',
# r'(?<!\\)\$([^$\n\w][^$\n]*[^$\n\w])\$(?![\\\$])',
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$,确保 $ 没有被转义
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\],确保 \ 没有被转义
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$,确保 $ 没有被转义
# r'(?<!\\)\$(.*?)(?<!\\)\$(?!\d)', # 第二个$后面不能是数字
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\),确保 \ 没有被转义
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
]

for pattern in latex_patterns:
for match in re.finditer(pattern, text, re.DOTALL):
formula_full = match.group(0) # 完整匹配(包含$符号)
formula_content = match.group(1) # 只是公式内容
formula_full = match.group(0)
formula_content = match.group(1)
extracted_segments.append(formula_full)
if formula_content.strip():
formula_parts.append(formula_content.strip())
regex_formulas.append(formula_content.strip())

# 第二步:根据字段类型决定是否需要API修正
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块逻辑感觉需要优化一下

if field_name == "groundtruth_content":
print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式")
formula_parts = regex_formulas
else:
print(f"[DEBUG] 检测到llm_webkit_md内容,使用正则+API修正模式")
# 对于llm_webkit_md,将正则结果传递给API进行修正
if regex_formulas:
# 将正则提取的公式作为输入传递给API
regex_formulas_text = '\n'.join(regex_formulas)
print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正")

cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
os.makedirs(cache_dir, exist_ok=True)

# 使用正则结果的哈希作为缓存文件名
text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest()
cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json')

try:
from .formula_extractor import correct_formulas_with_llm
corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file)
formula_parts = corrected_formulas
print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式")
except Exception as e:
print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果")
formula_parts = regex_formulas
else:
print(f"[DEBUG] 正则未提取到公式,跳过API修正")
formula_parts = []

# 提取表格
table_parts = []
Expand Down Expand Up @@ -468,4 +492,4 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}(name='{self.name}')"

def __repr__(self) -> str:
return self.__str__()
return self.__str__()
115 changes: 115 additions & 0 deletions webmainbench/metrics/formula_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# webmainbench/metrics/formula_extractor.py
import json
import os
from openai import OpenAI

def correct_formulas_with_llm(regex_formulas, cache_file=None):
"""使用LLM API修正正则提取的公式"""

if not regex_formulas:
print(f"[DEBUG] 输入公式列表为空,跳过API修正")
return []

# 检查缓存
if cache_file and os.path.exists(cache_file):
try:
with open(cache_file, 'r', encoding='utf-8') as f:
cached_result = json.load(f)
print(f"[DEBUG] 从缓存加载修正结果: {len(cached_result)} 个")
return cached_result
except Exception as e:
print(f"[DEBUG] 缓存读取失败: {e}")

# API配置
client = OpenAI(
base_url="",
api_key=""
)

# 将正则提取的公式转换为文本
formulas_text = '\n'.join(regex_formulas)

CORRECTION_PROMPT = '''任务:请从以下正则表达式提取的内容中,识别并保留真正的LaTeX数学公式,剔除货币形式的内容。

### 识别规则
**真正的数学公式**(保留):
- 包含数学符号:+ - × ÷ = < > ≤ ≥ ± ∞ ∑ ∫ ∂ √ ^ _ { } 等
- 包含希腊字母:α β γ δ θ λ μ π σ ω 等
- 包含LaTeX命令:\\frac \\sum \\int \\sqrt \\alpha \\beta \\sin \\cos 等
- 包含数学表达式:变量、函数、方程等

**货币形式内容**(剔除):
- 仅包含数字、逗号、小数点的价格:如 1,150.00
- 纯粹的金额数值:如 25.99、1,200、5.50
- 不包含任何数学运算符或数学符号的数字

### 处理要求
1. **严格区分**:只保留真正的数学公式,剔除所有货币价格
2. **格式标准化**:统一公式格式,确保LaTeX语法正确
3. **保持原意**:不修改数学公式内容

### 输出格式
- 每个有效的数学公式独占一行
- 只输出公式内容,不包含$符号或其他包装
- 如果输入不是有效的数学公式(如货币),则输出<空>
- 按原顺序输出保留的公式

### 示例 1 (剔除后有有效公式)
输入:1,150.00 → 剔除(货币)
输入:x^2 + y^2 = r^2 → 保留(数学公式)
输入:25.99 → 剔除(货币)
输入:\\frac{a}{b} + c → 保留(数学公式)

### 示例 2 (剔除后无有效公式)
输入:1,150.00 → 剔除(货币)
输入:25.99 → 剔除(货币)

输出:<空>

注意,输出结果中不要添加任何解释!。
[输入内容列表开始]'''

try:
print(f"[DEBUG] 开始调用 OpenAI API 进行公式修正...")
response = client.chat.completions.create(
model="deepseek-chat",
temperature=0,
messages=[
{"role": "user", "content": CORRECTION_PROMPT + f"\n{formulas_text}\n" + '''[输入内容列表结束]
---
请按要求识别并输出真正的数学公式,剔除货币形式的内容。
---'''}
]
)

result_text = response.choices[0].message.content.strip()
print(f"[DEBUG] API 返回修正结果: {repr(result_text)}")

# 检测返回内容是否包含"空"字 - 如果包含则整个结果为空
if '空' in result_text:
print(f"[DEBUG] 检测到API返回包含'空'字,将整个结果设置为空列表")
corrected_formulas = []
elif not result_text:
corrected_formulas = []
else:
# 正常解析返回的公式列表
corrected_formulas = [line.strip() for line in result_text.split('\n') if line.strip()]

print(f"[DEBUG] 修正后的公式列表: {corrected_formulas}")

# 保存缓存
if cache_file:
try:
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(corrected_formulas, f, ensure_ascii=False, indent=2)
print(f"[DEBUG] 修正结果已缓存到: {cache_file}")
except Exception as e:
print(f"[DEBUG] 缓存保存失败: {e}")

return corrected_formulas

except Exception as e:
print(f"[DEBUG] API 修正异常: {type(e).__name__}: {e}")
print(f"[DEBUG] 回退到原始正则结果")
return regex_formulas