Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New GRPO dataset and tasks: formally-verified program correctness #379

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
591f3c1
wip adding HTGen dataset and benchmark
Feb 20, 2025
69b44ca
add API test
Feb 20, 2025
b3a9587
wip adding HTGen dataset and benchmark
Feb 20, 2025
41af428
add API test
Feb 20, 2025
58cbde4
construct prompt in the dataset generator
Feb 22, 2025
aa3523d
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 22, 2025
cde339c
merge from upstream
Feb 22, 2025
2d225e9
prompt construction
Feb 22, 2025
d28e8f7
fix some typos and add more docstrings
Feb 22, 2025
40fdef8
add reward
Feb 22, 2025
4341395
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 23, 2025
3eae18c
fix typos
Feb 23, 2025
ce8d1b1
Merge branch 'feature/htgen-dataset' of github.com:unfoldml/open-r1 i…
Feb 23, 2025
5e2ea33
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 25, 2025
8535700
Merge branch 'feature/htgen-dataset' of github.com:unfoldml/open-r1 i…
Feb 25, 2025
410b4f9
add unit test for code rewards
Feb 25, 2025
465ae8c
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 25, 2025
fbd20c7
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 26, 2025
c567d55
Merge branch 'main' into feature/htgen-dataset
ocramz Feb 26, 2025
d63b4a2
docstring
Feb 28, 2025
0c9732c
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 2, 2025
c84f645
fix makefile and tests
ocramz Mar 2, 2025
ff0db1e
fix code rewards test
ocramz Mar 2, 2025
1793479
add prompt and fix_triple reward
ocramz Mar 2, 2025
3303083
fix makefile to activate venv correctly
ocramz Mar 2, 2025
532a012
fix_triple task: add reward tests and docstrings
ocramz Mar 2, 2025
66969e8
readme
ocramz Mar 2, 2025
3f88b06
fix style and quality
ocramz Mar 2, 2025
604f66f
cannot reliably activate venv within makefile
ocramz Mar 2, 2025
ed0c484
ignore API json parsing errors
ocramz Mar 2, 2025
6e4298c
cleanup and docstrings
Mar 3, 2025
aa97e62
add test for verify v2 endpoint
ocramz Mar 3, 2025
334b4b0
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 5, 2025
3bd8689
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 7, 2025
40736bd
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 14, 2025
456543a
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 18, 2025
bbf700a
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 23, 2025
f3f2166
Merge branch 'main' into feature/htgen-dataset
ocramz Mar 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"peft>=0.14.0",
"pytest",
"python-dotenv",
"requests",
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
Expand Down
130 changes: 130 additions & 0 deletions src/open_r1/rewards/api/code/unfoldml/htgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from json import loads, JSONDecodeError

from requests import Response, post
from requests.exceptions import HTTPError

api_server_url = "https://htgen.unfoldml.com"

def gen_triples_33(n_examples:int,
max_ast_depth:int = 3,
n_stmt:int = 5,
n_pre_terms:int = 1,
n_post_terms:int = 1,
seed:int = 1234,
endpoint = '/gen33',
):
"""
Yield program triples (Precondition, Statements, Postconditions) from the API,
together with their program traces plus a initial variable environment and
whether they are totally correct or they .
:param max_ast_depth: maximum AST depth of generated expressions
:param n_stmt: no. of statements in the generated program
:param n_pre_terms: no. of AND/OR terms in the generated pre-conditions
:param n_post_terms: no. of AND/OR terms in the generated post-conditions
:param seed: random seed for the PRNG
:returns: iterable of dict e.g.

{
'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], # starting program state
'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition
'label': 'bad_pre', # bad precondition (one of 'bad_pre', 'bad_post', 'ok_total')
'pre': 'v3 > (2 + v4)',
'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'],
'post': 'v3 > v4',
'prng_state_out': [1300, 1],
'rej_iters': 1, # number of rejection sampling iterations
'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds]
}
"""
cfg = {
"n_examples": n_examples,
"max_ast_depth": max_ast_depth,
"n_stmt": n_stmt,
"n_pre_terms": n_pre_terms,
"n_post_terms": n_post_terms,
"sm_gen_seed": seed,
"sm_gen_gamma": 1
}
url = f"{api_server_url}/{endpoint}"
try:
res = post(url, json= cfg, stream= True)
res.raise_for_status()
for chunk in res.iter_lines(chunk_size= None, delimiter=b"\r\n"):
try:
v = loads(chunk)
if not isinstance(v, dict):
v = None
except JSONDecodeError as e:
v = None
if v is not None:
yield v
except HTTPError as he:
print(f"HTTP error: {he}")
raise he


def verify_triple_33(
is_total:bool,
preconditions:str = "True",
program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4",
postconditions:str = "v5 == (0 - v3)",
endpoint:str = '/prove33'
):
"""
Verify a program triple and compare with a model prediction
of whether the triple is totally correct or not.
:param is_total: inferred correctness label
:param preconditions:
:param program:
:param postconditions:
:returns: whether the SMT verifier agrees with the label provided:

{'prediction_is_correct': True}
"""
cfg = {
"pre": preconditions,
"program": program,
"post": postconditions,
"is_total": is_total,
}
url = f"{api_server_url}/{endpoint}"
try:
res = post(url, json= cfg, stream= True)
res.raise_for_status()
try:
v = res.json()
except JSONDecodeError:
v = None
return v
# else:
except HTTPError as he:
print(f"HTTP error: {he}")
raise he




