Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New GRPO dataset and tasks: formally-verified program correctness #379

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
591f3c1
wip adding HTGen dataset and benchmark
Feb 20, 2025
69b44ca
add API test
Feb 20, 2025
b3a9587
wip adding HTGen dataset and benchmark
Feb 20, 2025
41af428
add API test
Feb 20, 2025
58cbde4
construct prompt in the dataset generator
Feb 22, 2025
aa3523d
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 22, 2025
cde339c
merge from upstream
Feb 22, 2025
2d225e9
prompt construction
Feb 22, 2025
d28e8f7
fix some typos and add more docstrings
Feb 22, 2025
40fdef8
add reward
Feb 22, 2025
4341395
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 23, 2025
3eae18c
fix typos
Feb 23, 2025
ce8d1b1
Merge branch 'feature/htgen-dataset' of github.com:unfoldml/open-r1 i…
Feb 23, 2025
5e2ea33
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 25, 2025
8535700
Merge branch 'feature/htgen-dataset' of github.com:unfoldml/open-r1 i…
Feb 25, 2025
410b4f9
add unit test for code rewards
Feb 25, 2025
465ae8c
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 25, 2025
fbd20c7
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 26, 2025
c567d55
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 26, 2025
d63b4a2
docstring
Feb 28, 2025
0c9732c
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 2, 2025
c84f645
fix makefile and tests
ocramz Mar 2, 2025
ff0db1e
fix code rewards test
ocramz Mar 2, 2025
1793479
add prompt and fix_triple reward
ocramz Mar 2, 2025
3303083
fix makefile to activate venv correctly
ocramz Mar 2, 2025
532a012
fix_triple task: add reward tests and docstrings
ocramz Mar 2, 2025
66969e8
readme
ocramz Mar 2, 2025
3f88b06
fix style and quality
ocramz Mar 2, 2025
604f66f
cannot reliably activate venv within makefile
ocramz Mar 2, 2025
ed0c484
ignore API json parsing errors
ocramz Mar 2, 2025
6e4298c
cleanup and docstrings
Mar 3, 2025
aa97e62
add test for verify v2 endpoint
ocramz Mar 3, 2025
334b4b0
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 5, 2025
3bd8689
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 7, 2025
40736bd
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 14, 2025
456543a
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 18, 2025
bbf700a
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 23, 2025
f3f2166
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@ export PYTHONPATH = src

check_dirs := src tests

############### Installation
# 1) install uv
install-uv:
curl -LsSf https://astral.sh/uv/install.sh | sh

# dev dependencies
install:
uv venv openr1 --python 3.11 && . openr1/bin/activate && uv pip install --upgrade pip
# 2) set up virtual environment
venv:
uv venv openr1 --python 3.11

# 3) activate virtual env
# # activate NB uses '.' instead of 'source' to work with dash/Codespaces as well)
# (. openr1/bin/activate)

# 4) install dev dependencies
install:
uv pip install --upgrade pip
uv pip install vllm==0.7.2
uv pip install setuptools
uv pip install flash-attn --no-build-isolation
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]"

############### Linting

style:
ruff format --line-length 119 --target-version py310 $(check_dirs) setup.py
isort $(check_dirs) setup.py
Expand All @@ -23,6 +37,8 @@ quality:
isort --check-only $(check_dirs) setup.py
flake8 --max-line-length 119 $(check_dirs) setup.py

############### Test

