Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/multi_extractor_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ def all_extractor_comparison():
print("\n=== 多抽取器对比演示 ===\n")

# 创建数据集
dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl")
dataset_path = Path("../data/sample_dataset.jsonl")
dataset = DataLoader.load_jsonl(dataset_path)

# 创建webkit抽取器
config = {
"use_preprocessed_html": True, # 🔑 关键配置:启用预处理HTML模式
"preprocessed_html_field": "llm_webkit_html" # 指定预处理HTML字段名
}

webkit_extractor = ExtractorFactory.create("llm-webkit", config=config)
# 创建magic-extractor抽取器
magic_extractor = ExtractorFactory.create("magic-html")
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ https://github.com/opendatalab/magic-html/releases/download/magic_html-0.1.5-rel
streamlit
markdown
jieba
apted
apted
openai
8 changes: 8 additions & 0 deletions webmainbench/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from .teds_metrics import TEDSMetric, StructureTEDSMetric
from .calculator import MetricCalculator
from .mainhtml_calculator import MainHTMLMetricCalculator
from .base_extractor import ContentExtractor
from .formula_extractor import FormulaExtractor
from .code_extractor import CodeExtractor
from .table_extractor import TableExtractor

__all__ = [
"BaseMetric",
Expand All @@ -27,4 +31,8 @@
"TextEditMetric",
"MetricCalculator",
"MainHTMLMetricCalculator",
'ContentExtractor',
'FormulaExtractor',
'CodeExtractor',
'TableExtractor',
]
248 changes: 26 additions & 222 deletions webmainbench/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Any, List, Optional, Union
import traceback
import re
from bs4 import BeautifulSoup
import os
import hashlib
from typing import Dict, Any, List, Optional

@dataclass
class MetricResult:
Expand Down Expand Up @@ -144,7 +139,7 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_na

# 从markdown文本中提取,传递字段名称
return BaseMetric._extract_from_markdown(text or "", field_name=field_name)

@staticmethod
def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]:
"""从content_list中递归提取各种类型的内容"""
Expand Down Expand Up @@ -194,233 +189,42 @@ def _recursive_extract(items):
'table': '\n'.join(extracted['table']),
'text': '\n'.join(extracted['text'])
}

@staticmethod
def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
"""从markdown文本中提取各种类型的内容"""
if not text:
return {'code': '', 'formula': '', 'table': '', 'text': ''}

# 收集所有需要移除的内容片段
extracted_segments = []
code_parts = []
# # 同匹配行间代码块 ```...```
# pattern = r'(```[\s\S]*?```)'
# for match in re.finditer(pattern, text):
# code_segment = match.group(0)
# extracted_segments.append(code_segment)
#
# if code_segment.startswith('```'):
# # 处理代码块(保留内部缩进)
# lines = code_segment.split('\n')
# # 移除首尾的```标记
# content_lines = lines[1:-1]
# # 保留原始缩进,只拼接内容
# code_content = '\n'.join(content_lines)
# else:
# # 处理行内代码(只去除外层`和前后空格)
# code_content = code_segment[1:-1].strip()
#
# if code_content: # 只添加非空内容
# code_parts.append(code_content)

# 1. 首先处理三个反引号包裹的代码块(优先级最高)
backtick_pattern = r'(```[\s\S]*?```)'
for match in re.finditer(backtick_pattern, text):
code_segment = match.group(0)

if code_segment.startswith('```'):
# 处理代码块
lines = code_segment.split('\n')
# 移除首尾的```标记
content_lines = lines[1:-1]
code_content = '\n'.join(content_lines)
else:
# 处理行内代码
code_content = code_segment[1:-1].strip()

if code_content:
code_parts.append(code_content)

# 2. 处理缩进代码块 - 使用更精确的匹配
# 匹配模式:前面有空行 + 连续的多行缩进内容 + 后面有空行
# 关键:要求所有匹配的行都是缩进的
indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)'

for match in re.finditer(indent_pattern, text, re.MULTILINE):
code_segment = match.group(1)

# 验证:确保所有行都是缩进的(避免混合缩进和非缩进行)
lines = code_segment.split('\n')
all_indented = all(
line.startswith(' ') or line.startswith('\t') or not line.strip()
for line in lines
if line.strip() # 空行不算
)

if not all_indented:
continue # 跳过包含非缩进行的块

# 进一步验证代码特征
non_empty_lines = [line.strip() for line in lines if line.strip()]
if len(non_empty_lines) < 2: # 至少2行非空内容
continue

# 检查是否有明显的非代码特征
has_list_features = any(
re.match(r'^[-•*]\s', line) or
re.match(r'^\d+\.\s', line) or
re.search(r'\$[\d,]', line) or
re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE)
for line in non_empty_lines
)

if has_list_features:
continue # 跳过列表内容

