Skip to content
Merged
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ please make a Github Issue.
- [X] Anthropic interface (basic)
- [X] Claude system messages
- [X] Use the huggingface chat templates for chat models if available
- [ ] Be able to add user metadata to a prompt
- [X] Be able to add user metadata to a prompt
- [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
- [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc)
- [ ] sort through usage of quantized models
Expand Down
41 changes: 22 additions & 19 deletions lmwrapper/abstract_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from abc import abstractmethod
from collections.abc import Callable, Iterable
from sqlite3 import OperationalError
from typing import Generic, TypeVar, Union

from ratemate import RateLimit

from lmwrapper.batch_config import CompletionWindow
from lmwrapper.sqlcache_struct import BatchPredictionPlaceholder
from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt
from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt, T


class LmPredictor:
class LmPredictor(Generic[T]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the predictor should be generic. It can accept different kinds of prompts. It just should the same type as the input prompt

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fc34031. I've removed the Generic[T] from the LmPredictor class definition while still maintaining the type parameters in method signatures.

_rate_limit: RateLimit | None = None

def __init__(
Expand All @@ -23,13 +24,13 @@ def __init__(

self._disk_cache = SqlBackedCache(self)

def find_prediction_class(self, prompt):
return LmPrediction
def find_prediction_class(self, prompt: LmPrompt[T]):
return LmPrediction[T]

def predict(
self,
prompt: LmPrompt | str | LM_CHAT_DIALOG_COERCIBLE_TYPES,
) -> LmPrediction | list[LmPrediction]:
prompt: Union[LmPrompt[T], str, LM_CHAT_DIALOG_COERCIBLE_TYPES],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why get rid of the | for union syntax. It was fine before?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fc34031. I've restored the pipe syntax (|) for union types throughout the file.

) -> Union[LmPrediction[T], list[LmPrediction[T]]]:
prompt = self._cast_prompt(prompt)
should_cache = self._cache_default if prompt.cache is None else prompt.cache
if should_cache and prompt.model_internals_request is not None:
Expand Down Expand Up @@ -74,7 +75,7 @@ def predict(
return vals[0]
return vals

def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:
def _read_cached_values(self, prompt: LmPrompt[T]) -> list[LmPrediction[T]]:
"""
Checks the cache for any matches of the prompt. Returns a list
as if num_completions is >1 we might have multiple items
Expand All @@ -101,9 +102,9 @@ def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:

def predict_many(
self,
prompts: list[LmPrompt],
prompts: list[LmPrompt[T]],
completion_window: CompletionWindow,
) -> Iterable[LmPrediction | list[LmPrediction]]:
) -> Iterable[Union[LmPrediction[T], list[LmPrediction[T]]]]:
self._validate_predict_many_prompts(prompts)
for prompt in prompts:
val = self.predict(prompt)
Expand All @@ -125,11 +126,11 @@ def _validate_predict_many_prompts(self, prompts):

def remove_prompt_from_cache(
self,
prompt: str | LmPrompt,
prompt: Union[str, LmPrompt[T]],
) -> bool:
return self._disk_cache.delete(prompt)

def _validate_prompt(self, prompt: LmPrompt, raise_on_invalid: bool = True) -> bool:
def _validate_prompt(self, prompt: LmPrompt[T], raise_on_invalid: bool = True) -> bool:
"""Called on prediction to make sure the prompt is valid for the model"""
return True

Expand All @@ -145,11 +146,11 @@ def supports_token_operations(self) -> bool:
@abstractmethod
def _predict_maybe_cached(
self,
prompt: LmPrompt,
) -> list[LmPrediction]:
prompt: LmPrompt[T],
) -> Union[LmPrediction[T], list[LmPrediction[T]]]:
pass

def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
def _cast_prompt(self, prompt: Union[str, LmPrompt[T], list]) -> LmPrompt[T]:
if isinstance(prompt, str):
return LmPrompt(prompt, 100)
if isinstance(prompt, list):
Expand Down Expand Up @@ -180,14 +181,14 @@ def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
)
raise ValueError(msg)

def estimate_tokens_in_prompt(self, prompt: LmPrompt) -> int:
def estimate_tokens_in_prompt(self, prompt: LmPrompt[T]) -> int:
raise NotImplementedError

@property
def token_limit(self):
raise NotImplementedError

def could_completion_go_over_token_limit(self, prompt: LmPrompt) -> bool:
def could_completion_go_over_token_limit(self, prompt: LmPrompt[T]) -> bool:
count = self.estimate_tokens_in_prompt(prompt)
return (
count + (prompt.max_tokens or self.default_tokens_generated)
Expand Down Expand Up @@ -255,20 +256,22 @@ def supports_prefilled_chat(self) -> bool:


def get_mock_predictor(
predict_func: Callable[[LmPrompt], LmPrediction] = None,
predict_func: Callable[[LmPrompt[T]], LmPrediction[T]] = None,
is_chat_model: bool = False,
):
"""Gets a mock predictor. By default returns whatever the prompt txt is"""

S = TypeVar('S') # Local TypeVar for the mock predictor

class MockPredict(LmPredictor):
class MockPredict(LmPredictor[S]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the predictor is not subscriptable

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in a345f4a. Removed the generic type parameter from MockPredict class definition, changing class MockPredict(LmPredictor[S]): to class MockPredict(LmPredictor):.

def get_model_cache_key(self):
return "mock_predictor"

@property
def is_chat_model(self) -> bool:
return is_chat_model

def _predict_maybe_cached(self, prompt):
def _predict_maybe_cached(self, prompt: LmPrompt[S]) -> LmPrediction[S]:
if predict_func is None:
return LmPrediction(
prompt.get_text_as_string_default_form(),
Expand Down
36 changes: 30 additions & 6 deletions lmwrapper/structs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
import dataclasses
import json
import pickle
import statistics
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any, Generic, Optional, TypeVar, Union

from lmwrapper.internals import ModelInternalsResults
from lmwrapper.utils import StrEnum

# Define TypeVar for user metadata
T = TypeVar('T')

LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[
str,
list[Union["LmChatTurn", tuple[str, str], dict, str]],
Expand All @@ -34,7 +38,7 @@ class StopMode(StrEnum):


@dataclass(frozen=True)
class LmPrompt:
class LmPrompt(Generic[T]):
text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES
"""The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES
which can become a LmChatDialog (such as a list of strings) it will be converted
Expand Down Expand Up @@ -125,6 +129,11 @@ class LmPrompt:
"""Whether or not to add special tokens when encoding the prompt."""
model_internals_request: Optional["ModelInternalsRequest"] = None
"""Used to attempt to get hidden states and attentions from the model."""
user_metadata: Optional[T] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about this more, maybe I just prefer metadata instead of user_metadata. A little less verbose, and think this seems better. Refractor and rename all the uses of this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a345f4a. Renamed all instances of user_metadata to metadata throughout the codebase, including in structs.py, tests, README.md, and CHANGELOG.md.

"""Optional user-defined metadata that gets transferred to the resulting LmPrediction.
This is not used for caching and can be any type. It's useful for tracking
additional information with each prompt and prediction (e.g., ground truth labels,
problem identifiers)."""

# TODO: make a auto_reduce_max_tokens to reduce when might go over.

Expand Down Expand Up @@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str:
else:
return self.text

def dict_serialize(self) -> dict:
def dict_serialize(self, include_user_metadata: bool = False) -> dict:
"""
Serialize the prompt into a json-compatible dictionary. Note this is not
guaranteed to be the same as the JSON representation for use
in an openai api call. This is just for serialization purposes.

Args:
include_user_metadata: Whether to include user_metadata in serialization.
Default is False since user_metadata should not be part of the cache key.
"""
out = {
"max_tokens": self.max_tokens,
Expand All @@ -232,6 +245,16 @@ def dict_serialize(self) -> dict:
out["text"] = self.get_text_as_chat().as_dicts()
else:
out["text"] = self.text

if include_user_metadata and self.user_metadata is not None:
try:
# Test if it's JSON serializable
json.dumps(self.user_metadata)
out["user_metadata"] = self.user_metadata
except (TypeError, ValueError):
# If metadata isn't JSON serializable, leave it out
pass

return out


Expand Down Expand Up @@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str:


@dataclass
class LmPrediction:
class LmPrediction(Generic[T]):
completion_text: str | None
"""The new text generated. It might be None if errors"""
prompt: LmPrompt
prompt: LmPrompt[T]
metad: Any
internals: ModelInternalsResults | None = field(default=None, kw_only=True)
error_message: str | None = field(default=None, kw_only=True)
Expand Down Expand Up @@ -458,10 +481,11 @@ def dict_serialize(
self,
pull_out_props: bool = True,
include_metad: bool = False,
include_user_metadata: bool = False,
) -> dict[str, Any]:
out = {
"completion_text": self.completion_text,
"prompt": self.prompt.dict_serialize(),
"prompt": self.prompt.dict_serialize(include_user_metadata=include_user_metadata),
"was_cached": self.was_cached,
"error_message": self.error_message,
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include = ["lmwrapper*"]

[project]
name = "lmwrapper"
version = "0.16.3.0"
version = "0.16.4.0"
authors = [
{ name = "David Gros" },
{ name = "Claudio Spiess" },
Expand Down
81 changes: 81 additions & 0 deletions test/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,87 @@ def test_system_prompt(lm):
assert pred.completion_text.strip() == "PARIS"


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict(lm):
# Test with string metadata
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata="test_metadata_string"
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == "test_metadata_string"

# Test with dictionary metadata
metadata_dict = {"label": "counting", "id": 12345}
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata=metadata_dict
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == metadata_dict
assert pred.prompt.user_metadata["label"] == "counting"
assert pred.prompt.user_metadata["id"] == 12345


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict_many(lm):
# Create prompts with different metadata
prompts = [
LmPrompt(
f"What is {i} + {i}?",
max_tokens=5,
cache=False,
temperature=0,
user_metadata={"question_id": i, "expected_answer": i*2}
)
for i in range(1, 4)
]

# Use predict_many and verify metadata is preserved
results = list(lm.predict_many(prompts, CompletionWindow.ASAP))

for i, result in enumerate(results):
expected_id = i + 1
expected_answer = expected_id * 2
assert result.prompt.user_metadata["question_id"] == expected_id
assert result.prompt.user_metadata["expected_answer"] == expected_answer


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_with_cache_hit(lm):
"""Test that user_metadata is preserved when there's a cache hit."""
# First prompt with metadata to populate the cache
prompt1 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": True, "id": 1}
)
result1 = lm.predict(prompt1)

# Same prompt with different metadata should use cache but keep the new metadata
prompt2 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": False, "id": 2}
)
result2 = lm.predict(prompt2)

# Check that the result is from cache but has the new metadata
assert result2.was_cached
assert result2.prompt.user_metadata["original"] == False
assert result2.prompt.user_metadata["id"] == 2


@pytest.mark.parametrize(
"lm", CHAT_MODELS,
ids=get_model_name
Expand Down