Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ please make a Github Issue.
- [X] Anthropic interface (basic)
- [X] Claude system messages
- [X] Use the huggingface chat templates for chat models if available
- [ ] Be able to add user metadata to a prompt
- [X] Be able to add user metadata to a prompt
- [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
- [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc)
- [ ] sort through usage of quantized models
Expand Down
8 changes: 8 additions & 0 deletions lmwrapper/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class LmPrompt:
"""Whether or not to add special tokens when encoding the prompt."""
model_internals_request: Optional["ModelInternalsRequest"] = None
"""Used to attempt to get hidden states and attentions from the model."""
user_metadata: Any = None
"""Optional user-defined metadata that gets transferred to the resulting LmPrediction.
This is not used for caching and can be any type. It's useful for tracking
additional information with each prompt and prediction (e.g., ground truth labels,
problem identifiers)."""

# TODO: make a auto_reduce_max_tokens to reduce when might go over.

Expand Down Expand Up @@ -213,6 +218,9 @@ def dict_serialize(self) -> dict:
Serialize the prompt into a json-compatible dictionary. Note this is not
guaranteed to be the same as the JSON representation for use
in an openai api call. This is just for serialization purposes.

Note: user_metadata is deliberately excluded from serialization as it should
not be part of the cache key.
"""
out = {
"max_tokens": self.max_tokens,
Expand Down
52 changes: 52 additions & 0 deletions test/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,58 @@ def test_system_prompt(lm):
assert pred.completion_text.strip() == "PARIS"


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict(lm):
# Test with string metadata
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata="test_metadata_string"
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == "test_metadata_string"

# Test with dictionary metadata
metadata_dict = {"label": "counting", "id": 12345}
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata=metadata_dict
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == metadata_dict
assert pred.prompt.user_metadata["label"] == "counting"
assert pred.prompt.user_metadata["id"] == 12345


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict_many(lm):
# Create prompts with different metadata
prompts = [
LmPrompt(
f"What is {i} + {i}?",
max_tokens=5,
cache=False,
temperature=0,
user_metadata={"question_id": i, "expected_answer": i*2}
)
for i in range(1, 4)
]

# Use predict_many and verify metadata is preserved
results = list(lm.predict_many(prompts, CompletionWindow(None)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't the right kind of completion window...? Should be the ASAP windo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b80cd02. Changed to use CompletionWindow.ASAP instead of CompletionWindow(None).


for i, result in enumerate(results):
expected_id = i + 1
expected_answer = expected_id * 2
assert result.prompt.user_metadata["question_id"] == expected_id
assert result.prompt.user_metadata["expected_answer"] == expected_answer


@pytest.mark.parametrize(
"lm", CHAT_MODELS,
ids=get_model_name
Expand Down