# 清理代码段
cleaned_lines = []
for line in code_segment.split('\n'):
if line.strip():
if line.startswith(' '):
cleaned_lines.append(line[4:])
elif line.startswith('\t'):
cleaned_lines.append(line[1:])
else:
cleaned_lines.append(line)

code_content = '\n'.join(cleaned_lines)
if code_content.strip():
code_parts.append(code_content)

# 提取公式 - 新的两步处理逻辑
formula_parts = []

# 第一步:先用正则提取公式
regex_formulas = []
latex_patterns = [
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
]

for pattern in latex_patterns:
for match in re.finditer(pattern, text, re.DOTALL):
formula_full = match.group(0)
formula_content = match.group(1)
extracted_segments.append(formula_full)
if formula_content.strip():
regex_formulas.append(formula_content.strip())

# 第二步:根据字段类型决定是否需要API修正
if field_name == "groundtruth_content":
print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式")
formula_parts = regex_formulas
else:
print(f"[DEBUG] 检测到llm_webkit_md内容,使用正则+API修正模式")
# 对于llm_webkit_md,将正则结果传递给API进行修正
if regex_formulas:
# 将正则提取的公式作为输入传递给API
regex_formulas_text = '\n'.join(regex_formulas)
print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正")

cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
os.makedirs(cache_dir, exist_ok=True)

# 使用正则结果的哈希作为缓存文件名
text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest()
cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json')

try:
from .formula_extractor import correct_formulas_with_llm
corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file)
formula_parts = corrected_formulas
print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式")
except Exception as e:
print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果")
formula_parts = regex_formulas
else:
print(f"[DEBUG] 正则未提取到公式,跳过API修正")
formula_parts = []

# 提取表格
table_parts = []

# ===== 1. 提取 HTML 表格 =====
# 用 BeautifulSoup 替代正则,防止嵌套或匹配不全
soup = BeautifulSoup(text, "html.parser")
for table in soup.find_all("table"):
# 判断当前表格的父级是否是表格内的标签(<td>、<tr>、<tbody>等)
parent_is_table_related = table.find_parent(["td", "tr", "tbody", "table"]) is not None
if not parent_is_table_related: # 父级不是表格相关标签 → 是外层表格
html_table = str(table)
extracted_segments.append(html_table)
table_parts.append(html_table)

# ===== 2. 提取 Markdown 表格 =====
lines = text.split('\n')
table_lines = []
in_markdown_table = False
found_separator = False # 是否已找到分隔行

def is_md_table_line(line):
"""判断是否可能是 Markdown 表格行"""
if line.count("|") < 1: # 至少三个竖线
return False
return True

def is_md_separator_line(line):
"""判断是否为 Markdown 分隔行"""
parts = [p.strip() for p in line.split("|")]
# 检查是否所有部分都是分隔符格式
for p in parts:
if p and not re.match(r"^:?\-{3,}:?$", p):
return False
return True
# 创建提取器配置
config = {
'llm_base_url': '',
'llm_api_key': '',
'llm_model': '',
'use_llm': False # 使用时改为True
}

def save_table():
"""保存当前表格并清空缓存"""
nonlocal table_lines
# 只有当表格行数大于等于2,且第二行是分隔行时才保存
if len(table_lines) >= 2 and is_md_separator_line(table_lines[1]):
md_table = '\n'.join(table_lines)
extracted_segments.append(md_table)
table_parts.append(md_table)
# 直接创建具体的提取器实例
from .code_extractor import CodeExtractor
from .formula_extractor import FormulaExtractor
from .table_extractor import TableExtractor

for line in lines:
if is_md_table_line(line):
table_lines.append(line)
in_markdown_table = True
if is_md_separator_line(line):
found_separator = True
else:
if in_markdown_table:
save_table()
table_lines = []
in_markdown_table = False
found_separator = False
code_extractor = CodeExtractor(config)
formula_extractor = FormulaExtractor(config)
table_extractor = TableExtractor(config)

# 处理文档末尾的 Markdown 表格
if in_markdown_table:
save_table()
# 提取各类内容
code_content = code_extractor.extract(text, field_name)
formula_content = formula_extractor.extract(text, field_name)
table_content = table_extractor.extract(text, field_name)

# 提取剩余文本(移除所有已提取的内容片段)
clean_text = text
for segment in extracted_segments:
clean_text = clean_text.replace(segment, '', 1)

# 清理多余的空行
clean_text = re.sub(r'\n\s*\n', '\n\n', clean_text)
clean_text = clean_text.strip()

return {
'code': '\n'.join(code_parts),
'formula': '\n'.join(formula_parts),
'table': '\n'.join(table_parts),
'text': text # 原始全部文本
'code': code_content,
'formula': formula_content,
'table': table_content,
'text': text # 保留原始全部文本
}

def aggregate_results(self, results: List[MetricResult]) -> MetricResult:
"""
Aggregate multiple metric results.
Expand Down
Loading