Skip to content

Commit b8a91fc

Browse files
authored
Add tests for message weights (#335)
1 parent c6db833 commit b8a91fc

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

tests/unit/test_files_checks.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,92 @@ def test_check_jsonl_empty_messages(tmp_path: Path):
303303
assert (
304304
"Expected a non-empty list of messages. Found empty list" in report["message"]
305305
)
306+
307+
308+
def test_check_jsonl_valid_weights_all_messages(tmp_path: Path):
309+
file = tmp_path / "valid_weights_all.jsonl"
310+
content = [
311+
{
312+
"messages": [
313+
{"role": "user", "content": "Hello", "weight": 1},
314+
{"role": "assistant", "content": "Hi there!", "weight": 0},
315+
{"role": "user", "content": "How are you?", "weight": 1},
316+
{"role": "assistant", "content": "I'm doing well!", "weight": 1},
317+
]
318+
},
319+
{
320+
"messages": [
321+
{"role": "system", "content": "You are helpful", "weight": 0},
322+
{"role": "user", "content": "What's the weather?", "weight": 1},
323+
{"role": "assistant", "content": "It's sunny today!", "weight": 1},
324+
]
325+
},
326+
]
327+
with file.open("w") as f:
328+
f.write("\n".join(json.dumps(item) for item in content))
329+
330+
report = check_file(file)
331+
assert report["is_check_passed"]
332+
assert report["num_samples"] == len(content)
333+
334+
335+
def test_check_jsonl_valid_weights_mixed_with_none(tmp_path: Path):
336+
file = tmp_path / "valid_weights_mixed.jsonl"
337+
content = [
338+
{
339+
"messages": [
340+
{"role": "user", "content": "Hello", "weight": 1},
341+
{"role": "assistant", "content": "Hi there!", "weight": 0},
342+
{"role": "user", "content": "How are you?"},
343+
{"role": "assistant", "content": "I'm doing well!"},
344+
]
345+
},
346+
{
347+
"messages": [
348+
{"role": "user", "content": "What's the weather?"},
349+
{"role": "assistant", "content": "It's sunny today!"},
350+
]
351+
},
352+
]
353+
with file.open("w") as f:
354+
f.write("\n".join(json.dumps(item) for item in content))
355+
356+
report = check_file(file)
357+
assert report["is_check_passed"]
358+
assert report["num_samples"] == len(content)
359+
360+
361+
def test_check_jsonl_invalid_weight_float(tmp_path: Path):
362+
file = tmp_path / "invalid_weight_float.jsonl"
363+
content = [
364+
{
365+
"messages": [
366+
{"role": "user", "content": "Hello", "weight": 1.0},
367+
{"role": "assistant", "content": "Hi there!", "weight": 0},
368+
]
369+
}
370+
]
371+
with file.open("w") as f:
372+
f.write("\n".join(json.dumps(item) for item in content))
373+
374+
report = check_file(file)
375+
assert not report["is_check_passed"]
376+
assert "Weight must be an integer" in report["message"]
377+
378+
379+
def test_check_jsonl_invalid_weight(tmp_path: Path):
380+
file = tmp_path / "invalid_weight.jsonl"
381+
content = [
382+
{
383+
"messages": [
384+
{"role": "user", "content": "Hello", "weight": 2},
385+
{"role": "assistant", "content": "Hi there!", "weight": 0},
386+
]
387+
}
388+
]
389+
with file.open("w") as f:
390+
f.write("\n".join(json.dumps(item) for item in content))
391+
392+
report = check_file(file)
393+
assert not report["is_check_passed"]
394+
assert "Weight must be either 0 or 1" in report["message"]

0 commit comments

Comments
 (0)