Skip to content
Merged
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ please make a Github Issue.
- [X] Anthropic interface (basic)
- [X] Claude system messages
- [X] Use the huggingface chat templates for chat models if available
- [ ] Be able to add user metadata to a prompt
- [X] Be able to add user metadata to a prompt
- [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
- [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc)
- [ ] sort through usage of quantized models
Expand Down
36 changes: 30 additions & 6 deletions lmwrapper/structs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
import dataclasses
import json
import pickle
import statistics
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any, Generic, Optional, TypeVar, Union

from lmwrapper.internals import ModelInternalsResults
from lmwrapper.utils import StrEnum

# Define TypeVar for user metadata
T = TypeVar('T')

LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[
str,
list[Union["LmChatTurn", tuple[str, str], dict, str]],
Expand All @@ -34,7 +38,7 @@ class StopMode(StrEnum):


@dataclass(frozen=True)
class LmPrompt:
class LmPrompt(Generic[T]):
text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES
"""The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES
which can become a LmChatDialog (such as a list of strings) it will be converted
Expand Down Expand Up @@ -125,6 +129,11 @@ class LmPrompt:
"""Whether or not to add special tokens when encoding the prompt."""
model_internals_request: Optional["ModelInternalsRequest"] = None
"""Used to attempt to get hidden states and attentions from the model."""
user_metadata: Optional[T] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about this more, maybe I just prefer metadata instead of user_metadata. A little less verbose, and think this seems better. Refractor and rename all the uses of this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a345f4a. Renamed all instances of user_metadata to metadata throughout the codebase, including in structs.py, tests, README.md, and CHANGELOG.md.

"""Optional user-defined metadata that gets transferred to the resulting LmPrediction.
This is not used for caching and can be any type. It's useful for tracking
additional information with each prompt and prediction (e.g., ground truth labels,
problem identifiers)."""

# TODO: make a auto_reduce_max_tokens to reduce when might go over.

Expand Down Expand Up @@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str:
else:
return self.text

def dict_serialize(self) -> dict:
def dict_serialize(self, include_user_metadata: bool = False) -> dict:
"""
Serialize the prompt into a json-compatible dictionary. Note this is not
guaranteed to be the same as the JSON representation for use
in an openai api call. This is just for serialization purposes.

Args:
include_user_metadata: Whether to include user_metadata in serialization.
Default is False since user_metadata should not be part of the cache key.
"""
out = {
"max_tokens": self.max_tokens,
Expand All @@ -232,6 +245,16 @@ def dict_serialize(self) -> dict:
out["text"] = self.get_text_as_chat().as_dicts()
else:
out["text"] = self.text

if include_user_metadata and self.user_metadata is not None:
try:
# Test if it's JSON serializable
json.dumps(self.user_metadata)
out["user_metadata"] = self.user_metadata
except (TypeError, ValueError):
# If metadata isn't JSON serializable, leave it out
pass

return out


Expand Down Expand Up @@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str:


@dataclass
class LmPrediction:
class LmPrediction(Generic[T]):
completion_text: str | None
"""The new text generated. It might be None if errors"""
prompt: LmPrompt
prompt: LmPrompt[T]
metad: Any
internals: ModelInternalsResults | None = field(default=None, kw_only=True)
error_message: str | None = field(default=None, kw_only=True)
Expand Down Expand Up @@ -458,10 +481,11 @@ def dict_serialize(
self,
pull_out_props: bool = True,
include_metad: bool = False,
include_user_metadata: bool = False,
) -> dict[str, Any]:
out = {
"completion_text": self.completion_text,
"prompt": self.prompt.dict_serialize(),
"prompt": self.prompt.dict_serialize(include_user_metadata=include_user_metadata),
"was_cached": self.was_cached,
"error_message": self.error_message,
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include = ["lmwrapper*"]

[project]
name = "lmwrapper"
version = "0.16.3.0"
version = "0.16.4.0"
authors = [
{ name = "David Gros" },
{ name = "Claudio Spiess" },
Expand Down
81 changes: 81 additions & 0 deletions test/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,87 @@ def test_system_prompt(lm):
assert pred.completion_text.strip() == "PARIS"


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict(lm):
# Test with string metadata
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata="test_metadata_string"
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == "test_metadata_string"

# Test with dictionary metadata
metadata_dict = {"label": "counting", "id": 12345}
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata=metadata_dict
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == metadata_dict
assert pred.prompt.user_metadata["label"] == "counting"
assert pred.prompt.user_metadata["id"] == 12345


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict_many(lm):
# Create prompts with different metadata
prompts = [
LmPrompt(
f"What is {i} + {i}?",
max_tokens=5,
cache=False,
temperature=0,
user_metadata={"question_id": i, "expected_answer": i*2}
)
for i in range(1, 4)
]

# Use predict_many and verify metadata is preserved
results = list(lm.predict_many(prompts, CompletionWindow.ASAP))

for i, result in enumerate(results):
expected_id = i + 1
expected_answer = expected_id * 2
assert result.prompt.user_metadata["question_id"] == expected_id
assert result.prompt.user_metadata["expected_answer"] == expected_answer


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_with_cache_hit(lm):
"""Test that user_metadata is preserved when there's a cache hit."""
# First prompt with metadata to populate the cache
prompt1 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": True, "id": 1}
)
result1 = lm.predict(prompt1)

# Same prompt with different metadata should use cache but keep the new metadata
prompt2 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": False, "id": 2}
)
result2 = lm.predict(prompt2)

# Check that the result is from cache but has the new metadata
assert result2.was_cached
assert result2.prompt.user_metadata["original"] == False
assert result2.prompt.user_metadata["id"] == 2


@pytest.mark.parametrize(
"lm", CHAT_MODELS,
ids=get_model_name
Expand Down