diff --git a/examples/multi_extractor_compare.py b/examples/multi_extractor_compare.py index 4affe52..b2f3258 100644 --- a/examples/multi_extractor_compare.py +++ b/examples/multi_extractor_compare.py @@ -8,7 +8,7 @@ def all_extractor_comparison(): print("\n=== 多抽取器对比演示 ===\n") # 创建数据集 - dataset_path = Path("data/sample_dataset.jsonl") + dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl") dataset = DataLoader.load_jsonl(dataset_path) # 创建webkit抽取器 diff --git a/webmainbench/data/saver.py b/webmainbench/data/saver.py index e88a5ef..247dc4f 100644 --- a/webmainbench/data/saver.py +++ b/webmainbench/data/saver.py @@ -303,14 +303,14 @@ def save_dataset_with_extraction(results: Union["EvaluationResult", Dict[str, An # 解析预测值(predicted) predicted_content = extraction_result.get('extracted_content', '') - predicted_parts = BaseMetric._extract_from_markdown(predicted_content) # 关键:解析预测内容 + predicted_parts = BaseMetric._extract_from_markdown(predicted_content, field_name="llm_webkit_md") # 关键:解析预测内容 for part_type in ['code', 'formula', 'table', 'text']: sample_dict[f'{current_extractor_name}_predicted_{part_type}'] = predicted_parts.get(part_type, '') # 解析真实值(groundtruth)- 只需要解析一次 if extractor_names: # 只有当存在extractor时才解析 groundtruth_content = sample_dict.get('groundtruth_content', '') - groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content) # 关键:解析真实内容 + groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content, field_name="groundtruth_content") # 关键:解析真实内容 for part_type in ['code', 'formula', 'table', 'text']: # 使用第一个extractor的名字作为前缀,或者使用通用前缀 prefix = extractor_names[0] if len(extractor_names) == 1 else 'groundtruth' diff --git a/webmainbench/metrics/base.py b/webmainbench/metrics/base.py index 7f6019e..10a7acb 100644 --- a/webmainbench/metrics/base.py +++ b/webmainbench/metrics/base.py @@ -8,6 +8,8 @@ import traceback import re from bs4 import BeautifulSoup +import os +import hashlib @dataclass class MetricResult: @@ -121,16 +123,16 @@ def batch_calculate(self, predicted_list: List[Any], result = self.calculate(pred, gt, **kwargs) results.append(result) return results - + @staticmethod - def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[str, str]: + def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_name: str = None) -> Dict[str, str]: """ 统一的内容分割方法,将文本分为代码、公式、表格和剩余文本4个部分。 - + Args: text: 原始markdown文本 content_list: 结构化内容列表(来自llm-webkit等) - + field_name: 当前处理的字段名称,传递给_extract_from_markdown Returns: Dict with keys: 'code', 'formula', 'table', 'text' """ @@ -139,9 +141,9 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[ extracted_content = BaseMetric._extract_from_content_list(content_list) if any(extracted_content.values()): return extracted_content - - # 从markdown文本中提取 - return BaseMetric._extract_from_markdown(text or "") + + # 从markdown文本中提取,传递字段名称 + return BaseMetric._extract_from_markdown(text or "", field_name=field_name) @staticmethod def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]: @@ -193,12 +195,12 @@ def _recursive_extract(items): 'text': '\n'.join(extracted['text']) } - @staticmethod - def _extract_from_markdown(text: str) -> Dict[str, str]: + @staticmethod + def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]: """从markdown文本中提取各种类型的内容""" if not text: return {'code': '', 'formula': '', 'table': '', 'text': ''} - + # 收集所有需要移除的内容片段 extracted_segments = [] code_parts = [] @@ -291,34 +293,56 @@ def _extract_from_markdown(text: str) -> Dict[str, str]: if code_content.strip(): code_parts.append(code_content) - # 提取公式 + # 提取公式 - 新的两步处理逻辑 formula_parts = [] - # 统一的公式提取模式 + + # 第一步:先用正则提取公式 + regex_formulas = [] latex_patterns = [ - # r'(? str: return f"{self.__class__.__name__}(name='{self.name}')" def __repr__(self) -> str: - return self.__str__() \ No newline at end of file + return self.__str__() diff --git a/webmainbench/metrics/formula_extractor.py b/webmainbench/metrics/formula_extractor.py new file mode 100644 index 0000000..6ef1e4d --- /dev/null +++ b/webmainbench/metrics/formula_extractor.py @@ -0,0 +1,115 @@ +# webmainbench/metrics/formula_extractor.py +import json +import os +from openai import OpenAI + +def correct_formulas_with_llm(regex_formulas, cache_file=None): + """使用LLM API修正正则提取的公式""" + + if not regex_formulas: + print(f"[DEBUG] 输入公式列表为空,跳过API修正") + return [] + + # 检查缓存 + if cache_file and os.path.exists(cache_file): + try: + with open(cache_file, 'r', encoding='utf-8') as f: + cached_result = json.load(f) + print(f"[DEBUG] 从缓存加载修正结果: {len(cached_result)} 个") + return cached_result + except Exception as e: + print(f"[DEBUG] 缓存读取失败: {e}") + + # API配置 + client = OpenAI( + base_url="", + api_key="" + ) + + # 将正则提取的公式转换为文本 + formulas_text = '\n'.join(regex_formulas) + + CORRECTION_PROMPT = '''任务:请从以下正则表达式提取的内容中,识别并保留真正的LaTeX数学公式,剔除货币形式的内容。 + + ### 识别规则 + **真正的数学公式**(保留): + - 包含数学符号:+ - × ÷ = < > ≤ ≥ ± ∞ ∑ ∫ ∂ √ ^ _ { } 等 + - 包含希腊字母:α β γ δ θ λ μ π σ ω 等 + - 包含LaTeX命令:\\frac \\sum \\int \\sqrt \\alpha \\beta \\sin \\cos 等 + - 包含数学表达式:变量、函数、方程等 + + **货币形式内容**(剔除): + - 仅包含数字、逗号、小数点的价格:如 1,150.00 + - 纯粹的金额数值:如 25.99、1,200、5.50 + - 不包含任何数学运算符或数学符号的数字 + + ### 处理要求 + 1. **严格区分**:只保留真正的数学公式,剔除所有货币价格 + 2. **格式标准化**:统一公式格式,确保LaTeX语法正确 + 3. **保持原意**:不修改数学公式内容 + + ### 输出格式 + - 每个有效的数学公式独占一行 + - 只输出公式内容,不包含$符号或其他包装 + - 如果输入不是有效的数学公式(如货币),则输出<空> + - 按原顺序输出保留的公式 + + ### 示例 1 (剔除后有有效公式) + 输入:1,150.00 → 剔除(货币) + 输入:x^2 + y^2 = r^2 → 保留(数学公式) + 输入:25.99 → 剔除(货币) + 输入:\\frac{a}{b} + c → 保留(数学公式) + + ### 示例 2 (剔除后无有效公式) + 输入:1,150.00 → 剔除(货币) + 输入:25.99 → 剔除(货币) + + 输出:<空> + + 注意,输出结果中不要添加任何解释!。 + [输入内容列表开始]''' + + try: + print(f"[DEBUG] 开始调用 OpenAI API 进行公式修正...") + response = client.chat.completions.create( + model="deepseek-chat", + temperature=0, + messages=[ + {"role": "user", "content": CORRECTION_PROMPT + f"\n{formulas_text}\n" + '''[输入内容列表结束] +--- +请按要求识别并输出真正的数学公式,剔除货币形式的内容。 +---'''} + ] + ) + + result_text = response.choices[0].message.content.strip() + print(f"[DEBUG] API 返回修正结果: {repr(result_text)}") + + # 检测返回内容是否包含"空"字 - 如果包含则整个结果为空 + if '空' in result_text: + print(f"[DEBUG] 检测到API返回包含'空'字,将整个结果设置为空列表") + corrected_formulas = [] + elif not result_text: + corrected_formulas = [] + else: + # 正常解析返回的公式列表 + corrected_formulas = [line.strip() for line in result_text.split('\n') if line.strip()] + + print(f"[DEBUG] 修正后的公式列表: {corrected_formulas}") + + # 保存缓存 + if cache_file: + try: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(corrected_formulas, f, ensure_ascii=False, indent=2) + print(f"[DEBUG] 修正结果已缓存到: {cache_file}") + except Exception as e: + print(f"[DEBUG] 缓存保存失败: {e}") + + return corrected_formulas + + except Exception as e: + print(f"[DEBUG] API 修正异常: {type(e).__name__}: {e}") + print(f"[DEBUG] 回退到原始正则结果") + return regex_formulas \ No newline at end of file