88import traceback
99import re
1010from bs4 import BeautifulSoup
11+ import os
12+ import hashlib
1113
1214@dataclass
1315class MetricResult :
@@ -121,16 +123,16 @@ def batch_calculate(self, predicted_list: List[Any],
121123 result = self .calculate (pred , gt , ** kwargs )
122124 results .append (result )
123125 return results
124-
126+
125127 @staticmethod
126- def split_content (text : str , content_list : List [Dict [str , Any ]] = None ) -> Dict [str , str ]:
128+ def split_content (text : str , content_list : List [Dict [str , Any ]] = None , field_name : str = None ) -> Dict [str , str ]:
127129 """
128130 统一的内容分割方法,将文本分为代码、公式、表格和剩余文本4个部分。
129-
131+
130132 Args:
131133 text: 原始markdown文本
132134 content_list: 结构化内容列表(来自llm-webkit等)
133-
135+ field_name: 当前处理的字段名称,传递给_extract_from_markdown
134136 Returns:
135137 Dict with keys: 'code', 'formula', 'table', 'text'
136138 """
@@ -139,9 +141,9 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[
139141 extracted_content = BaseMetric ._extract_from_content_list (content_list )
140142 if any (extracted_content .values ()):
141143 return extracted_content
142-
143- # 从markdown文本中提取
144- return BaseMetric ._extract_from_markdown (text or "" )
144+
145+ # 从markdown文本中提取,传递字段名称
146+ return BaseMetric ._extract_from_markdown (text or "" , field_name = field_name )
145147
146148 @staticmethod
147149 def _extract_from_content_list (content_list : List [Dict [str , Any ]]) -> Dict [str , str ]:
@@ -193,12 +195,12 @@ def _recursive_extract(items):
193195 'text' : '\n ' .join (extracted ['text' ])
194196 }
195197
196- @staticmethod
197- def _extract_from_markdown (text : str ) -> Dict [str , str ]:
198+ @staticmethod
199+ def _extract_from_markdown (text : str , field_name : str = None ) -> Dict [str , str ]:
198200 """从markdown文本中提取各种类型的内容"""
199201 if not text :
200202 return {'code' : '' , 'formula' : '' , 'table' : '' , 'text' : '' }
201-
203+
202204 # 收集所有需要移除的内容片段
203205 extracted_segments = []
204206 code_parts = []
@@ -291,34 +293,56 @@ def _extract_from_markdown(text: str) -> Dict[str, str]:
291293 if code_content .strip ():
292294 code_parts .append (code_content )
293295
294- # 提取公式
296+ # 提取公式 - 新的两步处理逻辑
295297 formula_parts = []
296- # 统一的公式提取模式
298+
299+ # 第一步:先用正则提取公式
300+ regex_formulas = []
297301 latex_patterns = [
298- # r'(?<!\\)\$\$([^$]+)\$\$(?!\\)', # Display math (not escaped)
299- # r'(?<!\\)\$([^$\n]+)\$(?![\\\$])', # Inline math (not escaped)
300- # r'\\begin\{equation\*?\}(.*?)\\end\{equation\*?\}', # Equation environment
301- # r'\\begin\{align\*?\}(.*?)\\end\{align\*?\}', # Align environment
302- # r'\\begin\{gather\*?\}(.*?)\\end\{gather\*?\}', # Gather environment
303- # r'\\begin\{eqnarray\*?\}(.*?)\\end\{eqnarray\*?\}', # Eqnarray environment
304- # r'\\begin\{multline\*?\}(.*?)\\end\{multline\*?\}', # Multline environment
305- # r'\\begin\{split\}(.*?)\\end\{split\}', # Split environment
306- # r'(?<!\\)\$\$([^$]+)\$\$(?!\\)',
307- # r'(?<!\\)\$([^$\n\w][^$\n]*[^$\n\w])\$(?![\\\$])',
308- r'(?<!\\)\$\$(.*?)(?<!\\)\$\$' , # 行间 $$...$$,确保 $ 没有被转义
309- r'(?<!\\)\\\[(.*?)(?<!\\)\\\]' , # 行间 \[...\],确保 \ 没有被转义
310- r'(?<!\\)\$(.*?)(?<!\\)\$' , # 行内 $...$,确保 $ 没有被转义
311- # r'(?<!\\)\$(.*?)(?<!\\)\$(?!\d)', # 第二个$后面不能是数字
312- r'(?<!\\)\\\((.*?)(?<!\\)\\\)' , # 行内 \(...\),确保 \ 没有被转义
302+ r'(?<!\\)\$\$(.*?)(?<!\\)\$\$' , # 行间 $$...$$
303+ r'(?<!\\)\\\[(.*?)(?<!\\)\\\]' , # 行间 \[...\]
304+ r'(?<!\\)\$(.*?)(?<!\\)\$' , # 行内 $...$
305+ r'(?<!\\)\\\((.*?)(?<!\\)\\\)' , # 行内 \(...\)
313306 ]
314-
307+
315308 for pattern in latex_patterns :
316309 for match in re .finditer (pattern , text , re .DOTALL ):
317- formula_full = match .group (0 ) # 完整匹配(包含$符号)
318- formula_content = match .group (1 ) # 只是公式内容
310+ formula_full = match .group (0 )
311+ formula_content = match .group (1 )
319312 extracted_segments .append (formula_full )
320313 if formula_content .strip ():
321- formula_parts .append (formula_content .strip ())
314+ regex_formulas .append (formula_content .strip ())
315+
316+ # 第二步:根据字段类型决定是否需要API修正
317+ if field_name == "groundtruth_content" :
318+ print (f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式" )
319+ formula_parts = regex_formulas
320+ else :
321+ print (f"[DEBUG] 检测到llm_webkit_md内容,使用正则+API修正模式" )
322+ # 对于llm_webkit_md,将正则结果传递给API进行修正
323+ if regex_formulas :
324+ # 将正则提取的公式作为输入传递给API
325+ regex_formulas_text = '\n ' .join (regex_formulas )
326+ print (f"[DEBUG] 正则提取到 { len (regex_formulas )} 个公式,准备API修正" )
327+
328+ cache_dir = os .path .join (os .path .dirname (os .path .abspath (__file__ )), '.cache' )
329+ os .makedirs (cache_dir , exist_ok = True )
330+
331+ # 使用正则结果的哈希作为缓存文件名
332+ text_hash = hashlib .md5 (regex_formulas_text .encode ('utf-8' )).hexdigest ()
333+ cache_file = os .path .join (cache_dir , f'formula_correction_cache_{ text_hash } .json' )
334+
335+ try :
336+ from .formula_extractor import correct_formulas_with_llm
337+ corrected_formulas = correct_formulas_with_llm (regex_formulas , cache_file )
338+ formula_parts = corrected_formulas
339+ print (f"[DEBUG] API修正成功,最终得到 { len (formula_parts )} 个公式" )
340+ except Exception as e :
341+ print (f"[DEBUG] API修正失败: { type (e ).__name__ } : { e } ,使用正则结果" )
342+ formula_parts = regex_formulas
343+ else :
344+ print (f"[DEBUG] 正则未提取到公式,跳过API修正" )
345+ formula_parts = []
322346
323347 # 提取表格
324348 table_parts = []
@@ -468,4 +492,4 @@ def __str__(self) -> str:
468492 return f"{ self .__class__ .__name__ } (name='{ self .name } ')"
469493
470494 def __repr__ (self ) -> str :
471- return self .__str__ ()
495+ return self .__str__ ()
0 commit comments