test:
pytest -sv --ignore=tests/slow/ tests/

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ To install `uv`, follow the [UV Installation Guide](https://docs.astral.sh/uv/ge


> [!NOTE]
> As a shortcut, run `make install` to setup development libraries (spelled out below). Afterwards, if everything is setup correctly you can try out the Open-R1 models.
> As a shortcut, run `make install` to setup `uv` and the development libraries (spelled out below). Afterwards, if everything is setup correctly and you have a functioning CUDA, you can install `flash-attn` and try out the Open-R1 models.


```shell
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"peft>=0.14.0",
"pytest",
"python-dotenv",
"requests",
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
Expand Down
143 changes: 143 additions & 0 deletions src/open_r1/rewards_internal/api/code/unfoldml/htgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from json import JSONDecodeError, loads

from requests import post
from requests.exceptions import HTTPError


api_server_url = "https://htgen.unfoldml.com"


def gen_triples(
n_examples: int,
max_ast_depth: int = 3,
n_stmt: int = 5,
n_pre_terms: int = 1,
n_post_terms: int = 1,
seed: int = 1234,
endpoint="/gen33",
):
"""
Yield program triples (Precondition, Statements, Postconditions) from the API,
together with their program traces, plus a initial variable environment and
whether they are totally correct ('ok_total'), or fail to satisfy either specification ('bad_pre', 'bad_post').
NB: in the backend, we have distinct REST endpoints for various combinations of (constant, mutable) variables, e.g. '33' and '55'.

:param n_examples: number of triples to generate
:param max_ast_depth: maximum AST depth of generated expressions
:param n_stmt: no. of statements in the generated program
:param n_pre_terms: no. of AND/OR terms in the generated pre-conditions
:param n_post_terms: no. of AND/OR terms in the generated post-conditions
:param seed: random seed for the PRNG
:param endpoint: REST endpoint of the request. '33' stands for 3 constants and 3 mutable identifiers
:returns: iterable of dict e.g.

{
'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], # starting program state
'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition
'label': 'bad_pre', # bad precondition (one of 'bad_pre', 'bad_post', 'ok_total')
'pre': 'v3 > (2 + v4)',
'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'],
'post': 'v3 > v4',
'prng_state_out': [1300, 1],
'rej_iters': 1, # number of rejection sampling iterations
'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds]
}
"""
cfg = {
"n_examples": n_examples,
"max_ast_depth": max_ast_depth,
"n_stmt": n_stmt,
"n_pre_terms": n_pre_terms,
"n_post_terms": n_post_terms,
"sm_gen_seed": seed,
"sm_gen_gamma": 1,
}
url = f"{api_server_url}/{endpoint}"
try:
res = post(url, json=cfg, stream=True)
res.raise_for_status()
for chunk in res.iter_lines(chunk_size=None, delimiter=b"\r\n"):
try:
v = loads(chunk)
if not isinstance(v, dict):
v = None
except JSONDecodeError:
v = None
if v is not None:
yield v
except HTTPError as he:
print(f"HTTP error: {he}")
raise he


def verify_triple(
is_total: bool,
preconditions: str = "True",
program: str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4",
postconditions: str = "v5 == (0 - v3)",
endpoint: str = "/prove33",
):
"""
Verify a program triple and compare with a model prediction
of whether the triple is totally correct or not.
NB: '33' stands for the number of constant and mutable identifiers in the program

:param is_total: inferred correctness label
:param preconditions:
:param program:
:param postconditions:
:returns: whether the SMT verifier agrees with the label provided:

{'prediction_is_correct': True}
"""
cfg = {
"pre": preconditions,
"program": program,
"post": postconditions,
"is_total": is_total,
}
url = f"{api_server_url}/{endpoint}"
try:
res = post(url, json=cfg, stream=True)
res.raise_for_status()
try:
v = res.json()
except JSONDecodeError:
v = None
return v
# else:
except HTTPError as he:
print(f"HTTP error: {he}")
raise he


def verify_triple_v2(
preconditions: str = "True",
program: str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4",
postconditions: str = "v5 == (0 - v3)",
endpoint: str = "/v2/prove33",
):
"""
Verify a program triple, V2 endpoint
NB: '33' stands for the number of constant and mutable identifiers in the program

:param preconditions:
:param program:
:param postconditions:
:returns: dict with the proof result. The 'result' key has values {'proven_total', 'proven_partial', 'indeterminate', 'failed'}

"""
triple = {"pre": preconditions, "program": program, "post": postconditions}
url = f"{api_server_url}/{endpoint}"
try:
res = post(url, json=triple, stream=True)
res.raise_for_status()
try:
v = res.json()
except JSONDecodeError:
v = None
return v
# else:
except HTTPError as he:
print(f"HTTP error: {he}")
raise he
Loading