Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/pass_rate_filtering/compute_pass_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
from dataclasses import dataclass
from git import Optional
from typing import Optional
import torch
import sys

Expand Down
90 changes: 90 additions & 0 deletions src/open_r1/utils/code_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

import abc
import asyncio
import os
import tempfile
import textwrap
from asyncio.subprocess import PIPE
from typing import List, Optional

from ..utils import is_e2b_available, is_morph_available
Expand Down Expand Up @@ -60,6 +64,90 @@ def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[floa
pass


class LocalProvider(CodeExecutionProvider):
"""Lightweight local execution provider for development.

WARNING: This executes code on the local machine and is intended only for
development/testing. Do not use with untrusted inputs.

Current implementation supports Python evaluation scripts only, matching
the evaluation script template constructed in `code_reward`.
"""

def __init__(self, num_parallel: int = 2):
self.num_parallel = num_parallel

def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]:
try:
return asyncio.run(self._run_async(scripts, languages, self.num_parallel))
except Exception:
return [0.0] * len(scripts)

async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]:
semaphore = asyncio.Semaphore(num_parallel)
tasks = [self._run_script(script, lang, semaphore) for script, lang in zip(scripts, languages)]
results = await asyncio.gather(*tasks)
return results

async def _run_script(self, script: str, language: str, semaphore: asyncio.Semaphore) -> float:
# Only Python evaluation scripts are supported in local mode
if language.lower() not in {"python", "py", "python3"}:
return 0.0

# Ensure the script prints the numeric result. The original template
# ends with `evaluate_code(code_snippet, test_cases)` which returns a value
# in notebook-like environments but doesn't print locally. We append a
# print to make the score available on stdout.
appended = textwrap.dedent(
"""
try:
__or1_res = evaluate_code(code_snippet, test_cases)
print(__or1_res)
except Exception:
pass
"""
)

tmp_path = None
async with semaphore:
try:
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
tmp_path = f.name
f.write(script)
f.write("\n\n")
f.write(appended)

# Use the current python executable
python_exe = os.environ.get("PYTHON", None)
if not python_exe:
# Fallback to 'python' which resolves to the active interpreter in most envs
python_exe = "python"

proc = await asyncio.create_subprocess_exec(
python_exe,
tmp_path,
stdout=PIPE,
stderr=PIPE,
)
stdout, _ = await proc.communicate()
text = stdout.decode(errors="ignore").strip()
# parse last non-empty line as float
for line in reversed([l for l in text.splitlines() if l.strip() != ""]):
try:
return float(line.strip())
except ValueError:
continue
return 0.0
except Exception:
return 0.0
finally:
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except Exception:
pass


class E2BProvider(CodeExecutionProvider):
"""Provider that executes code using E2B sandboxes."""

Expand Down Expand Up @@ -362,5 +450,7 @@ def get_provider(provider_type: str = "e2b", **kwargs) -> CodeExecutionProvider:
num_parallel=num_parallel,
morph_router_url=morph_router_url,
)
elif provider_type == "local":
return LocalProvider(num_parallel=num_parallel)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
4 changes: 3 additions & 1 deletion src/open_r1/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


# Use same as transformers.utils.import_utils
_e2b_available = _is_package_available("e2b")
# The E2B code interpreter package is installed as "e2b-code-interpreter"
# and imported as "e2b_code_interpreter"; check availability by import name.
_e2b_available = _is_package_available("e2b_code_interpreter")


def is_e2b_available() -> bool:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_local_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest

from open_r1.rewards import code_reward


class TestLocalProvider(unittest.TestCase):
def test_local_python_code_reward(self):
# Two samples: one correct, one incorrect
completions = [
[{"content": "```python\nprint('hello')\n```"}],
[{"content": "```python\nprint('bye')\n```"}],
]

verification_info = [
{
"language": "python",
"test_cases": [
{
"input": "",
"output": "hello",
"type": "stdin_stdout",
}
],
},
{
"language": "python",
"test_cases": [
{
"input": "",
"output": "hello",
"type": "stdin_stdout",
}
],
},
]

rewards = code_reward(
completions,
provider_type="local",
verification_info=verification_info,
num_parallel=2,
)

self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)


if __name__ == "__main__":
unittest.main()