if __name__ == "__main__":
# # generate triples
for t in gen_triples_33(n_examples = 1):
print(t)
# {
# 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'],
# 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition
# 'label': 'bad_pre', # bad precondition
# 'pre': 'v3 > (2 + v4)',
# 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'],
# 'post': 'v3 > v4',
# 'prng_state_out': [1300, 1],
# 'rej_iters': 1, # number of rejection sampling iterations
# 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds]
# }

# # verify a triple against an inferred total correctness label
verify_triple_33(
is_total = True,
preconditions = "True",
program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4",
postconditions = "v5 == (0 - v3)"
)
# {'prediction_is_correct': True}
136 changes: 136 additions & 0 deletions src/open_r1/rewards/code/htgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from datasets import Dataset, IterableDataset

from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33


# # GRPOTrainer requires 1. a dataset and 2. a verification callback

# # # Tasks
# TOTALITY_CHECK : is the program "total"?
# FIX_PRE, FIX_POST, FIX_PROGRAM : modify either part of a triple to achieve either a total triple or other proof result

def quotes(s:str):
"""markdown triple backticks for a piece of code"""
return f"```{str}```"


# TOTALITY_CHECK task
def mk_row_totality_check(o):
"""
Construct the prompt
NB: the rows have a 'prompt' column as required by the GRPOTrainer interface:
https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset
"""
label = o['label']
pre = o['pre']
program = o['program'] # list of statements
post = o['post']

program_str = '\n'.join(program)

prompt_hdr = (
f"Below you are given a Python program triple, made of a precondition predicate, "
f"a sequence of program statements, and a post-condition predicate. "
f"The precondition returns True if the variable environment before beginning the "
f"program execution satisfies the predicate, and False otherwise. "
f"Similarly, the postcondition returns True if the program environment after the last "
f"statement satisfies the predicate, and False otherwise. "
f"Note that there might be unsatisfiable or contradictory predicates, that make the solution False by definition. "
f"With this information, you should judge whether the program is 'total', i.e. "
f"whether the post-condition evaluates to True for all possible variable assigments "
f"that satisfy the precondition."
)

prompt_question = (
f"Given a program triple made of program {quotes(program_str)}, "
f"preconditions {quotes(pre)} and postcondition {quotes(post)}, is the postcondition "
f"always True at the end of the program ? Please return 'True' or 'False'."
)

# # concatenate header and question into a prompt
prompt_problem = f"{prompt_hdr}\n{prompt_question}"

label_is_total = label == 'ok_total' # boolean

# # construct a row of the dataset
o_out = {
"prompt": prompt_problem,
"ground_truth": label_is_total,
"triple": {"pre": pre, "program":program, "post": post}
}

return o_out

def mk_dataset_totality_check(
max_ast_depth:int = 3,
n_stmt:int = 5,
n_pre_terms:int = 1,
n_post_terms:int = 1,
seed:int = 1234,
):
"""
construct an interable dataset for GRPOTrainer
"""
gen = gen_triples_33(
max_ast_depth = max_ast_depth,
n_stmt = n_stmt,
n_pre_terms = n_pre_terms,
n_post_terms = n_post_terms,
seed = seed,
)

# produce prompts from the raw API data
gen_prompts = (mk_row_totality_check(o) for o in gen if o is not None)

dataset = IterableDataset.from_generator(gen_prompts)

return dataset

def totality_check_reward(completions, ground_truth, **kwargs):
"""
verification callback for GRPOTRainer
:param completions: list of truthy values produced by the model
:param ground_truth: list of boolean ground truth values
:returns: list of float 1s or 0s with the prediction scores that match the ground truth
"""
if not isinstance(completions[0], bool):
completions = [bool(c) for c in completions]
def verify(predicted, actual):
if predicted == actual:
return 1.0
else:
return 0.0

return [verify(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)]




if __name__ == "__main__":
compls = [True]
ground_truth = ["True"]
res = totality_check_reward(compls, ground_truth)
print(res)


# # # verify against API

# def totality_oracle_reward(completions, triples, **kwargs):
# """
# verification callback for GRPOTRainer
# :param completions: list of truthy values produced by the model
# :param triples: list of program triples dicts {"pre":: string, "program":: string, "post:: string}
# """

# def verify(pre, program, post, is_total):
# res = verify_triple_33(
# preconditions = pre,
# program = program,
# postconditions = post,
# is_total = is_total
# )
# if res is not None:
# prediction = res['prediction_is_correct']
# return 1.0 if prediction else 0.0
# else:
# return 0.0
40 changes: 40 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33


class TestApi(unittest.TestCase):
def test_gen_triples_structure(self):
n_stmt = 3
for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt):
len_program = len(o['program'])
self.assertEqual(len_program, n_stmt)
def test_verify_triple_result(self):
is_total = True
preconditions = "True" # trivial precondition
program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4"
post_ok = "v5 == (0 - v3)" # post-condition that verifies
post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify
# # should return True
o = verify_triple_33(
is_total = is_total,
preconditions = preconditions,
program = program,
postconditions = post_ok
)
res_ok = o['prediction_is_correct']
self.assertEqual(res_ok, True)
# # should return False
o = verify_triple_33(
is_total = is_total,
preconditions = preconditions,
program = program,
postconditions = post_not_ok
)
res_not_ok = o['prediction_is_correct']
self.assertEqual(res_not_ok, False)



if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
tag_count_reward,
)

from open_r1.rewards.code.htgen import totality_check_reward

class TestRewards(unittest.TestCase):
def test_accuracy_reward_correct_answer(self):
Expand Down