Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ please make a Github Issue.
- [X] Anthropic interface (basic)
- [X] Claude system messages
- [X] Use the huggingface chat templates for chat models if available
- [ ] Be able to add user metadata to a prompt
- [X] Be able to add user metadata to a prompt
- [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
- [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc)
- [ ] sort through usage of quantized models
Expand Down
26 changes: 24 additions & 2 deletions lmwrapper/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class LmPrompt:
"""Whether or not to add special tokens when encoding the prompt."""
model_internals_request: Optional["ModelInternalsRequest"] = None
"""Used to attempt to get hidden states and attentions from the model."""
user_metadata: Any = None
"""Optional user-defined metadata that gets transferred to the resulting LmPrediction.
This is not used for caching and can be any type. It's useful for tracking
additional information with each prompt and prediction (e.g., ground truth labels,
problem identifiers)."""

# TODO: make a auto_reduce_max_tokens to reduce when might go over.

Expand Down Expand Up @@ -208,11 +213,15 @@ def get_text_as_string_default_form(self) -> str:
else:
return self.text

def dict_serialize(self) -> dict:
def dict_serialize(self, include_user_metadata: bool = False) -> dict:
"""
Serialize the prompt into a json-compatible dictionary. Note this is not
guaranteed to be the same as the JSON representation for use
in an openai api call. This is just for serialization purposes.

Args:
include_user_metadata: Whether to include user_metadata in serialization.
Default is False since user_metadata should not be part of the cache key.
"""
out = {
"max_tokens": self.max_tokens,
Expand All @@ -232,6 +241,18 @@ def dict_serialize(self) -> dict:
out["text"] = self.get_text_as_chat().as_dicts()
else:
out["text"] = self.text

if include_user_metadata and self.user_metadata is not None:
try:
# Try to serialize the user_metadata
import json
# Test if it's JSON serializable
json.dumps(self.user_metadata)
out["user_metadata"] = self.user_metadata
except (TypeError, ValueError):
# If metadata isn't JSON serializable, leave it out
pass

return out


Expand Down Expand Up @@ -458,10 +479,11 @@ def dict_serialize(
self,
pull_out_props: bool = True,
include_metad: bool = False,
include_user_metadata: bool = False,
) -> dict[str, Any]:
out = {
"completion_text": self.completion_text,
"prompt": self.prompt.dict_serialize(),
"prompt": self.prompt.dict_serialize(include_user_metadata=include_user_metadata),
"was_cached": self.was_cached,
"error_message": self.error_message,
}
Expand Down
81 changes: 81 additions & 0 deletions test/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,87 @@ def test_system_prompt(lm):
assert pred.completion_text.strip() == "PARIS"


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict(lm):
# Test with string metadata
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata="test_metadata_string"
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == "test_metadata_string"

# Test with dictionary metadata
metadata_dict = {"label": "counting", "id": 12345}
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata=metadata_dict
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == metadata_dict
assert pred.prompt.user_metadata["label"] == "counting"
assert pred.prompt.user_metadata["id"] == 12345


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict_many(lm):
# Create prompts with different metadata
prompts = [
LmPrompt(
f"What is {i} + {i}?",
max_tokens=5,
cache=False,
temperature=0,
user_metadata={"question_id": i, "expected_answer": i*2}
)
for i in range(1, 4)
]

# Use predict_many and verify metadata is preserved
results = list(lm.predict_many(prompts, CompletionWindow.ASAP))

for i, result in enumerate(results):
expected_id = i + 1
expected_answer = expected_id * 2
assert result.prompt.user_metadata["question_id"] == expected_id
assert result.prompt.user_metadata["expected_answer"] == expected_answer


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_with_cache_hit(lm):
"""Test that user_metadata is preserved when there's a cache hit."""
# First prompt with metadata to populate the cache
prompt1 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": True, "id": 1}
)
result1 = lm.predict(prompt1)

# Same prompt with different metadata should use cache but keep the new metadata
prompt2 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": False, "id": 2}
)
result2 = lm.predict(prompt2)

# Check that the result is from cache but has the new metadata
assert result2.was_cached
assert result2.prompt.user_metadata["original"] == False
assert result2.prompt.user_metadata["id"] == 2


@pytest.mark.parametrize(
"lm", CHAT_MODELS,
ids=get_model_name
Expand Down