Skip to content

Commit 69b44ca

Browse files
author
Marco Zocca
committed
add API test
1 parent 591f3c1 commit 69b44ca

File tree

4 files changed

+42
-2
lines changed

4 files changed

+42
-2
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"peft>=0.14.0",
6363
"pytest",
6464
"python-dotenv",
65+
"requests",
6566
"ruff>=0.9.0",
6667
"safetensors>=0.3.3",
6768
"sentencepiece>=0.1.99",

src/open_r1/rewards/api/code/unfoldml/htgen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def verify_triple_33(
9595
v = res.json()
9696
except JSONDecodeError:
9797
v = None
98-
print(v)
98+
return v
9999
# else:
100100
except HTTPError as he:
101101
print(f"HTTP error: {he}")

tests/test_api.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
3+
from open_r1.rewards.api.code.unfoldml.htgen import gen_triples_33, verify_triple_33
4+
5+
6+
class TestApi(unittest.TestCase):
7+
def test_gen_triples_structure():
8+
n_stmt = 3
9+
for o in gen_triples_33(n_examples = 1, n_stmt = n_stmt):
10+
len_program = len(o['program'])
11+
self.assertEqual(len_program, n_stmt)
12+
def test_verify_triple_result():
13+
is_total = True
14+
preconditions = "True" # trivial precondition
15+
program = "v4 = (0 - v3)\nv3 = v3\nv5 = v4"
16+
post_ok = "v5 == (0 - v3)" # post-condition that verifies
17+
post_not_ok = "v5 == (1 - v3)" # post-condition that does not verify
18+
# # should return True
19+
o = verify_triple_33(
20+
is_total = is_total,
21+
preconditions = preconditions,
22+
program = program,
23+
postconditions = post_ok
24+
)
25+
res_ok = o['prediction_is_correct']
26+
self.assertEqual(res_ok, True)
27+
# # should return False
28+
o = verify_triple_33(
29+
is_total = is_total,
30+
preconditions = preconditions,
31+
program = program,
32+
postconditions = post_not_ok
33+
)
34+
res_not_ok = o['prediction_is_correct']
35+
salf.assertEqual(res_not_ok, False)
36+
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main()

tests/test_rewards.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
reasoning_steps_reward,
1010
)
1111

12-
1312
class TestRewards(unittest.TestCase):
1413
def test_accuracy_reward_correct_answer(self):
1514
"""Test accuracy_reward with a correct answer."""

0 commit comments

Comments
 (0)