From 591f3c15cd37792a819c272b560754eff1889337 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Thu, 20 Feb 2025 04:57:46 +0100 Subject: [PATCH 01/22] wip adding HTGen dataset and benchmark --- .../rewards/api/code/unfoldml/htgen.py | 130 ++++++++++++++++++ src/open_r1/rewards/code/htgen.py | 36 +++++ 2 files changed, 166 insertions(+) create mode 100644 src/open_r1/rewards/api/code/unfoldml/htgen.py create mode 100644 src/open_r1/rewards/code/htgen.py diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py new file mode 100644 index 000000000..db4474b98 --- /dev/null +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -0,0 +1,130 @@ +from json import loads, JSONDecodeError + +from requests import Response, post +from requests.exceptions import HTTPError + +api_server_url = "https://htgen.unfoldml.com" + +def gen_triples_33(n_examples:int, + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + endpoint = '/gen33', + ): + """ + Yield program triples (Precondition, Statements, Postconditions) from the API, + together with their program traces plus a initial variable environment and + whether they are totally correct or they . + :param max_ast_depth: maximum AST depth of generated expressions + :param n_stmt: no. of statements in the generated program + :param n_pre_terms: no. of AND/OR terms in the generated pre-conditions + :param n_post_terms: no. of AND/OR terms in the generated post-conditions + :param seed: random seed for the PRNG + :returns: iterable of dict e.g. + + { + 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], # starting program state + 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition + 'label': 'bad_pre', # bad precondition (one of 'bad_pre', 'bad_post', 'ok_total') + 'pre': 'v3 > (2 + v4)', + 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], + 'post': 'v3 > v4', + 'prng_state_out': [1300, 1], + 'rej_iters': 1, # number of rejection sampling iterations + 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds] + } + """ + cfg = { + "n_examples": n_examples, + "max_ast_depth": max_ast_depth, + "n_stmt": n_stmt, + "n_pre_terms": n_pre_terms, + "n_post_terms": n_post_terms, + "sm_gen_seed": seed, + "sm_gen_gamma": 1 + } + url = f"{api_server_url}/{endpoint}" + try: + res = post(url, json= cfg, stream= True) + res.raise_for_status() + for chunk in res.iter_lines(chunk_size= None, delimiter=b"\r\n"): + try: + v = loads(chunk) + if not isinstance(v, dict): + v = None + except JSONDecodeError as e: + v = None + if v is not None: + yield v + except HTTPError as he: + print(f"HTTP error: {he}") + raise he + + +def verify_triple_33( + is_total:bool, + preconditions:str = "True", + program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions:str = "v5 == (0 - v3)", + endpoint:str = '/prove33' + ): + """ + Verify a program triple and compare with a model prediction + of whether the triple is totally correct or not. + :param is_total: inferred correctness label + :param preconditions: + :param program: + :param postconditions: + :returns: whether the SMT verifier agrees with the label provided: + + {'prediction_is_correct': True} + """ + cfg = { + "pre": preconditions, + "program": program, + "post": postconditions, + "is_total": is_total, + } + url = f"{api_server_url}/{endpoint}" + try: + res = post(url, json= cfg, stream= True) + res.raise_for_status() + try: + v = res.json() + except JSONDecodeError: + v = None + print(v) + # else: + except HTTPError as he: + print(f"HTTP error: {he}") + raise he + + + + +if __name__ == "__main__": + # # generate triples + for t in gen_triples_33(n_examples = 1): + print(t) + # { + # 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], + # 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition + # 'label': 'bad_pre', # bad precondition + # 'pre': 'v3 > (2 + v4)', + # 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], + # 'post': 'v3 > v4', + # 'prng_state_out': [1300, 1], + # 'rej_iters': 1, # number of rejection sampling iterations + # 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds] + # } + + # # verify a triple against an inferred total correctness label + verify_triple_33( + is_total = True, + preconditions = "True", + program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions = "v5 == (0 - v3)" + ) + # {'prediction_is_correct': True} \ No newline at end of file diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py new file mode 100644 index 000000000..777dfe09c --- /dev/null +++ b/src/open_r1/rewards/code/htgen.py @@ -0,0 +1,36 @@ +from datasets import Dataset, IterableDataset + +from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 + + +# # GRPOTrainer requires 1. a dataset and 2. a verification callback + + +def mk_dataset( + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + ): + """ + construct an interable dataset for GRPOTrainer + """ + dataset = IterableDataset.from_generator( + gen_triples_33( + max_ast_depth = max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + ) + ) + return dataset + +def total_correctness_reward(completions, solution, **kwargs): + """ + verification callback for GRPOTRainer + """ + # pass the completion together with the reference solution to 'verify_triple_X' + # and score the result + pass From 69b44cab803798fe7e0662d2992b89b7721e22d4 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Thu, 20 Feb 2025 05:22:09 +0100 Subject: [PATCH 02/22] add API test --- setup.py | 1 + .../rewards/api/code/unfoldml/htgen.py | 2 +- tests/test_api.py | 40 +++++++++++++++++++ tests/test_rewards.py | 1 - 4 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 tests/test_api.py diff --git a/setup.py b/setup.py index 907269c20..e31d9fcc9 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ "peft>=0.14.0", "pytest", "python-dotenv", + "requests", "ruff>=0.9.0", "safetensors>=0.3.3", "sentencepiece>=0.1.99", diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py index db4474b98..5947a8739 100644 --- a/src/open_r1/rewards/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -95,7 +95,7 @@ def verify_triple_33( v = res.json() except JSONDecodeError: v = None - print(v) + return v # else: except HTTPError as he: print(f"HTTP error: {he}") diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 000000000..d5fa705d1 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,40 @@ +import unittest + +from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 + + +class TestApi(unittest.TestCase): + def test_gen_triples_structure(): + n_stmt = 3 + for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt): + len_program = len(o['program']) + self.assertEqual(len_program, n_stmt) + def test_verify_triple_result(): + is_total = True + preconditions = "True" # trivial precondition + program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4" + post_ok = "v5 == (0 - v3)" # post-condition that verifies + post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify + # # should return True + o = verify_triple_33( + is_total = is_total, + preconditions = preconditions, + program = program, + postconditions = post_ok + ) + res_ok = o['prediction_is_correct'] + self.assertEqual(res_ok, True) + # # should return False + o = verify_triple_33( + is_total = is_total, + preconditions = preconditions, + program = program, + postconditions = post_not_ok + ) + res_not_ok = o['prediction_is_correct'] + salf.assertEqual(res_not_ok, False) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 9e41bdb0d..7bcf807c5 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -9,7 +9,6 @@ reasoning_steps_reward, ) - class TestRewards(unittest.TestCase): def test_accuracy_reward_correct_answer(self): """Test accuracy_reward with a correct answer.""" From b3a95879a54b9f0bee39626979deb76d207163d4 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Thu, 20 Feb 2025 04:57:46 +0100 Subject: [PATCH 03/22] wip adding HTGen dataset and benchmark --- .../rewards/api/code/unfoldml/htgen.py | 130 ++++++++++++++++++ src/open_r1/rewards/code/htgen.py | 36 +++++ 2 files changed, 166 insertions(+) create mode 100644 src/open_r1/rewards/api/code/unfoldml/htgen.py create mode 100644 src/open_r1/rewards/code/htgen.py diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py new file mode 100644 index 000000000..db4474b98 --- /dev/null +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -0,0 +1,130 @@ +from json import loads, JSONDecodeError + +from requests import Response, post +from requests.exceptions import HTTPError + +api_server_url = "https://htgen.unfoldml.com" + +def gen_triples_33(n_examples:int, + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + endpoint = '/gen33', + ): + """ + Yield program triples (Precondition, Statements, Postconditions) from the API, + together with their program traces plus a initial variable environment and + whether they are totally correct or they . + :param max_ast_depth: maximum AST depth of generated expressions + :param n_stmt: no. of statements in the generated program + :param n_pre_terms: no. of AND/OR terms in the generated pre-conditions + :param n_post_terms: no. of AND/OR terms in the generated post-conditions + :param seed: random seed for the PRNG + :returns: iterable of dict e.g. + + { + 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], # starting program state + 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition + 'label': 'bad_pre', # bad precondition (one of 'bad_pre', 'bad_post', 'ok_total') + 'pre': 'v3 > (2 + v4)', + 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], + 'post': 'v3 > v4', + 'prng_state_out': [1300, 1], + 'rej_iters': 1, # number of rejection sampling iterations + 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds] + } + """ + cfg = { + "n_examples": n_examples, + "max_ast_depth": max_ast_depth, + "n_stmt": n_stmt, + "n_pre_terms": n_pre_terms, + "n_post_terms": n_post_terms, + "sm_gen_seed": seed, + "sm_gen_gamma": 1 + } + url = f"{api_server_url}/{endpoint}" + try: + res = post(url, json= cfg, stream= True) + res.raise_for_status() + for chunk in res.iter_lines(chunk_size= None, delimiter=b"\r\n"): + try: + v = loads(chunk) + if not isinstance(v, dict): + v = None + except JSONDecodeError as e: + v = None + if v is not None: + yield v + except HTTPError as he: + print(f"HTTP error: {he}") + raise he + + +def verify_triple_33( + is_total:bool, + preconditions:str = "True", + program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions:str = "v5 == (0 - v3)", + endpoint:str = '/prove33' + ): + """ + Verify a program triple and compare with a model prediction + of whether the triple is totally correct or not. + :param is_total: inferred correctness label + :param preconditions: + :param program: + :param postconditions: + :returns: whether the SMT verifier agrees with the label provided: + + {'prediction_is_correct': True} + """ + cfg = { + "pre": preconditions, + "program": program, + "post": postconditions, + "is_total": is_total, + } + url = f"{api_server_url}/{endpoint}" + try: + res = post(url, json= cfg, stream= True) + res.raise_for_status() + try: + v = res.json() + except JSONDecodeError: + v = None + print(v) + # else: + except HTTPError as he: + print(f"HTTP error: {he}") + raise he + + + + +if __name__ == "__main__": + # # generate triples + for t in gen_triples_33(n_examples = 1): + print(t) + # { + # 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], + # 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition + # 'label': 'bad_pre', # bad precondition + # 'pre': 'v3 > (2 + v4)', + # 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], + # 'post': 'v3 > v4', + # 'prng_state_out': [1300, 1], + # 'rej_iters': 1, # number of rejection sampling iterations + # 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds] + # } + + # # verify a triple against an inferred total correctness label + verify_triple_33( + is_total = True, + preconditions = "True", + program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions = "v5 == (0 - v3)" + ) + # {'prediction_is_correct': True} \ No newline at end of file diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py new file mode 100644 index 000000000..777dfe09c --- /dev/null +++ b/src/open_r1/rewards/code/htgen.py @@ -0,0 +1,36 @@ +from datasets import Dataset, IterableDataset + +from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 + + +# # GRPOTrainer requires 1. a dataset and 2. a verification callback + + +def mk_dataset( + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + ): + """ + construct an interable dataset for GRPOTrainer + """ + dataset = IterableDataset.from_generator( + gen_triples_33( + max_ast_depth = max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + ) + ) + return dataset + +def total_correctness_reward(completions, solution, **kwargs): + """ + verification callback for GRPOTRainer + """ + # pass the completion together with the reference solution to 'verify_triple_X' + # and score the result + pass From 41af4285829fb5df6d131f793d019ce108a92228 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Thu, 20 Feb 2025 05:22:09 +0100 Subject: [PATCH 04/22] add API test --- setup.py | 1 + .../rewards/api/code/unfoldml/htgen.py | 2 +- tests/test_api.py | 40 +++++++++++++++++++ tests/test_rewards.py | 1 - 4 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 tests/test_api.py diff --git a/setup.py b/setup.py index 907269c20..e31d9fcc9 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ "peft>=0.14.0", "pytest", "python-dotenv", + "requests", "ruff>=0.9.0", "safetensors>=0.3.3", "sentencepiece>=0.1.99", diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py index db4474b98..5947a8739 100644 --- a/src/open_r1/rewards/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -95,7 +95,7 @@ def verify_triple_33( v = res.json() except JSONDecodeError: v = None - print(v) + return v # else: except HTTPError as he: print(f"HTTP error: {he}") diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 000000000..d5fa705d1 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,40 @@ +import unittest + +from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 + + +class TestApi(unittest.TestCase): + def test_gen_triples_structure(): + n_stmt = 3 + for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt): + len_program = len(o['program']) + self.assertEqual(len_program, n_stmt) + def test_verify_triple_result(): + is_total = True + preconditions = "True" # trivial precondition + program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4" + post_ok = "v5 == (0 - v3)" # post-condition that verifies + post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify + # # should return True + o = verify_triple_33( + is_total = is_total, + preconditions = preconditions, + program = program, + postconditions = post_ok + ) + res_ok = o['prediction_is_correct'] + self.assertEqual(res_ok, True) + # # should return False + o = verify_triple_33( + is_total = is_total, + preconditions = preconditions, + program = program, + postconditions = post_not_ok + ) + res_not_ok = o['prediction_is_correct'] + salf.assertEqual(res_not_ok, False) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 9e41bdb0d..7bcf807c5 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -9,7 +9,6 @@ reasoning_steps_reward, ) - class TestRewards(unittest.TestCase): def test_accuracy_reward_correct_answer(self): """Test accuracy_reward with a correct answer.""" From 58cbde43267685cd6071ca332041b115ebf24b04 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Sat, 22 Feb 2025 07:56:56 +0100 Subject: [PATCH 05/22] construct prompt in the dataset generator --- src/open_r1/rewards/code/htgen.py | 53 +++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index 777dfe09c..a55bc6157 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -5,6 +5,50 @@ # # GRPOTrainer requires 1. a dataset and 2. a verification callback +def mk_dataset_row(o): + """ + Construct the prompt from the raw API data + """ + label = o['label'] + pre = o['pre'] + program = o['program'] # list of statements + post = o['post'] + + program_str = '\n'.join(program) + + label_is_total = label == 'ok_total' + + prompt_hdr = ( + f"Below you are given a Python program triple, made of a precondition predicate, " + f"a sequence of program statements, and a post-condition predicate." + f"The precondition returns True if the variable environment before beginning the " + f"program execution satisfies the predicate, and False otherwise. " + f"Similarly, the postcondition returns True if the program environment after the last " + f"statement satisfies the predicate, and False otherwise. " + f"Note that there might be unsatisfiable or contradictory predicates, that make the solution unreachable." + f"With this information, you should judge whether the program is 'total', i.e. " + f"whether the post-condition evaluates to True for all possible variable assigments " + f"that satisfy the precondition." + ) + + prompt_question = ( + f"Given a program triple made of program '{program_str}', preconditions '{pre}' and postcondition '{post}', is the postcondition " + f"always True at the end of the program ? Please return 'True' or 'False'." + ) + + # # concatenate header and question into a prompt + prompt_problem = f"{prompt_hdr}\n{prompt_question}" + + solution = label_is_total # boolean + + # # construct a row of the dataset + o_out = { + "problem": prompt_problem, + "solution": label_is_total + } + + return o_out + def mk_dataset( max_ast_depth:int = 3, @@ -16,15 +60,18 @@ def mk_dataset( """ construct an interable dataset for GRPOTrainer """ - dataset = IterableDataset.from_generator( - gen_triples_33( + gen = gen_triples_33( max_ast_depth = max_ast_depth, n_stmt = n_stmt, n_pre_terms = n_pre_terms, n_post_terms = n_post_terms, seed = seed, ) - ) + + # produce prompts from the raw API data + gen_prompts = (mk_dataset_row(o) for o in gen if o is not None) + + dataset = IterableDataset.from_generator(gen_prompts) return dataset def total_correctness_reward(completions, solution, **kwargs): From 2d225e98e8e2ffb98bf5238bdcb2839a03e06942 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Sat, 22 Feb 2025 08:13:44 +0100 Subject: [PATCH 06/22] prompt construction --- src/open_r1/rewards/code/htgen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index ccfc8e5a2..89000f310 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -16,16 +16,16 @@ def mk_dataset_row(o): program_str = '\n'.join(program) - label_is_total = label == 'ok_total' + prompt_hdr = ( f"Below you are given a Python program triple, made of a precondition predicate, " - f"a sequence of program statements, and a post-condition predicate." + f"a sequence of program statements, and a post-condition predicate. " f"The precondition returns True if the variable environment before beginning the " f"program execution satisfies the predicate, and False otherwise. " f"Similarly, the postcondition returns True if the program environment after the last " f"statement satisfies the predicate, and False otherwise. " - f"Note that there might be unsatisfiable or contradictory predicates, that make the solution unreachable." + f"Note that there might be unsatisfiable or contradictory predicates, that make the solution False by definition. " f"With this information, you should judge whether the program is 'total', i.e. " f"whether the post-condition evaluates to True for all possible variable assigments " f"that satisfy the precondition." @@ -39,7 +39,7 @@ def mk_dataset_row(o): # # concatenate header and question into a prompt prompt_problem = f"{prompt_hdr}\n{prompt_question}" - solution = label_is_total # boolean + label_is_total = label == 'ok_total' # boolean # # construct a row of the dataset o_out = { From d28e8f7de7bf186018f9f517d37890267d531012 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Sat, 22 Feb 2025 09:18:17 +0100 Subject: [PATCH 07/22] fix some typos and add more docstrings --- src/open_r1/rewards/code/htgen.py | 26 ++++++++++++++++---------- tests/test_api.py | 6 +++--- tests/test_rewards.py | 2 ++ 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index 89000f310..e4c9894f5 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -5,9 +5,18 @@ # # GRPOTrainer requires 1. a dataset and 2. a verification callback -def mk_dataset_row(o): +# # # Tasks +# TOTALITY_CHECK : is the program "total"? +# FIX_PRE, FIX_POST, FIX_PROGRAM : modify either part of a triple to achieve either a total triple or other proof result + +def quotes(s:str): + """markdown triple backticks for a piece of code""" + return f"```{str}```" + +# totality check task +def mk_row_totality_check(o): """ - Construct the prompt from the raw API data + Construct the prompt """ label = o['label'] pre = o['pre'] @@ -16,8 +25,6 @@ def mk_dataset_row(o): program_str = '\n'.join(program) - - prompt_hdr = ( f"Below you are given a Python program triple, made of a precondition predicate, " f"a sequence of program statements, and a post-condition predicate. " @@ -32,7 +39,8 @@ def mk_dataset_row(o): ) prompt_question = ( - f"Given a program triple made of program '{program_str}', preconditions '{pre}' and postcondition '{post}', is the postcondition " + f"Given a program triple made of program {quotes(program_str)}, " + f"preconditions {quotes(pre)} and postcondition {quotes(post)}, is the postcondition " f"always True at the end of the program ? Please return 'True' or 'False'." ) @@ -49,9 +57,7 @@ def mk_dataset_row(o): return o_out - - -def mk_dataset( +def mk_dataset_totality_check( max_ast_depth:int = 3, n_stmt:int = 5, n_pre_terms:int = 1, @@ -70,13 +76,13 @@ def mk_dataset( ) # produce prompts from the raw API data - gen_prompts = (mk_dataset_row(o) for o in gen if o is not None) + gen_prompts = (mk_row_totality_check(o) for o in gen if o is not None) dataset = IterableDataset.from_generator(gen_prompts) return dataset -def total_correctness_reward(completions, solution, **kwargs): +def totality_check_reward(completions, solution, **kwargs): """ verification callback for GRPOTRainer """ diff --git a/tests/test_api.py b/tests/test_api.py index d5fa705d1..a07325b2c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,12 +4,12 @@ class TestApi(unittest.TestCase): - def test_gen_triples_structure(): + def test_gen_triples_structure(self): n_stmt = 3 for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt): len_program = len(o['program']) self.assertEqual(len_program, n_stmt) - def test_verify_triple_result(): + def test_verify_triple_result(self): is_total = True preconditions = "True" # trivial precondition program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4" @@ -32,7 +32,7 @@ def test_verify_triple_result(): postconditions = post_not_ok ) res_not_ok = o['prediction_is_correct'] - salf.assertEqual(res_not_ok, False) + self.assertEqual(res_not_ok, False) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 7956b2cd5..5d4de44fb 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -10,6 +10,8 @@ reasoning_steps_reward, ) +from open_r1.rewards.code.htgen import totality_check_reward + class TestRewards(unittest.TestCase): def test_accuracy_reward_correct_answer(self): """Test accuracy_reward with a correct answer.""" From 40fdef852ccc54dbf6cc63d9fe773fbe831f5d75 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Sat, 22 Feb 2025 12:56:15 +0100 Subject: [PATCH 08/22] add reward --- src/open_r1/rewards/code/htgen.py | 59 +++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index e4c9894f5..5d433e1b3 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -13,10 +13,13 @@ def quotes(s:str): """markdown triple backticks for a piece of code""" return f"```{str}```" -# totality check task + +# TOTALITY_CHECK task def mk_row_totality_check(o): """ Construct the prompt + NB: the rows have a 'prompt' column as required by the GRPOTrainer interface: + https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset """ label = o['label'] pre = o['pre'] @@ -51,8 +54,9 @@ def mk_row_totality_check(o): # # construct a row of the dataset o_out = { - "problem": prompt_problem, - "solution": label_is_total + "prompt": prompt_problem, + "ground_truth": label_is_total, + "triple": {"pre": pre, "program":program, "post": post} } return o_out @@ -82,10 +86,51 @@ def mk_dataset_totality_check( return dataset -def totality_check_reward(completions, solution, **kwargs): +def totality_check_reward(completions, ground_truth, **kwargs): """ verification callback for GRPOTRainer + :param completions: list of truthy values produced by the model + :param ground_truth: list of boolean ground truth values + :returns: list of float 1s or 0s with the prediction scores that match the ground truth """ - # pass the completion together with the reference solution to 'verify_triple_X' - # and score the result - pass + if not isinstance(completions[0], bool): + completions = [bool(c) for c in completions] + def verify(predicted, actual): + if predicted == actual: + return 1.0 + else: + return 0.0 + + return [verify(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)] + + + + +if __name__ == "__main__": + compls = [True] + ground_truth = ["True"] + res = totality_check_reward(compls, ground_truth) + print(res) + + +# # # verify against API + +# def totality_oracle_reward(completions, triples, **kwargs): +# """ +# verification callback for GRPOTRainer +# :param completions: list of truthy values produced by the model +# :param triples: list of program triples dicts {"pre":: string, "program":: string, "post:: string} +# """ + +# def verify(pre, program, post, is_total): +# res = verify_triple_33( +# preconditions = pre, +# program = program, +# postconditions = post, +# is_total = is_total +# ) +# if res is not None: +# prediction = res['prediction_is_correct'] +# return 1.0 if prediction else 0.0 +# else: +# return 0.0 \ No newline at end of file From 3eae18cbeafc36182ba57f75f0be83d04e340b76 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Sun, 23 Feb 2025 04:55:57 +0100 Subject: [PATCH 09/22] fix typos --- src/open_r1/rewards/api/code/unfoldml/htgen.py | 4 +++- src/open_r1/rewards/code/htgen.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py index 5947a8739..7d1355029 100644 --- a/src/open_r1/rewards/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -5,7 +5,8 @@ api_server_url = "https://htgen.unfoldml.com" -def gen_triples_33(n_examples:int, +def gen_triples_33( + n_examples:int, max_ast_depth:int = 3, n_stmt:int = 5, n_pre_terms:int = 1, @@ -17,6 +18,7 @@ def gen_triples_33(n_examples:int, Yield program triples (Precondition, Statements, Postconditions) from the API, together with their program traces plus a initial variable environment and whether they are totally correct or they . + :param n_examples: number of triples to generate :param max_ast_depth: maximum AST depth of generated expressions :param n_stmt: no. of statements in the generated program :param n_pre_terms: no. of AND/OR terms in the generated pre-conditions diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index 5d433e1b3..a9c15dc3a 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -11,7 +11,7 @@ def quotes(s:str): """markdown triple backticks for a piece of code""" - return f"```{str}```" + return f"```{s}```" # TOTALITY_CHECK task @@ -58,7 +58,6 @@ def mk_row_totality_check(o): "ground_truth": label_is_total, "triple": {"pre": pre, "program":program, "post": post} } - return o_out def mk_dataset_totality_check( From 410b4f925c038fdd4e64e6b590b050569c3b2bec Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Tue, 25 Feb 2025 05:49:38 +0100 Subject: [PATCH 10/22] add unit test for code rewards --- .../rewards/api/code/unfoldml/htgen.py | 8 +++- src/open_r1/rewards/code/htgen.py | 45 ++++++++++++++----- tests/test_rewards.py | 2 - tests/test_rewards_code.py | 18 ++++++++ 4 files changed, 57 insertions(+), 16 deletions(-) create mode 100644 tests/test_rewards_code.py diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards/api/code/unfoldml/htgen.py index 7d1355029..49c80c065 100644 --- a/src/open_r1/rewards/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards/api/code/unfoldml/htgen.py @@ -16,8 +16,10 @@ def gen_triples_33( ): """ Yield program triples (Precondition, Statements, Postconditions) from the API, - together with their program traces plus a initial variable environment and - whether they are totally correct or they . + together with their program traces, plus a initial variable environment and + whether they are totally correct ('ok_total'), or fail to satisfy either specification. + NB: '33' stands for the number of constant and mutable identifiers in the program + :param n_examples: number of triples to generate :param max_ast_depth: maximum AST depth of generated expressions :param n_stmt: no. of statements in the generated program @@ -75,6 +77,8 @@ def verify_triple_33( """ Verify a program triple and compare with a model prediction of whether the triple is totally correct or not. + NB: '33' stands for the number of constant and mutable identifiers in the program + :param is_total: inferred correctness label :param preconditions: :param program: diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index a9c15dc3a..67f4ad29e 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -60,7 +60,9 @@ def mk_row_totality_check(o): } return o_out -def mk_dataset_totality_check( + +def mk_dataset_iter_totality_check( + n_examples:int, max_ast_depth:int = 3, n_stmt:int = 5, n_pre_terms:int = 1, @@ -68,32 +70,51 @@ def mk_dataset_totality_check( seed:int = 1234, ): """ - construct an interable dataset for GRPOTrainer + returns an interable of prompts for GRPOTrainer """ gen = gen_triples_33( - max_ast_depth = max_ast_depth, - n_stmt = n_stmt, - n_pre_terms = n_pre_terms, - n_post_terms = n_post_terms, - seed = seed, - ) - + n_examples= n_examples, + max_ast_depth = max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + ) # produce prompts from the raw API data gen_prompts = (mk_row_totality_check(o) for o in gen if o is not None) + return gen_prompts +def mk_dataset_totality_check( + n_examples:int, + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + ): + """ + construct an interable dataset for GRPOTrainer + """ + gen_prompts = mk_dataset_iter_totality_check( + n_examples= n_examples, + max_ast_depth= max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + ) dataset = IterableDataset.from_generator(gen_prompts) - return dataset def totality_check_reward(completions, ground_truth, **kwargs): """ verification callback for GRPOTRainer - :param completions: list of truthy values produced by the model + :param completions: list of "True"/"False" strings produced by the model :param ground_truth: list of boolean ground truth values :returns: list of float 1s or 0s with the prediction scores that match the ground truth """ if not isinstance(completions[0], bool): - completions = [bool(c) for c in completions] + completions = [True if c == "True" else False for c in completions] def verify(predicted, actual): if predicted == actual: return 1.0 diff --git a/tests/test_rewards.py b/tests/test_rewards.py index f3095ac1e..1b455cfdf 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -11,8 +11,6 @@ tag_count_reward, ) -from open_r1.rewards.code.htgen import totality_check_reward - class TestRewards(unittest.TestCase): def test_accuracy_reward_correct_answer(self): """Test accuracy_reward with a correct answer.""" diff --git a/tests/test_rewards_code.py b/tests/test_rewards_code.py new file mode 100644 index 000000000..f4069ab9f --- /dev/null +++ b/tests/test_rewards_code.py @@ -0,0 +1,18 @@ +from open_r1.rewards.code.htgen import totality_check_reward + +class TestRewardsCode(unittest.TestCase): + def test_totality_check_reward_correct(self): + """Test totality_check_reward""" + completion = ["True"] + solution = [True] + reward = totality_check_reward(completion, solution) + self.assertEqual(reward, 1.0) + def test_totality_check_reward_wrong_format(self): + """Test totality_check_reward, wrong format""" + completion = ["The triple is total"] + solution = [True] + reward = totality_check_reward(completion, solution) + self.assertEqual(reward, 0.0) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From d63b4a278d94496c7568305eda853c133a981d0e Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Fri, 28 Feb 2025 10:11:42 +0100 Subject: [PATCH 11/22] docstring --- src/open_r1/rewards/code/htgen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards/code/htgen.py index 67f4ad29e..72ea16783 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards/code/htgen.py @@ -26,7 +26,7 @@ def mk_row_totality_check(o): program = o['program'] # list of statements post = o['post'] - program_str = '\n'.join(program) + program_str = '\n'.join(program) # single program string prompt_hdr = ( f"Below you are given a Python program triple, made of a precondition predicate, " From c84f6456df12797722577f181af42f7dd88e931a Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 09:26:56 +0000 Subject: [PATCH 12/22] fix makefile and tests --- Makefile | 5 +++-- .../api/code/unfoldml/htgen.py | 0 .../{rewards => rewards_internal}/code/htgen.py | 2 +- tests/test_api.py | 2 +- tests/test_rewards_code.py | 11 ++++++----- 5 files changed, 11 insertions(+), 9 deletions(-) rename src/open_r1/{rewards => rewards_internal}/api/code/unfoldml/htgen.py (100%) rename src/open_r1/{rewards => rewards_internal}/code/htgen.py (98%) diff --git a/Makefile b/Makefile index c775ed66f..59cd47f6f 100644 --- a/Makefile +++ b/Makefile @@ -6,9 +6,10 @@ export PYTHONPATH = src check_dirs := src tests -# dev dependencies +# install dev dependencies (NB uses '.' instead of 'source' to work with dash/Codespaces as well) install: - uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip + curl -LsSf https://astral.sh/uv/install.sh | sh + uv venv openr1 --python 3.11 && . openr1/bin/activate && uv pip install --upgrade pip uv pip install vllm==0.7.2 uv pip install setuptools GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" diff --git a/src/open_r1/rewards/api/code/unfoldml/htgen.py b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py similarity index 100% rename from src/open_r1/rewards/api/code/unfoldml/htgen.py rename to src/open_r1/rewards_internal/api/code/unfoldml/htgen.py diff --git a/src/open_r1/rewards/code/htgen.py b/src/open_r1/rewards_internal/code/htgen.py similarity index 98% rename from src/open_r1/rewards/code/htgen.py rename to src/open_r1/rewards_internal/code/htgen.py index 72ea16783..67b33c299 100644 --- a/src/open_r1/rewards/code/htgen.py +++ b/src/open_r1/rewards_internal/code/htgen.py @@ -1,6 +1,6 @@ from datasets import Dataset, IterableDataset -from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 # # GRPOTrainer requires 1. a dataset and 2. a verification callback diff --git a/tests/test_api.py b/tests/test_api.py index a07325b2c..10852a24a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,6 @@ import unittest -from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 class TestApi(unittest.TestCase): diff --git a/tests/test_rewards_code.py b/tests/test_rewards_code.py index f4069ab9f..1ef867743 100644 --- a/tests/test_rewards_code.py +++ b/tests/test_rewards_code.py @@ -1,18 +1,19 @@ -from open_r1.rewards.code.htgen import totality_check_reward +import unittest +from open_r1.rewards_internal.code.htgen import totality_check_reward class TestRewardsCode(unittest.TestCase): def test_totality_check_reward_correct(self): """Test totality_check_reward""" completion = ["True"] solution = [True] - reward = totality_check_reward(completion, solution) - self.assertEqual(reward, 1.0) + rewards = totality_check_reward(completion, solution) + self.assertEqual(rewards[0], 1.0) def test_totality_check_reward_wrong_format(self): """Test totality_check_reward, wrong format""" completion = ["The triple is total"] solution = [True] - reward = totality_check_reward(completion, solution) - self.assertEqual(reward, 0.0) + rewards = totality_check_reward(completion, solution) + self.assertEqual(rewards[0], 0.0) if __name__ == "__main__": unittest.main() \ No newline at end of file From ff0db1e08c2e6acfc57c40b1deb8a93f69f7b696 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 11:56:00 +0000 Subject: [PATCH 13/22] fix code rewards test --- .../api/code/unfoldml/htgen.py | 9 +- src/open_r1/rewards_internal/code/htgen.py | 156 +++++++++++------- tests/test_api.py | 8 +- tests/test_rewards_code.py | 12 +- 4 files changed, 118 insertions(+), 67 deletions(-) diff --git a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py index 49c80c065..f25bce9c0 100644 --- a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py @@ -5,7 +5,7 @@ api_server_url = "https://htgen.unfoldml.com" -def gen_triples_33( +def gen_triples( n_examples:int, max_ast_depth:int = 3, n_stmt:int = 5, @@ -17,8 +17,8 @@ def gen_triples_33( """ Yield program triples (Precondition, Statements, Postconditions) from the API, together with their program traces, plus a initial variable environment and - whether they are totally correct ('ok_total'), or fail to satisfy either specification. - NB: '33' stands for the number of constant and mutable identifiers in the program + whether they are totally correct ('ok_total'), or fail to satisfy either specification ('bad_pre', 'bad_post'). + NB: in the backend, we have distinct REST endpoints for various combinations of (constant, mutable) variables, e.g. '33' and '55'. :param n_examples: number of triples to generate :param max_ast_depth: maximum AST depth of generated expressions @@ -26,6 +26,7 @@ def gen_triples_33( :param n_pre_terms: no. of AND/OR terms in the generated pre-conditions :param n_post_terms: no. of AND/OR terms in the generated post-conditions :param seed: random seed for the PRNG + :param endpoint: REST endpoint of the request. '33' stands for 3 constants and 3 mutable identifiers :returns: iterable of dict e.g. { @@ -67,7 +68,7 @@ def gen_triples_33( raise he -def verify_triple_33( +def verify_triple( is_total:bool, preconditions:str = "True", program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", diff --git a/src/open_r1/rewards_internal/code/htgen.py b/src/open_r1/rewards_internal/code/htgen.py index 67b33c299..329d404da 100644 --- a/src/open_r1/rewards_internal/code/htgen.py +++ b/src/open_r1/rewards_internal/code/htgen.py @@ -1,23 +1,92 @@ from datasets import Dataset, IterableDataset -from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple +def quotes(s:str): + """markdown triple backticks for a piece of code""" + return f"```{s}```" + +# # header of all prompts, describing Hoare logic at a high level +prompt_hdr = ( + f"Below you are given a Python program triple, made of a precondition predicate, " + f"a sequence of program statements, and a post-condition predicate. " + f"The precondition returns True if the variable environment before beginning the " + f"program execution satisfies the predicate, and False otherwise. " + f"Similarly, the postcondition returns True if the program environment after the last " + f"statement satisfies the predicate, and False otherwise. " + ) + +prompt_contradict_warning = ( + f"Note that there might be unsatisfiable or contradictory predicates such as 'v1 < v1' or 'v3 > 5 + v3' that make the solution False by definition. " +) # # GRPOTrainer requires 1. a dataset and 2. a verification callback + + # # # Tasks -# TOTALITY_CHECK : is the program "total"? -# FIX_PRE, FIX_POST, FIX_PROGRAM : modify either part of a triple to achieve either a total triple or other proof result -def quotes(s:str): - """markdown triple backticks for a piece of code""" - return f"```{s}```" +# FIX_TRIPLE task : modify either part of a triple to achieve either a total triple or other proof result +def mk_row_fix_triple(o): + """ + FIX_TRIPLE task: Construct the prompt + NB: the rows have a 'prompt' column as required by the GRPOTrainer interface: + https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset + """ + label = o['label'] # {'ok_total', 'bad_pre', 'bad_post'} + pre = o['pre'] + program = o['program'] # list of statements + post = o['post'] + + program_str = '\n'.join(program) # single program string + + match label: + case 'ok_total': + which_triple_el = 'program' + case 'bad_pre': + which_triple_el = 'precondition' + case 'bad_post': + which_triple_el = 'postcondition' + + # assemble task prompt + prompt_task = ( + f"Given a program triple made of program {quotes(program_str)}, " + f"precondition {quotes(pre)} and postcondition {quotes(post)}, " + f"You should modify the {which_triple_el} such that the resulting triple is total." + ) + + # # concatenate header, task and question into a prompt + prompt_problem = f"{prompt_hdr}\n{prompt_contradict_warning}\n{prompt_task}" + + o_out = { + "prompt": prompt_problem, + "ground_truth": label, + "triple": {"pre": pre, "program":program, "post": post} + } + return o_out + + + + + + + + + + + + + + + + + # TOTALITY_CHECK task def mk_row_totality_check(o): """ - Construct the prompt + TOTALITY_CHECK task: Construct the prompt NB: the rows have a 'prompt' column as required by the GRPOTrainer interface: https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset """ @@ -28,27 +97,20 @@ def mk_row_totality_check(o): program_str = '\n'.join(program) # single program string - prompt_hdr = ( - f"Below you are given a Python program triple, made of a precondition predicate, " - f"a sequence of program statements, and a post-condition predicate. " - f"The precondition returns True if the variable environment before beginning the " - f"program execution satisfies the predicate, and False otherwise. " - f"Similarly, the postcondition returns True if the program environment after the last " - f"statement satisfies the predicate, and False otherwise. " - f"Note that there might be unsatisfiable or contradictory predicates, that make the solution False by definition. " - f"With this information, you should judge whether the program is 'total', i.e. " + prompt_task = ( + f"You should judge whether the program is 'total', i.e. " f"whether the post-condition evaluates to True for all possible variable assigments " f"that satisfy the precondition." ) prompt_question = ( f"Given a program triple made of program {quotes(program_str)}, " - f"preconditions {quotes(pre)} and postcondition {quotes(post)}, is the postcondition " - f"always True at the end of the program ? Please return 'True' or 'False'." + f"precondition {quotes(pre)} and postcondition {quotes(post)}, is the postcondition " + f"always True at the end of the program ? Please only return 'True' or 'False'." ) - # # concatenate header and question into a prompt - prompt_problem = f"{prompt_hdr}\n{prompt_question}" + # # concatenate header, task and question into a prompt + prompt_problem = f"{prompt_hdr}\n{prompt_contradict_warning}\n{prompt_task}\n{prompt_question}" label_is_total = label == 'ok_total' # boolean @@ -61,29 +123,6 @@ def mk_row_totality_check(o): return o_out -def mk_dataset_iter_totality_check( - n_examples:int, - max_ast_depth:int = 3, - n_stmt:int = 5, - n_pre_terms:int = 1, - n_post_terms:int = 1, - seed:int = 1234, - ): - """ - returns an interable of prompts for GRPOTrainer - """ - gen = gen_triples_33( - n_examples= n_examples, - max_ast_depth = max_ast_depth, - n_stmt = n_stmt, - n_pre_terms = n_pre_terms, - n_post_terms = n_post_terms, - seed = seed, - ) - # produce prompts from the raw API data - gen_prompts = (mk_row_totality_check(o) for o in gen if o is not None) - return gen_prompts - def mk_dataset_totality_check( n_examples:int, max_ast_depth:int = 3, @@ -91,18 +130,24 @@ def mk_dataset_totality_check( n_pre_terms:int = 1, n_post_terms:int = 1, seed:int = 1234, + endpoint:str = '/gen33' ): """ construct an interable dataset for GRPOTrainer """ - gen_prompts = mk_dataset_iter_totality_check( - n_examples= n_examples, - max_ast_depth= max_ast_depth, - n_stmt = n_stmt, - n_pre_terms = n_pre_terms, - n_post_terms = n_post_terms, - seed = seed, - ) + # produce prompts from the API data + def gen_prompts(): + for o in gen_triples( + n_examples= n_examples, + max_ast_depth = max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + endpoint= endpoint + ): + if o is not None: + yield mk_row_totality_check(o) dataset = IterableDataset.from_generator(gen_prompts) return dataset @@ -115,22 +160,17 @@ def totality_check_reward(completions, ground_truth, **kwargs): """ if not isinstance(completions[0], bool): completions = [True if c == "True" else False for c in completions] - def verify(predicted, actual): + def compare(predicted, actual): if predicted == actual: return 1.0 else: return 0.0 - return [verify(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)] + return [compare(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)] -if __name__ == "__main__": - compls = [True] - ground_truth = ["True"] - res = totality_check_reward(compls, ground_truth) - print(res) # # # verify against API diff --git a/tests/test_api.py b/tests/test_api.py index 10852a24a..c7cb6dab9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,12 +1,12 @@ import unittest -from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33 +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple class TestApi(unittest.TestCase): def test_gen_triples_structure(self): n_stmt = 3 - for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt): + for o in gen_triples(n_examples = 1, n_stmt = n_stmt): len_program = len(o['program']) self.assertEqual(len_program, n_stmt) def test_verify_triple_result(self): @@ -16,7 +16,7 @@ def test_verify_triple_result(self): post_ok = "v5 == (0 - v3)" # post-condition that verifies post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify # # should return True - o = verify_triple_33( + o = verify_triple( is_total = is_total, preconditions = preconditions, program = program, @@ -25,7 +25,7 @@ def test_verify_triple_result(self): res_ok = o['prediction_is_correct'] self.assertEqual(res_ok, True) # # should return False - o = verify_triple_33( + o = verify_triple( is_total = is_total, preconditions = preconditions, program = program, diff --git a/tests/test_rewards_code.py b/tests/test_rewards_code.py index 1ef867743..bb9ec14bc 100644 --- a/tests/test_rewards_code.py +++ b/tests/test_rewards_code.py @@ -1,5 +1,5 @@ import unittest -from open_r1.rewards_internal.code.htgen import totality_check_reward +from open_r1.rewards_internal.code.htgen import totality_check_reward, mk_dataset_totality_check class TestRewardsCode(unittest.TestCase): def test_totality_check_reward_correct(self): @@ -8,6 +8,16 @@ def test_totality_check_reward_correct(self): solution = [True] rewards = totality_check_reward(completion, solution) self.assertEqual(rewards[0], 1.0) + def test_mk_dataset_totality_check_format_correct(self): + """test output format of dataset generator mk_dataset_iter_totality_check""" + ds = mk_dataset_totality_check(n_examples= 1) + examples = list(ds) + prompt = examples[0]['prompt'] + label = examples[0]['ground_truth'] + triple = examples[0]['triple'] + self.assertIsInstance(prompt, str) + self.assertIsInstance(label, bool) + self.assertIsInstance(triple, dict) def test_totality_check_reward_wrong_format(self): """Test totality_check_reward, wrong format""" completion = ["The triple is total"] From 1793479a58f6a528802af01c2fb505a18b2a5fd4 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 14:56:57 +0000 Subject: [PATCH 14/22] add prompt and fix_triple reward --- .../api/code/unfoldml/htgen.py | 56 +++--- src/open_r1/rewards_internal/code/htgen.py | 175 +++++++++++++++--- tests/test_rewards_code.py | 37 +++- 3 files changed, 212 insertions(+), 56 deletions(-) diff --git a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py index f25bce9c0..a792c64e7 100644 --- a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py @@ -110,28 +110,38 @@ def verify_triple( +def verify_triple_v2( + preconditions:str = "True", + program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions:str = "v5 == (0 - v3)", + endpoint:str = '/v2/prove33' + ): + """ + Verify a program triple, V2 endpoint + NB: '33' stands for the number of constant and mutable identifiers in the program -if __name__ == "__main__": - # # generate triples - for t in gen_triples_33(n_examples = 1): - print(t) - # { - # 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], - # 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition - # 'label': 'bad_pre', # bad precondition - # 'pre': 'v3 > (2 + v4)', - # 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], - # 'post': 'v3 > v4', - # 'prng_state_out': [1300, 1], - # 'rej_iters': 1, # number of rejection sampling iterations - # 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds] - # } + :param preconditions: + :param program: + :param postconditions: + :returns: whether the SMT verifier agrees with the label provided: - # # verify a triple against an inferred total correctness label - verify_triple_33( - is_total = True, - preconditions = "True", - program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", - postconditions = "v5 == (0 - v3)" - ) - # {'prediction_is_correct': True} \ No newline at end of file + """ + triple = { + "pre": preconditions, + "program": program, + "post": postconditions + } + url = f"{api_server_url}/{endpoint}" + try: + res = post(url, json= triple, stream= True) + res.raise_for_status() + try: + v = res.json() + except JSONDecodeError: + v = None + return v + # else: + except HTTPError as he: + print(f"HTTP error: {he}") + raise he + diff --git a/src/open_r1/rewards_internal/code/htgen.py b/src/open_r1/rewards_internal/code/htgen.py index 329d404da..f9d6d9495 100644 --- a/src/open_r1/rewards_internal/code/htgen.py +++ b/src/open_r1/rewards_internal/code/htgen.py @@ -1,6 +1,6 @@ from datasets import Dataset, IterableDataset -from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple, verify_triple_v2 def quotes(s:str): """markdown triple backticks for a piece of code""" @@ -9,24 +9,81 @@ def quotes(s:str): # # header of all prompts, describing Hoare logic at a high level prompt_hdr = ( f"Below you are given a Python program triple, made of a precondition predicate, " - f"a sequence of program statements, and a post-condition predicate. " + f"a sequence of program statements, and a postcondition predicate. " f"The precondition returns True if the variable environment before beginning the " f"program execution satisfies the predicate, and False otherwise. " f"Similarly, the postcondition returns True if the program environment after the last " f"statement satisfies the predicate, and False otherwise. " + f"We say that a triple is correct if, whenever the precondition holds for a given variable " + f"assignment, executing the program will produce a variable assignment that satisfies the postcondition. " ) prompt_contradict_warning = ( f"Note that there might be unsatisfiable or contradictory predicates such as 'v1 < v1' or 'v3 > 5 + v3' that make the solution False by definition. " ) -# # GRPOTrainer requires 1. a dataset and 2. a verification callback +def explain_vcs(o): + """ + produce a string explanation of why the verification conditions are violated with counterexamples + """ + def render_env(es): + return ', '.join(es) + def start(): + return o['state_start'] + def end(): + return o['state_end'] + def info(): + return o['info'] + def state_diff(): + s1 = start() + s2 = end() + states_diff = list(set(s1) - set(s2)) + return ', '.join(states_diff) + match o['vc']: + case 'bad_precondition': + s = start() + return f"the environment {render_env(s)} does not satisfy the precondition." + case 'bad_postcondition': + s = start() + sd = state_diff() + return f"if the program starts in state {render_env(s)}, the final environment {sd} does not satisfy the postcondition." + case 'unstable': + ident = info() # identifier that mutates throughout the program + sd = state_diff() + return f"variable {ident} is not immutable and variable assignments {sd} do not satisfy the postcondition." + case vc : + raise RuntimeWarning(f"Verification condition '{vc}' currently not supported.") # TODO all other verification conditions: + # case 'abort_reachable': + # return [] + # case 'invariant_broken_upon_loop_entry': + # return [] + # case 'invariant_broken_in_loop_body': + # return [] + # case 'measure_not_non_negative': + # return [] + # case 'measure_doesnt_decrease': + # return [] + +def explain_proof_result(wp_proof_result): + match wp_proof_result['result']: + case 'proven_total': + expl = None + case 'failed': + vcs = wp_proof_result['vcs'] # verification conditions + vces = [explain_vcs(v) for v in vcs] + vc_explanation = ' '.join(vces) + plur = 's' if len(vces)>1 else '' + expl = f"Currently, the program triple fails {len(vces)} verification condition{plur}: {vc_explanation}" + return expl + + +# # GRPOTrainer requires 1. a dataset and 2. a verification callback # # # Tasks -# FIX_TRIPLE task : modify either part of a triple to achieve either a total triple or other proof result +# FIX_TRIPLE task : modify the program to satisfy the pre- and post-conditions def mk_row_fix_triple(o): """ FIX_TRIPLE task: Construct the prompt @@ -38,21 +95,30 @@ def mk_row_fix_triple(o): program = o['program'] # list of statements post = o['post'] - program_str = '\n'.join(program) # single program string + wp_proof_result = o['wp_proof_result'] + explanation_wpr = explain_proof_result(wp_proof_result) + + program_str = '\\n'.join(program) # single program string - match label: - case 'ok_total': - which_triple_el = 'program' - case 'bad_pre': - which_triple_el = 'precondition' - case 'bad_post': - which_triple_el = 'postcondition' + # # task variant: modify either pre- or post-condition + # match label: + # case 'ok_total': + # which_triple_el = 'program' + # case 'bad_pre': + # which_triple_el = 'precondition' + # case 'bad_post': + # which_triple_el = 'postcondition' + + # # task: consider pre- and post-condition as fixed (i.e. the program specification) + which_triple_el = 'program' # only modify the program # assemble task prompt prompt_task = ( f"Given a program triple made of program {quotes(program_str)}, " f"precondition {quotes(pre)} and postcondition {quotes(post)}, " - f"You should modify the {which_triple_el} such that the resulting triple is total." + f"you should modify the {which_triple_el} such that the resulting triple is total. " + f"{explanation_wpr if explanation_wpr is not None else ''} " + f"With this information, the correct program that satisfies the given precondition and postcondition is: " ) # # concatenate header, task and question into a prompt @@ -65,20 +131,53 @@ def mk_row_fix_triple(o): } return o_out - - - - - - - - - - - +def mk_dataset_fix_triple( + n_examples:int, + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + endpoint:str = '/gen33' + ): + """ + construct an interable dataset for the 'fix_triple' task + """ + ds = mk_dataset(mk = mk_row_fix_triple, + n_examples= n_examples, + max_ast_depth = max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + endpoint= endpoint + ) + return ds +def fix_triple_reward(completions, ground_truth_triples, **kwargs): + """ + verification callback for fix_triple task + :param completions: list of program strings produced by the model + :param ground_truth_triples: list of input program ground truth triples + :returns: list of float 1s or 0s with the prediction scores that match the ground truth + """ + def compare(completion:str, triple:dict): + pre = triple['pre'] + post = triple['post'] + res = verify_triple_v2( + preconditions = pre, + program = completion, + postconditions = post + ) + if res is not None: + print(res) + reward = 1.0 if res.get('result') is 'proven_total' else 0.0 + return reward + else: + return 0.0 + return [compare(predicted, gtt) for (predicted, gtt) in zip(completions, ground_truth_triples)] @@ -133,7 +232,31 @@ def mk_dataset_totality_check( endpoint:str = '/gen33' ): """ - construct an interable dataset for GRPOTrainer + construct an interable dataset for the 'totality_check' task + """ + ds = mk_dataset(mk = mk_row_totality_check, + n_examples= n_examples, + max_ast_depth = max_ast_depth, + n_stmt = n_stmt, + n_pre_terms = n_pre_terms, + n_post_terms = n_post_terms, + seed = seed, + endpoint= endpoint + ) + return ds + +def mk_dataset( + mk, + n_examples:int, + max_ast_depth:int = 3, + n_stmt:int = 5, + n_pre_terms:int = 1, + n_post_terms:int = 1, + seed:int = 1234, + endpoint:str = '/gen33' + ): + """ + construct an interable dataset for GRPO """ # produce prompts from the API data def gen_prompts(): @@ -147,7 +270,7 @@ def gen_prompts(): endpoint= endpoint ): if o is not None: - yield mk_row_totality_check(o) + yield mk(o) dataset = IterableDataset.from_generator(gen_prompts) return dataset diff --git a/tests/test_rewards_code.py b/tests/test_rewards_code.py index bb9ec14bc..0722b84f5 100644 --- a/tests/test_rewards_code.py +++ b/tests/test_rewards_code.py @@ -1,13 +1,7 @@ import unittest -from open_r1.rewards_internal.code.htgen import totality_check_reward, mk_dataset_totality_check +from open_r1.rewards_internal.code.htgen import totality_check_reward, mk_dataset_totality_check, fix_triple_reward, mk_dataset_fix_triple class TestRewardsCode(unittest.TestCase): - def test_totality_check_reward_correct(self): - """Test totality_check_reward""" - completion = ["True"] - solution = [True] - rewards = totality_check_reward(completion, solution) - self.assertEqual(rewards[0], 1.0) def test_mk_dataset_totality_check_format_correct(self): """test output format of dataset generator mk_dataset_iter_totality_check""" ds = mk_dataset_totality_check(n_examples= 1) @@ -18,12 +12,41 @@ def test_mk_dataset_totality_check_format_correct(self): self.assertIsInstance(prompt, str) self.assertIsInstance(label, bool) self.assertIsInstance(triple, dict) + def test_mk_dataset_fix_triple_format_correct(self): + """test output format of dataset generator mk_dataset_fix_triple""" + ds = mk_dataset_fix_triple(n_examples= 1, seed= 5556) + examples = list(ds) + ex = examples[0] + prompt = ex['prompt'] + # print(prompt) # DEBUG + label = ex['ground_truth'] + triple = ex['triple'] + self.assertIsInstance(prompt, str) + self.assertIsInstance(label, str) + self.assertIn(label, ['ok_total', 'bad_pre', 'bad_post']) + self.assertIsInstance(triple, dict) + # self.assertIsInstance(diagnosis, str) + def test_totality_check_reward_correct(self): + """Test totality_check_reward""" + completion = ["True"] + solution = [True] + rewards = totality_check_reward(completion, solution) + self.assertEqual(rewards[0], 1.0) def test_totality_check_reward_wrong_format(self): """Test totality_check_reward, wrong format""" completion = ["The triple is total"] solution = [True] rewards = totality_check_reward(completion, solution) self.assertEqual(rewards[0], 0.0) + def test_fix_triple_reward(self): + triple = { + "pre": "v3 > 0 && v4 > 2", + "program": "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv4 = 9\nv4 = (v3 - 7)", + "post": "v5 > 6" + } + completion = "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv5 = v4" + rewards = fix_triple_reward([completion], [triple]) + self.assertEqual(rewards[0], 1.0) if __name__ == "__main__": unittest.main() \ No newline at end of file From 330308381260e60c1906481639df55e6ba1be03a Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 16:19:36 +0000 Subject: [PATCH 15/22] fix makefile to activate venv correctly --- Makefile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 59cd47f6f..aec4c40ba 100644 --- a/Makefile +++ b/Makefile @@ -6,10 +6,13 @@ export PYTHONPATH = src check_dirs := src tests -# install dev dependencies (NB uses '.' instead of 'source' to work with dash/Codespaces as well) +# install dev dependencies +# # NB uses '.' instead of 'source' to work with dash/Codespaces as well) install: curl -LsSf https://astral.sh/uv/install.sh | sh - uv venv openr1 --python 3.11 && . openr1/bin/activate && uv pip install --upgrade pip + uv venv openr1 --python 3.11 + (. openr1/bin/activate) + uv pip install --upgrade pip uv pip install vllm==0.7.2 uv pip install setuptools GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" From 532a01247a729b94e405ff6f787e34acefb9e194 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 16:28:56 +0000 Subject: [PATCH 16/22] fix_triple task: add reward tests and docstrings --- src/open_r1/rewards_internal/code/htgen.py | 12 +++++++----- tests/test_rewards_code.py | 15 ++++++++++++--- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/open_r1/rewards_internal/code/htgen.py b/src/open_r1/rewards_internal/code/htgen.py index f9d6d9495..0a70fe61a 100644 --- a/src/open_r1/rewards_internal/code/htgen.py +++ b/src/open_r1/rewards_internal/code/htgen.py @@ -86,7 +86,10 @@ def explain_proof_result(wp_proof_result): # FIX_TRIPLE task : modify the program to satisfy the pre- and post-conditions def mk_row_fix_triple(o): """ - FIX_TRIPLE task: Construct the prompt + FIX_TRIPLE task: given a program triple, modify the program such that it satisfies + pre- and post-conditions (i.e. the program spec) + + This function constructs the prompt from a dict that comes from the dataset API NB: the rows have a 'prompt' column as required by the GRPOTrainer interface: https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset """ @@ -158,8 +161,8 @@ def mk_dataset_fix_triple( def fix_triple_reward(completions, ground_truth_triples, **kwargs): """ verification callback for fix_triple task - :param completions: list of program strings produced by the model - :param ground_truth_triples: list of input program ground truth triples + :param completions: list of program strings (produced by the model) + :param ground_truth_triples: list of input program ground truth triples (coming from the dataset) :returns: list of float 1s or 0s with the prediction scores that match the ground truth """ def compare(completion:str, triple:dict): @@ -171,8 +174,7 @@ def compare(completion:str, triple:dict): postconditions = post ) if res is not None: - print(res) - reward = 1.0 if res.get('result') is 'proven_total' else 0.0 + reward = 1.0 if res.get('result') == 'proven_total' else 0.0 return reward else: return 0.0 diff --git a/tests/test_rewards_code.py b/tests/test_rewards_code.py index 0722b84f5..096ba8908 100644 --- a/tests/test_rewards_code.py +++ b/tests/test_rewards_code.py @@ -18,14 +18,12 @@ def test_mk_dataset_fix_triple_format_correct(self): examples = list(ds) ex = examples[0] prompt = ex['prompt'] - # print(prompt) # DEBUG label = ex['ground_truth'] triple = ex['triple'] self.assertIsInstance(prompt, str) self.assertIsInstance(label, str) self.assertIn(label, ['ok_total', 'bad_pre', 'bad_post']) self.assertIsInstance(triple, dict) - # self.assertIsInstance(diagnosis, str) def test_totality_check_reward_correct(self): """Test totality_check_reward""" completion = ["True"] @@ -38,7 +36,8 @@ def test_totality_check_reward_wrong_format(self): solution = [True] rewards = totality_check_reward(completion, solution) self.assertEqual(rewards[0], 0.0) - def test_fix_triple_reward(self): + def test_fix_triple_reward_correct(self): + """fix_triple task: assert a correct completion gives 1.0 reward""" triple = { "pre": "v3 > 0 && v4 > 2", "program": "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv4 = 9\nv4 = (v3 - 7)", @@ -47,6 +46,16 @@ def test_fix_triple_reward(self): completion = "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv5 = v4" rewards = fix_triple_reward([completion], [triple]) self.assertEqual(rewards[0], 1.0) + def test_fix_triple_reward_wrong_0(self): + """fix_triple task: asserts an incorrect completion gives 0.0 reward""" + triple = { + "pre": "v3 > 0 && v4 > 2", + "program": "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv4 = 9\nv4 = (v3 - 7)", + "post": "v5 > 6" + } + completion = "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv5 = v3 + v3" + rewards = fix_triple_reward([completion], [triple]) + self.assertEqual(rewards[0], 0.0) if __name__ == "__main__": unittest.main() \ No newline at end of file From 66969e8bc44b69c15bab112cdf3e48106176da88 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 16:33:39 +0000 Subject: [PATCH 17/22] readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b520ee022..641d89351 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ To install `uv`, follow the [UV Installation Guide](https://docs.astral.sh/uv/ge > [!NOTE] -> As a shortcut, run `make install` to setup development libraries (spelled out below). Afterwards if everything is setup correctly and you have a functioning CUDA, you can install `flash-attn` and try out the Open-R1 models. +> As a shortcut, run `make install` to setup `uv` and the development libraries (spelled out below). Afterwards, if everything is setup correctly and you have a functioning CUDA, you can install `flash-attn` and try out the Open-R1 models. ```shell From 3f88b068d023cc56e87c209c4f279c4e20f2ae82 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 16:41:10 +0000 Subject: [PATCH 18/22] fix style and quality --- .../api/code/unfoldml/htgen.py | 97 ++++--- src/open_r1/rewards_internal/code/htgen.py | 266 +++++++++--------- tests/test_api.py | 32 +-- tests/test_rewards.py | 1 + tests/test_rewards_code.py | 39 ++- 5 files changed, 220 insertions(+), 215 deletions(-) diff --git a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py index a792c64e7..c510d7333 100644 --- a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py @@ -1,22 +1,24 @@ -from json import loads, JSONDecodeError +from json import JSONDecodeError, loads -from requests import Response, post +from requests import post from requests.exceptions import HTTPError + api_server_url = "https://htgen.unfoldml.com" + def gen_triples( - n_examples:int, - max_ast_depth:int = 3, - n_stmt:int = 5, - n_pre_terms:int = 1, - n_post_terms:int = 1, - seed:int = 1234, - endpoint = '/gen33', - ): + n_examples: int, + max_ast_depth: int = 3, + n_stmt: int = 5, + n_pre_terms: int = 1, + n_post_terms: int = 1, + seed: int = 1234, + endpoint="/gen33", +): """ Yield program triples (Precondition, Statements, Postconditions) from the API, - together with their program traces, plus a initial variable environment and + together with their program traces, plus a initial variable environment and whether they are totally correct ('ok_total'), or fail to satisfy either specification ('bad_pre', 'bad_post'). NB: in the backend, we have distinct REST endpoints for various combinations of (constant, mutable) variables, e.g. '33' and '55'. @@ -33,10 +35,10 @@ def gen_triples( 'env_initial': ['v0 = 15', 'v1 = 42', 'v2 = -36', 'v3 = 73', 'v4 = 72', 'v5 = 64'], # starting program state 'env_trace': [], # no execution trace because the starting env doesn't satisfy the precondition 'label': 'bad_pre', # bad precondition (one of 'bad_pre', 'bad_post', 'ok_total') - 'pre': 'v3 > (2 + v4)', - 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], - 'post': 'v3 > v4', - 'prng_state_out': [1300, 1], + 'pre': 'v3 > (2 + v4)', + 'program': ['v3 = v5', 'v4 = (4 - (4 - (v5 - 4)))', 'v5 = v4', 'v4 = (v5 - v3)', 'v3 = 4'], + 'post': 'v3 > v4', + 'prng_state_out': [1300, 1], 'rej_iters': 1, # number of rejection sampling iterations 'rej_iters_time_s': 0.056072775 # time it took to generate this triple [seconds] } @@ -48,42 +50,43 @@ def gen_triples( "n_pre_terms": n_pre_terms, "n_post_terms": n_post_terms, "sm_gen_seed": seed, - "sm_gen_gamma": 1 - } + "sm_gen_gamma": 1, + } url = f"{api_server_url}/{endpoint}" try: - res = post(url, json= cfg, stream= True) + res = post(url, json=cfg, stream=True) res.raise_for_status() - for chunk in res.iter_lines(chunk_size= None, delimiter=b"\r\n"): + for chunk in res.iter_lines(chunk_size=None, delimiter=b"\r\n"): try: v = loads(chunk) if not isinstance(v, dict): v = None except JSONDecodeError as e: + print(f"JSON decode error: {e}") v = None - if v is not None: + if v is not None: yield v except HTTPError as he: print(f"HTTP error: {he}") raise he - + def verify_triple( - is_total:bool, - preconditions:str = "True", - program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", - postconditions:str = "v5 == (0 - v3)", - endpoint:str = '/prove33' - ): + is_total: bool, + preconditions: str = "True", + program: str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions: str = "v5 == (0 - v3)", + endpoint: str = "/prove33", +): """ - Verify a program triple and compare with a model prediction + Verify a program triple and compare with a model prediction of whether the triple is totally correct or not. NB: '33' stands for the number of constant and mutable identifiers in the program :param is_total: inferred correctness label - :param preconditions: - :param program: - :param postconditions: + :param preconditions: + :param program: + :param postconditions: :returns: whether the SMT verifier agrees with the label provided: {'prediction_is_correct': True} @@ -96,52 +99,46 @@ def verify_triple( } url = f"{api_server_url}/{endpoint}" try: - res = post(url, json= cfg, stream= True) + res = post(url, json=cfg, stream=True) res.raise_for_status() try: v = res.json() except JSONDecodeError: v = None return v - # else: + # else: except HTTPError as he: print(f"HTTP error: {he}") raise he - def verify_triple_v2( - preconditions:str = "True", - program:str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", - postconditions:str = "v5 == (0 - v3)", - endpoint:str = '/v2/prove33' - ): + preconditions: str = "True", + program: str = "v4 = (0 - v3)\nv3 = v3\nv5 = v4", + postconditions: str = "v5 == (0 - v3)", + endpoint: str = "/v2/prove33", +): """ Verify a program triple, V2 endpoint NB: '33' stands for the number of constant and mutable identifiers in the program - :param preconditions: - :param program: - :param postconditions: + :param preconditions: + :param program: + :param postconditions: :returns: whether the SMT verifier agrees with the label provided: """ - triple = { - "pre": preconditions, - "program": program, - "post": postconditions - } + triple = {"pre": preconditions, "program": program, "post": postconditions} url = f"{api_server_url}/{endpoint}" try: - res = post(url, json= triple, stream= True) + res = post(url, json=triple, stream=True) res.raise_for_status() try: v = res.json() except JSONDecodeError: v = None return v - # else: + # else: except HTTPError as he: print(f"HTTP error: {he}") raise he - diff --git a/src/open_r1/rewards_internal/code/htgen.py b/src/open_r1/rewards_internal/code/htgen.py index 0a70fe61a..09285ba64 100644 --- a/src/open_r1/rewards_internal/code/htgen.py +++ b/src/open_r1/rewards_internal/code/htgen.py @@ -1,60 +1,67 @@ -from datasets import Dataset, IterableDataset +from datasets import IterableDataset -from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple, verify_triple_v2 +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple_v2 -def quotes(s:str): + +def quotes(s: str): """markdown triple backticks for a piece of code""" return f"```{s}```" + # # header of all prompts, describing Hoare logic at a high level prompt_hdr = ( - f"Below you are given a Python program triple, made of a precondition predicate, " - f"a sequence of program statements, and a postcondition predicate. " - f"The precondition returns True if the variable environment before beginning the " - f"program execution satisfies the predicate, and False otherwise. " - f"Similarly, the postcondition returns True if the program environment after the last " - f"statement satisfies the predicate, and False otherwise. " - f"We say that a triple is correct if, whenever the precondition holds for a given variable " - f"assignment, executing the program will produce a variable assignment that satisfies the postcondition. " - ) - -prompt_contradict_warning = ( - f"Note that there might be unsatisfiable or contradictory predicates such as 'v1 < v1' or 'v3 > 5 + v3' that make the solution False by definition. " + "Below you are given a Python program triple, made of a precondition predicate, " + "a sequence of program statements, and a postcondition predicate. " + "The precondition returns True if the variable environment before beginning the " + "program execution satisfies the predicate, and False otherwise. " + "Similarly, the postcondition returns True if the program environment after the last " + "statement satisfies the predicate, and False otherwise. " + "We say that a triple is correct if, whenever the precondition holds for a given variable " + "assignment, executing the program will produce a variable assignment that satisfies the postcondition. " ) +prompt_contradict_warning = "Note that there might be unsatisfiable or contradictory predicates such as 'v1 < v1' or 'v3 > 5 + v3' that make the solution False by definition. " def explain_vcs(o): """ produce a string explanation of why the verification conditions are violated with counterexamples """ + def render_env(es): - return ', '.join(es) + return ", ".join(es) + def start(): - return o['state_start'] + return o["state_start"] + def end(): - return o['state_end'] + return o["state_end"] + def info(): - return o['info'] + return o["info"] + def state_diff(): s1 = start() s2 = end() states_diff = list(set(s1) - set(s2)) - return ', '.join(states_diff) - match o['vc']: - case 'bad_precondition': + return ", ".join(states_diff) + + match o["vc"]: + case "bad_precondition": s = start() return f"the environment {render_env(s)} does not satisfy the precondition." - case 'bad_postcondition': + case "bad_postcondition": s = start() sd = state_diff() return f"if the program starts in state {render_env(s)}, the final environment {sd} does not satisfy the postcondition." - case 'unstable': - ident = info() # identifier that mutates throughout the program + case "unstable": + ident = info() # identifier that mutates throughout the program sd = state_diff() return f"variable {ident} is not immutable and variable assignments {sd} do not satisfy the postcondition." - case vc : - raise RuntimeWarning(f"Verification condition '{vc}' currently not supported.") # TODO all other verification conditions: + case vc: + raise RuntimeWarning( + f"Verification condition '{vc}' currently not supported." + ) # TODO all other verification conditions: # case 'abort_reachable': # return [] # case 'invariant_broken_upon_loop_entry': @@ -66,44 +73,46 @@ def state_diff(): # case 'measure_doesnt_decrease': # return [] + def explain_proof_result(wp_proof_result): - match wp_proof_result['result']: - case 'proven_total': + match wp_proof_result["result"]: + case "proven_total": expl = None - case 'failed': - vcs = wp_proof_result['vcs'] # verification conditions + case "failed": + vcs = wp_proof_result["vcs"] # verification conditions vces = [explain_vcs(v) for v in vcs] - vc_explanation = ' '.join(vces) - plur = 's' if len(vces)>1 else '' + vc_explanation = " ".join(vces) + plur = "s" if len(vces) > 1 else "" expl = f"Currently, the program triple fails {len(vces)} verification condition{plur}: {vc_explanation}" return expl # # GRPOTrainer requires 1. a dataset and 2. a verification callback -# # # Tasks +# # # Tasks + # FIX_TRIPLE task : modify the program to satisfy the pre- and post-conditions def mk_row_fix_triple(o): """ - FIX_TRIPLE task: given a program triple, modify the program such that it satisfies + FIX_TRIPLE task: given a program triple, modify the program such that it satisfies pre- and post-conditions (i.e. the program spec) - + This function constructs the prompt from a dict that comes from the dataset API NB: the rows have a 'prompt' column as required by the GRPOTrainer interface: - https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset + https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset """ - label = o['label'] # {'ok_total', 'bad_pre', 'bad_post'} - pre = o['pre'] - program = o['program'] # list of statements - post = o['post'] + label = o["label"] # {'ok_total', 'bad_pre', 'bad_post'} + pre = o["pre"] + program = o["program"] # list of statements + post = o["post"] - wp_proof_result = o['wp_proof_result'] + wp_proof_result = o["wp_proof_result"] explanation_wpr = explain_proof_result(wp_proof_result) - program_str = '\\n'.join(program) # single program string + program_str = "\\n".join(program) # single program string - # # task variant: modify either pre- or post-condition + # # task variant: modify either pre- or post-condition # match label: # case 'ok_total': # which_triple_el = 'program' @@ -113,48 +122,46 @@ def mk_row_fix_triple(o): # which_triple_el = 'postcondition' # # task: consider pre- and post-condition as fixed (i.e. the program specification) - which_triple_el = 'program' # only modify the program - + which_triple_el = "program" # only modify the program + # assemble task prompt prompt_task = ( f"Given a program triple made of program {quotes(program_str)}, " f"precondition {quotes(pre)} and postcondition {quotes(post)}, " f"you should modify the {which_triple_el} such that the resulting triple is total. " f"{explanation_wpr if explanation_wpr is not None else ''} " - f"With this information, the correct program that satisfies the given precondition and postcondition is: " - ) + "With this information, the correct program that satisfies the given precondition and postcondition is: " + ) # # concatenate header, task and question into a prompt prompt_problem = f"{prompt_hdr}\n{prompt_contradict_warning}\n{prompt_task}" - o_out = { - "prompt": prompt_problem, - "ground_truth": label, - "triple": {"pre": pre, "program":program, "post": post} - } + o_out = {"prompt": prompt_problem, "ground_truth": label, "triple": {"pre": pre, "program": program, "post": post}} return o_out + def mk_dataset_fix_triple( - n_examples:int, - max_ast_depth:int = 3, - n_stmt:int = 5, - n_pre_terms:int = 1, - n_post_terms:int = 1, - seed:int = 1234, - endpoint:str = '/gen33' - ): + n_examples: int, + max_ast_depth: int = 3, + n_stmt: int = 5, + n_pre_terms: int = 1, + n_post_terms: int = 1, + seed: int = 1234, + endpoint: str = "/gen33", +): """ construct an interable dataset for the 'fix_triple' task """ - ds = mk_dataset(mk = mk_row_fix_triple, - n_examples= n_examples, - max_ast_depth = max_ast_depth, - n_stmt = n_stmt, - n_pre_terms = n_pre_terms, - n_post_terms = n_post_terms, - seed = seed, - endpoint= endpoint - ) + ds = mk_dataset( + mk=mk_row_fix_triple, + n_examples=n_examples, + max_ast_depth=max_ast_depth, + n_stmt=n_stmt, + n_pre_terms=n_pre_terms, + n_post_terms=n_post_terms, + seed=seed, + endpoint=endpoint, + ) return ds @@ -165,16 +172,13 @@ def fix_triple_reward(completions, ground_truth_triples, **kwargs): :param ground_truth_triples: list of input program ground truth triples (coming from the dataset) :returns: list of float 1s or 0s with the prediction scores that match the ground truth """ - def compare(completion:str, triple:dict): - pre = triple['pre'] - post = triple['post'] - res = verify_triple_v2( - preconditions = pre, - program = completion, - postconditions = post - ) + + def compare(completion: str, triple: dict): + pre = triple["pre"] + post = triple["post"] + res = verify_triple_v2(preconditions=pre, program=completion, postconditions=post) if res is not None: - reward = 1.0 if res.get('result') == 'proven_total' else 0.0 + reward = 1.0 if res.get("result") == "proven_total" else 0.0 return reward else: return 0.0 @@ -182,100 +186,103 @@ def compare(completion:str, triple:dict): return [compare(predicted, gtt) for (predicted, gtt) in zip(completions, ground_truth_triples)] - - # TOTALITY_CHECK task def mk_row_totality_check(o): """ TOTALITY_CHECK task: Construct the prompt NB: the rows have a 'prompt' column as required by the GRPOTrainer interface: - https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset + https://huggingface.co/docs/trl/main/grpo_trainer#trl.GRPOTrainer.train_dataset """ - label = o['label'] - pre = o['pre'] - program = o['program'] # list of statements - post = o['post'] + label = o["label"] + pre = o["pre"] + program = o["program"] # list of statements + post = o["post"] - program_str = '\n'.join(program) # single program string + program_str = "\n".join(program) # single program string prompt_task = ( - f"You should judge whether the program is 'total', i.e. " - f"whether the post-condition evaluates to True for all possible variable assigments " - f"that satisfy the precondition." + "You should judge whether the program is 'total', i.e. " + "whether the post-condition evaluates to True for all possible variable assigments " + "that satisfy the precondition." ) prompt_question = ( f"Given a program triple made of program {quotes(program_str)}, " f"precondition {quotes(pre)} and postcondition {quotes(post)}, is the postcondition " - f"always True at the end of the program ? Please only return 'True' or 'False'." + "always True at the end of the program ? Please only return 'True' or 'False'." ) # # concatenate header, task and question into a prompt prompt_problem = f"{prompt_hdr}\n{prompt_contradict_warning}\n{prompt_task}\n{prompt_question}" - label_is_total = label == 'ok_total' # boolean + label_is_total = label == "ok_total" # boolean # # construct a row of the dataset o_out = { "prompt": prompt_problem, "ground_truth": label_is_total, - "triple": {"pre": pre, "program":program, "post": post} + "triple": {"pre": pre, "program": program, "post": post}, } return o_out def mk_dataset_totality_check( - n_examples:int, - max_ast_depth:int = 3, - n_stmt:int = 5, - n_pre_terms:int = 1, - n_post_terms:int = 1, - seed:int = 1234, - endpoint:str = '/gen33' - ): + n_examples: int, + max_ast_depth: int = 3, + n_stmt: int = 5, + n_pre_terms: int = 1, + n_post_terms: int = 1, + seed: int = 1234, + endpoint: str = "/gen33", +): """ construct an interable dataset for the 'totality_check' task """ - ds = mk_dataset(mk = mk_row_totality_check, - n_examples= n_examples, - max_ast_depth = max_ast_depth, - n_stmt = n_stmt, - n_pre_terms = n_pre_terms, - n_post_terms = n_post_terms, - seed = seed, - endpoint= endpoint - ) + ds = mk_dataset( + mk=mk_row_totality_check, + n_examples=n_examples, + max_ast_depth=max_ast_depth, + n_stmt=n_stmt, + n_pre_terms=n_pre_terms, + n_post_terms=n_post_terms, + seed=seed, + endpoint=endpoint, + ) return ds + def mk_dataset( mk, - n_examples:int, - max_ast_depth:int = 3, - n_stmt:int = 5, - n_pre_terms:int = 1, - n_post_terms:int = 1, - seed:int = 1234, - endpoint:str = '/gen33' - ): + n_examples: int, + max_ast_depth: int = 3, + n_stmt: int = 5, + n_pre_terms: int = 1, + n_post_terms: int = 1, + seed: int = 1234, + endpoint: str = "/gen33", +): """ construct an interable dataset for GRPO """ + # produce prompts from the API data def gen_prompts(): for o in gen_triples( - n_examples= n_examples, - max_ast_depth = max_ast_depth, - n_stmt = n_stmt, - n_pre_terms = n_pre_terms, - n_post_terms = n_post_terms, - seed = seed, - endpoint= endpoint - ): + n_examples=n_examples, + max_ast_depth=max_ast_depth, + n_stmt=n_stmt, + n_pre_terms=n_pre_terms, + n_post_terms=n_post_terms, + seed=seed, + endpoint=endpoint, + ): if o is not None: yield mk(o) + dataset = IterableDataset.from_generator(gen_prompts) return dataset + def totality_check_reward(completions, ground_truth, **kwargs): """ verification callback for GRPOTRainer @@ -285,6 +292,7 @@ def totality_check_reward(completions, ground_truth, **kwargs): """ if not isinstance(completions[0], bool): completions = [True if c == "True" else False for c in completions] + def compare(predicted, actual): if predicted == actual: return 1.0 @@ -292,10 +300,6 @@ def compare(predicted, actual): return 0.0 return [compare(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)] - - - - # # # verify against API @@ -318,4 +322,4 @@ def compare(predicted, actual): # prediction = res['prediction_is_correct'] # return 1.0 if prediction else 0.0 # else: -# return 0.0 \ No newline at end of file +# return 0.0 diff --git a/tests/test_api.py b/tests/test_api.py index c7cb6dab9..253459918 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,35 +6,25 @@ class TestApi(unittest.TestCase): def test_gen_triples_structure(self): n_stmt = 3 - for o in gen_triples(n_examples = 1, n_stmt = n_stmt): - len_program = len(o['program']) + for o in gen_triples(n_examples=1, n_stmt=n_stmt): + len_program = len(o["program"]) self.assertEqual(len_program, n_stmt) + def test_verify_triple_result(self): is_total = True - preconditions = "True" # trivial precondition + preconditions = "True" # trivial precondition program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4" - post_ok = "v5 == (0 - v3)" # post-condition that verifies - post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify + post_ok = "v5 == (0 - v3)" # post-condition that verifies + post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify # # should return True - o = verify_triple( - is_total = is_total, - preconditions = preconditions, - program = program, - postconditions = post_ok - ) - res_ok = o['prediction_is_correct'] + o = verify_triple(is_total=is_total, preconditions=preconditions, program=program, postconditions=post_ok) + res_ok = o["prediction_is_correct"] self.assertEqual(res_ok, True) # # should return False - o = verify_triple( - is_total = is_total, - preconditions = preconditions, - program = program, - postconditions = post_not_ok - ) - res_not_ok = o['prediction_is_correct'] + o = verify_triple(is_total=is_total, preconditions=preconditions, program=program, postconditions=post_not_ok) + res_not_ok = o["prediction_is_correct"] self.assertEqual(res_not_ok, False) - if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 1b455cfdf..103b9981a 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -11,6 +11,7 @@ tag_count_reward, ) + class TestRewards(unittest.TestCase): def test_accuracy_reward_correct_answer(self): """Test accuracy_reward with a correct answer.""" diff --git a/tests/test_rewards_code.py b/tests/test_rewards_code.py index 096ba8908..22a26816d 100644 --- a/tests/test_rewards_code.py +++ b/tests/test_rewards_code.py @@ -1,61 +1,74 @@ import unittest -from open_r1.rewards_internal.code.htgen import totality_check_reward, mk_dataset_totality_check, fix_triple_reward, mk_dataset_fix_triple + +from open_r1.rewards_internal.code.htgen import ( + fix_triple_reward, + mk_dataset_fix_triple, + mk_dataset_totality_check, + totality_check_reward, +) + class TestRewardsCode(unittest.TestCase): def test_mk_dataset_totality_check_format_correct(self): """test output format of dataset generator mk_dataset_iter_totality_check""" - ds = mk_dataset_totality_check(n_examples= 1) + ds = mk_dataset_totality_check(n_examples=1) examples = list(ds) - prompt = examples[0]['prompt'] - label = examples[0]['ground_truth'] - triple = examples[0]['triple'] + prompt = examples[0]["prompt"] + label = examples[0]["ground_truth"] + triple = examples[0]["triple"] self.assertIsInstance(prompt, str) self.assertIsInstance(label, bool) self.assertIsInstance(triple, dict) + def test_mk_dataset_fix_triple_format_correct(self): """test output format of dataset generator mk_dataset_fix_triple""" - ds = mk_dataset_fix_triple(n_examples= 1, seed= 5556) + ds = mk_dataset_fix_triple(n_examples=1, seed=5556) examples = list(ds) ex = examples[0] - prompt = ex['prompt'] - label = ex['ground_truth'] - triple = ex['triple'] + prompt = ex["prompt"] + label = ex["ground_truth"] + triple = ex["triple"] self.assertIsInstance(prompt, str) self.assertIsInstance(label, str) - self.assertIn(label, ['ok_total', 'bad_pre', 'bad_post']) + self.assertIn(label, ["ok_total", "bad_pre", "bad_post"]) self.assertIsInstance(triple, dict) + def test_totality_check_reward_correct(self): """Test totality_check_reward""" completion = ["True"] solution = [True] rewards = totality_check_reward(completion, solution) self.assertEqual(rewards[0], 1.0) + def test_totality_check_reward_wrong_format(self): """Test totality_check_reward, wrong format""" completion = ["The triple is total"] solution = [True] rewards = totality_check_reward(completion, solution) self.assertEqual(rewards[0], 0.0) + def test_fix_triple_reward_correct(self): """fix_triple task: assert a correct completion gives 1.0 reward""" triple = { "pre": "v3 > 0 && v4 > 2", "program": "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv4 = 9\nv4 = (v3 - 7)", - "post": "v5 > 6" + "post": "v5 > 6", } completion = "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv5 = v4" rewards = fix_triple_reward([completion], [triple]) self.assertEqual(rewards[0], 1.0) + def test_fix_triple_reward_wrong_0(self): """fix_triple task: asserts an incorrect completion gives 0.0 reward""" triple = { "pre": "v3 > 0 && v4 > 2", "program": "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv4 = 9\nv4 = (v3 - 7)", - "post": "v5 > 6" + "post": "v5 > 6", } completion = "v5 = 2\nv3 = v5\nv4 = ((5 + (3 + v3)) + (v4 + v5))\nv5 = v3 + v3" rewards = fix_triple_reward([completion], [triple]) self.assertEqual(rewards[0], 0.0) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 604f66f44d3047b612d93132200ecd9a3234054c Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 19:24:55 +0000 Subject: [PATCH 19/22] cannot reliably activate venv within makefile --- Makefile | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index aec4c40ba..6e76a411b 100644 --- a/Makefile +++ b/Makefile @@ -5,18 +5,28 @@ export PYTHONPATH = src check_dirs := src tests - -# install dev dependencies -# # NB uses '.' instead of 'source' to work with dash/Codespaces as well) -install: +############### Installation +# 1) install uv +install-uv: curl -LsSf https://astral.sh/uv/install.sh | sh + +# 2) set up virtual environment +venv: uv venv openr1 --python 3.11 - (. openr1/bin/activate) + +# 3) activate virtual env +# # activate NB uses '.' instead of 'source' to work with dash/Codespaces as well) +# (. openr1/bin/activate) + +# 4) install dev dependencies +install: uv pip install --upgrade pip uv pip install vllm==0.7.2 uv pip install setuptools GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" +############### Linting + style: ruff format --line-length 119 --target-version py310 $(check_dirs) setup.py isort $(check_dirs) setup.py @@ -26,6 +36,8 @@ quality: isort --check-only $(check_dirs) setup.py flake8 --max-line-length 119 $(check_dirs) setup.py +############### Test + test: pytest -sv tests/ From ed0c48471c90886396be11173bc3f07998f03246 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Sun, 2 Mar 2025 19:28:35 +0000 Subject: [PATCH 20/22] ignore API json parsing errors --- src/open_r1/rewards_internal/api/code/unfoldml/htgen.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py index c510d7333..e38cb5741 100644 --- a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py @@ -61,8 +61,7 @@ def gen_triples( v = loads(chunk) if not isinstance(v, dict): v = None - except JSONDecodeError as e: - print(f"JSON decode error: {e}") + except JSONDecodeError: v = None if v is not None: yield v From 6e4298c0f9daae6190f65fa2c643ded18e8a9c79 Mon Sep 17 00:00:00 2001 From: Marco Zocca Date: Mon, 3 Mar 2025 04:50:43 +0100 Subject: [PATCH 21/22] cleanup and docstrings --- .../api/code/unfoldml/htgen.py | 2 +- src/open_r1/rewards_internal/code/htgen.py | 24 +------------------ 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py index e38cb5741..9b2cc5b4d 100644 --- a/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py +++ b/src/open_r1/rewards_internal/api/code/unfoldml/htgen.py @@ -124,7 +124,7 @@ def verify_triple_v2( :param preconditions: :param program: :param postconditions: - :returns: whether the SMT verifier agrees with the label provided: + :returns: dict with the proof result. The 'result' key has values {'proven_total', 'proven_partial', 'indeterminate', 'failed'} """ triple = {"pre": preconditions, "program": program, "post": postconditions} diff --git a/src/open_r1/rewards_internal/code/htgen.py b/src/open_r1/rewards_internal/code/htgen.py index 09285ba64..3c5e065d5 100644 --- a/src/open_r1/rewards_internal/code/htgen.py +++ b/src/open_r1/rewards_internal/code/htgen.py @@ -8,7 +8,7 @@ def quotes(s: str): return f"```{s}```" -# # header of all prompts, describing Hoare logic at a high level +# # header to be put in front of all task prompts, describing Hoare logic at a high level prompt_hdr = ( "Below you are given a Python program triple, made of a precondition predicate, " "a sequence of program statements, and a postcondition predicate. " @@ -301,25 +301,3 @@ def compare(predicted, actual): return [compare(predicted, actual) for (predicted, actual) in zip(completions, ground_truth)] - -# # # verify against API - -# def totality_oracle_reward(completions, triples, **kwargs): -# """ -# verification callback for GRPOTRainer -# :param completions: list of truthy values produced by the model -# :param triples: list of program triples dicts {"pre":: string, "program":: string, "post:: string} -# """ - -# def verify(pre, program, post, is_total): -# res = verify_triple_33( -# preconditions = pre, -# program = program, -# postconditions = post, -# is_total = is_total -# ) -# if res is not None: -# prediction = res['prediction_is_correct'] -# return 1.0 if prediction else 0.0 -# else: -# return 0.0 From aa97e62cef99e7c318c5de75eb7c295d534f61e2 Mon Sep 17 00:00:00 2001 From: Marco Z Date: Mon, 3 Mar 2025 04:08:43 +0000 Subject: [PATCH 22/22] add test for verify v2 endpoint --- tests/test_api.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_api.py b/tests/test_api.py index 253459918..7a3c39fa0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,6 @@ import unittest -from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple +from open_r1.rewards_internal.api.code.unfoldml.htgen import gen_triples, verify_triple, verify_triple_v2 class TestApi(unittest.TestCase): @@ -25,6 +25,20 @@ def test_verify_triple_result(self): res_not_ok = o["prediction_is_correct"] self.assertEqual(res_not_ok, False) + def test_verify_v2_triple_result(self): + preconditions = "True" # trivial precondition + program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4" + post_ok = "v5 == (0 - v3)" # post-condition that verifies + post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify + # # should return True + o = verify_triple_v2(preconditions=preconditions, program=program, postconditions=post_ok) + res_ok = o["result"] + self.assertEqual(res_ok, 'proven_total') + # # should return False + o = verify_triple_v2(preconditions=preconditions, program=program, postconditions=post_not_ok) + res_not_ok = o["result"] + self.assertEqual(res_not_ok, 'failed') + if __name__ == "__main__": unittest.main()