diff --git a/src/open_r1/configs.py b/src/open_r1/configs.py index ddb6e53b0..87bc7f466 100644 --- a/src/open_r1/configs.py +++ b/src/open_r1/configs.py @@ -229,6 +229,8 @@ class GRPOScriptArguments(ScriptArguments): Maximum number of tokens in completion. soft_punish_cache (`int`): Minimum number of tokens in completion. + txt_language (`str`): + Language for lang_consistency reward. """ reward_funcs: list[str] = field( @@ -329,3 +331,10 @@ class GRPOScriptArguments(ScriptArguments): default=4096, metadata={"help": "Minimum number of characters in completion."}, ) + + txt_language: str = field( + default="en", + metadata={ + "help": "Language for lang_consistency reward. Based on langdetect supported languages https://pypi.org/project/langdetect/" + }, + ) diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py index 0b3662841..097d3e8ef 100644 --- a/src/open_r1/rewards.py +++ b/src/open_r1/rewards.py @@ -22,6 +22,7 @@ from functools import partial, update_wrapper from typing import Callable, Dict, Literal, Optional +from langdetect import detect from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify @@ -643,6 +644,47 @@ def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) - return soft_overlong_punishment_reward +def get_lang_consistency_reward(language: str = "en"): + """ + Reward function that evaluates whether the language of a generated text matches a target language. + + Args: + language (str): The default target language code (https://pypi.org/project/langdetect/). + """ + + CLEAN_PATTERN = re.compile( + # 1. Remove XML-like tags such as , + r'<[\/]?(think|answer)[^>]*>' + # 2. Remove code blocks (both ```...``` and `...`) + r'|```[\s\S]*?```|`[^`]*?`' + # 3. Remove LaTeX math blocks (e.g., $...$, $$...$$, \[...\], \(...\)) + r'|[\$]+(?:(?![\$]+)[\s\S])*[\$]+|\\\[.*?\\\]|\\\(.*?\\\)', + flags=re.DOTALL | re.MULTILINE + ) + + def clean_content(text): + return CLEAN_PATTERN.sub('', text).strip() + + def lang_consistency_reward(completions, **kwargs): + """Calculates language consistency scores for a batch of completions.""" + rewards = [] + target_languages = kwargs["language"] if "language" in kwargs else [language] * len(completions) + for completion, sample_language in zip(completions, target_languages): + try: + content = completion[0].get("content", "") + if not content: + rewards.append(None) + continue + cleaned_text = clean_content(content) + detected_lang = detect(cleaned_text) + rewards.append(1.0 if sample_language == detected_lang else 0.0) + except Exception as e: + rewards.append(None) + return rewards + + return lang_consistency_reward + + def get_reward_funcs(script_args) -> list[Callable]: REWARD_FUNCS_REGISTRY = { "accuracy": accuracy_reward, @@ -700,6 +742,7 @@ def get_reward_funcs(script_args) -> list[Callable]: max_completion_len=script_args.max_completion_len, soft_punish_cache=script_args.soft_punish_cache, ), + "lang_consistency": get_lang_consistency_reward(language=script_args.txt_language), } reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]