diff --git a/scripts/pass_rate_filtering/compute_pass_rate.py b/scripts/pass_rate_filtering/compute_pass_rate.py index dcc5286d3..57295c5c9 100644 --- a/scripts/pass_rate_filtering/compute_pass_rate.py +++ b/scripts/pass_rate_filtering/compute_pass_rate.py @@ -16,7 +16,7 @@ import logging from dataclasses import dataclass -from git import Optional +from typing import Optional import torch import sys diff --git a/src/open_r1/utils/code_providers.py b/src/open_r1/utils/code_providers.py index 71830e6ae..c34d4508f 100644 --- a/src/open_r1/utils/code_providers.py +++ b/src/open_r1/utils/code_providers.py @@ -17,6 +17,10 @@ import abc import asyncio +import os +import tempfile +import textwrap +from asyncio.subprocess import PIPE from typing import List, Optional from ..utils import is_e2b_available, is_morph_available @@ -60,6 +64,90 @@ def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[floa pass +class LocalProvider(CodeExecutionProvider): + """Lightweight local execution provider for development. + + WARNING: This executes code on the local machine and is intended only for + development/testing. Do not use with untrusted inputs. + + Current implementation supports Python evaluation scripts only, matching + the evaluation script template constructed in `code_reward`. + """ + + def __init__(self, num_parallel: int = 2): + self.num_parallel = num_parallel + + def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]: + try: + return asyncio.run(self._run_async(scripts, languages, self.num_parallel)) + except Exception: + return [0.0] * len(scripts) + + async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]: + semaphore = asyncio.Semaphore(num_parallel) + tasks = [self._run_script(script, lang, semaphore) for script, lang in zip(scripts, languages)] + results = await asyncio.gather(*tasks) + return results + + async def _run_script(self, script: str, language: str, semaphore: asyncio.Semaphore) -> float: + # Only Python evaluation scripts are supported in local mode + if language.lower() not in {"python", "py", "python3"}: + return 0.0 + + # Ensure the script prints the numeric result. The original template + # ends with `evaluate_code(code_snippet, test_cases)` which returns a value + # in notebook-like environments but doesn't print locally. We append a + # print to make the score available on stdout. + appended = textwrap.dedent( + """ + try: + __or1_res = evaluate_code(code_snippet, test_cases) + print(__or1_res) + except Exception: + pass + """ + ) + + tmp_path = None + async with semaphore: + try: + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: + tmp_path = f.name + f.write(script) + f.write("\n\n") + f.write(appended) + + # Use the current python executable + python_exe = os.environ.get("PYTHON", None) + if not python_exe: + # Fallback to 'python' which resolves to the active interpreter in most envs + python_exe = "python" + + proc = await asyncio.create_subprocess_exec( + python_exe, + tmp_path, + stdout=PIPE, + stderr=PIPE, + ) + stdout, _ = await proc.communicate() + text = stdout.decode(errors="ignore").strip() + # parse last non-empty line as float + for line in reversed([l for l in text.splitlines() if l.strip() != ""]): + try: + return float(line.strip()) + except ValueError: + continue + return 0.0 + except Exception: + return 0.0 + finally: + if tmp_path and os.path.exists(tmp_path): + try: + os.remove(tmp_path) + except Exception: + pass + + class E2BProvider(CodeExecutionProvider): """Provider that executes code using E2B sandboxes.""" @@ -362,5 +450,7 @@ def get_provider(provider_type: str = "e2b", **kwargs) -> CodeExecutionProvider: num_parallel=num_parallel, morph_router_url=morph_router_url, ) + elif provider_type == "local": + return LocalProvider(num_parallel=num_parallel) else: raise ValueError(f"Unknown provider type: {provider_type}") diff --git a/src/open_r1/utils/import_utils.py b/src/open_r1/utils/import_utils.py index 5d6624302..dd0c90b31 100644 --- a/src/open_r1/utils/import_utils.py +++ b/src/open_r1/utils/import_utils.py @@ -16,7 +16,9 @@ # Use same as transformers.utils.import_utils -_e2b_available = _is_package_available("e2b") +# The E2B code interpreter package is installed as "e2b-code-interpreter" +# and imported as "e2b_code_interpreter"; check availability by import name. +_e2b_available = _is_package_available("e2b_code_interpreter") def is_e2b_available() -> bool: diff --git a/tests/test_local_provider.py b/tests/test_local_provider.py new file mode 100644 index 000000000..84b7a0635 --- /dev/null +++ b/tests/test_local_provider.py @@ -0,0 +1,50 @@ +import unittest + +from open_r1.rewards import code_reward + + +class TestLocalProvider(unittest.TestCase): + def test_local_python_code_reward(self): + # Two samples: one correct, one incorrect + completions = [ + [{"content": "```python\nprint('hello')\n```"}], + [{"content": "```python\nprint('bye')\n```"}], + ] + + verification_info = [ + { + "language": "python", + "test_cases": [ + { + "input": "", + "output": "hello", + "type": "stdin_stdout", + } + ], + }, + { + "language": "python", + "test_cases": [ + { + "input": "", + "output": "hello", + "type": "stdin_stdout", + } + ], + }, + ] + + rewards = code_reward( + completions, + provider_type="local", + verification_info=verification_info, + num_parallel=2, + ) + + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 0.0) + + +if __name__ == "__main__": + unittest.main() +