-
Notifications
You must be signed in to change notification settings - Fork 2
Add user metadata to LmPrompt class #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
907cc4b
a848c4d
b80cd02
0c29dc5
e0ed21f
0cf836a
fc34031
7b928b4
a345f4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1107,6 +1107,58 @@ def test_system_prompt(lm): | |
| assert pred.completion_text.strip() == "PARIS" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name) | ||
| def test_user_metadata_predict(lm): | ||
| # Test with string metadata | ||
| prompt = LmPrompt( | ||
| "Count from 1 to 5:", | ||
| max_tokens=10, | ||
| cache=False, | ||
| temperature=0, | ||
| user_metadata="test_metadata_string" | ||
| ) | ||
| pred = lm.predict(prompt) | ||
| assert pred.prompt.user_metadata == "test_metadata_string" | ||
|
|
||
| # Test with dictionary metadata | ||
| metadata_dict = {"label": "counting", "id": 12345} | ||
| prompt = LmPrompt( | ||
| "Count from 1 to 5:", | ||
| max_tokens=10, | ||
| cache=False, | ||
| temperature=0, | ||
| user_metadata=metadata_dict | ||
| ) | ||
| pred = lm.predict(prompt) | ||
| assert pred.prompt.user_metadata == metadata_dict | ||
| assert pred.prompt.user_metadata["label"] == "counting" | ||
| assert pred.prompt.user_metadata["id"] == 12345 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name) | ||
| def test_user_metadata_predict_many(lm): | ||
| # Create prompts with different metadata | ||
| prompts = [ | ||
| LmPrompt( | ||
| f"What is {i} + {i}?", | ||
| max_tokens=5, | ||
| cache=False, | ||
| temperature=0, | ||
| user_metadata={"question_id": i, "expected_answer": i*2} | ||
| ) | ||
| for i in range(1, 4) | ||
| ] | ||
|
|
||
| # Use predict_many and verify metadata is preserved | ||
| results = list(lm.predict_many(prompts, CompletionWindow(None))) | ||
|
||
|
|
||
| for i, result in enumerate(results): | ||
| expected_id = i + 1 | ||
| expected_answer = expected_id * 2 | ||
| assert result.prompt.user_metadata["question_id"] == expected_id | ||
| assert result.prompt.user_metadata["expected_answer"] == expected_answer | ||
|
|
||
|
|
||
DNGros marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.parametrize( | ||
| "lm", CHAT_MODELS, | ||
| ids=get_model_name | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.