Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"-s", "-vv", "--runslow",
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"python.defaultInterpreterPath": "~/miniconda3/envs/lmwrapper/bin/python",
}
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ breaking changes, Y is new features or larger non-breaking changes, and Z is sma
However, it is still pre-1.0 software, and does not claim to
be super stable.

## [0.16.4.0]

### Added
- Added user_metadata field to LmPrompt class with generic type support

## [0.16.3.0]

Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def make_prompts(data) -> list[LmPrompt]:
max_tokens=10,
temperature=0,
cache=True,
user_metadata={"country": country}
)
for country in data
]
Expand All @@ -216,9 +217,10 @@ predictions = lm.predict_many(
# the non-batching API at a higher cost.
) # The batch is submitted here

for ex, pred in zip(data, predictions): # Will wait for the batch to complete
print(f"Country: {ex} --- Capital: {pred.completion_text}")
if ex == "France": assert pred.completion_text == "Paris"
for pred in predictions: # Will wait for the batch to complete
country = pred.prompt.user_metadata['country']
print(f"Country: {country} --- Capital: {pred.completion_text}")
if country == "France": assert pred.completion_text == "Paris"
# ...
```

Expand Down Expand Up @@ -364,7 +366,7 @@ please make a Github Issue.
- [X] Anthropic interface (basic)
- [X] Claude system messages
- [X] Use the huggingface chat templates for chat models if available
- [ ] Be able to add user metadata to a prompt
- [X] Be able to add user metadata to a prompt
- [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
- [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc)
- [ ] sort through usage of quantized models
Expand Down
39 changes: 21 additions & 18 deletions lmwrapper/abstract_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from abc import abstractmethod
from collections.abc import Callable, Iterable
from sqlite3 import OperationalError
from typing import TypeVar, Union

from ratemate import RateLimit

from lmwrapper.batch_config import CompletionWindow
from lmwrapper.sqlcache_struct import BatchPredictionPlaceholder
from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt
from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt, T


class LmPredictor:
Expand All @@ -23,13 +24,13 @@ def __init__(

self._disk_cache = SqlBackedCache(self)

def find_prediction_class(self, prompt):
return LmPrediction
def find_prediction_class(self, prompt: LmPrompt[T]):
return LmPrediction[T]

def predict(
self,
prompt: LmPrompt | str | LM_CHAT_DIALOG_COERCIBLE_TYPES,
) -> LmPrediction | list[LmPrediction]:
prompt: LmPrompt[T] | str | LM_CHAT_DIALOG_COERCIBLE_TYPES,
) -> LmPrediction[T] | list[LmPrediction[T]]:
prompt = self._cast_prompt(prompt)
should_cache = self._cache_default if prompt.cache is None else prompt.cache
if should_cache and prompt.model_internals_request is not None:
Expand Down Expand Up @@ -74,7 +75,7 @@ def predict(
return vals[0]
return vals

def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:
def _read_cached_values(self, prompt: LmPrompt[T]) -> list[LmPrediction[T]]:
"""
Checks the cache for any matches of the prompt. Returns a list
as if num_completions is >1 we might have multiple items
Expand All @@ -101,9 +102,9 @@ def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:

def predict_many(
self,
prompts: list[LmPrompt],
prompts: list[LmPrompt[T]],
completion_window: CompletionWindow,
) -> Iterable[LmPrediction | list[LmPrediction]]:
) -> Iterable[LmPrediction[T] | list[LmPrediction[T]]]:
self._validate_predict_many_prompts(prompts)
for prompt in prompts:
val = self.predict(prompt)
Expand All @@ -125,11 +126,11 @@ def _validate_predict_many_prompts(self, prompts):

def remove_prompt_from_cache(
self,
prompt: str | LmPrompt,
prompt: str | LmPrompt[T],
) -> bool:
return self._disk_cache.delete(prompt)

def _validate_prompt(self, prompt: LmPrompt, raise_on_invalid: bool = True) -> bool:
def _validate_prompt(self, prompt: LmPrompt[T], raise_on_invalid: bool = True) -> bool:
"""Called on prediction to make sure the prompt is valid for the model"""
return True

Expand All @@ -145,11 +146,11 @@ def supports_token_operations(self) -> bool:
@abstractmethod
def _predict_maybe_cached(
self,
prompt: LmPrompt,
) -> list[LmPrediction]:
prompt: LmPrompt[T],
) -> LmPrediction[T] | list[LmPrediction[T]]:
pass

def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
def _cast_prompt(self, prompt: str | LmPrompt[T] | list) -> LmPrompt[T]:
if isinstance(prompt, str):
return LmPrompt(prompt, 100)
if isinstance(prompt, list):
Expand Down Expand Up @@ -180,14 +181,14 @@ def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
)
raise ValueError(msg)

def estimate_tokens_in_prompt(self, prompt: LmPrompt) -> int:
def estimate_tokens_in_prompt(self, prompt: LmPrompt[T]) -> int:
raise NotImplementedError

@property
def token_limit(self):
raise NotImplementedError

def could_completion_go_over_token_limit(self, prompt: LmPrompt) -> bool:
def could_completion_go_over_token_limit(self, prompt: LmPrompt[T]) -> bool:
count = self.estimate_tokens_in_prompt(prompt)
return (
count + (prompt.max_tokens or self.default_tokens_generated)
Expand Down Expand Up @@ -255,20 +256,22 @@ def supports_prefilled_chat(self) -> bool:


def get_mock_predictor(
predict_func: Callable[[LmPrompt], LmPrediction] = None,
predict_func: Callable[[LmPrompt[T]], LmPrediction[T]] = None,
is_chat_model: bool = False,
):
"""Gets a mock predictor. By default returns whatever the prompt txt is"""

S = TypeVar('S') # Local TypeVar for the mock predictor

class MockPredict(LmPredictor):
class MockPredict(LmPredictor[S]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the predictor is not subscriptable

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in a345f4a. Removed the generic type parameter from MockPredict class definition, changing class MockPredict(LmPredictor[S]): to class MockPredict(LmPredictor):.

def get_model_cache_key(self):
return "mock_predictor"

@property
def is_chat_model(self) -> bool:
return is_chat_model

def _predict_maybe_cached(self, prompt):
def _predict_maybe_cached(self, prompt: LmPrompt[S]) -> LmPrediction[S]:
if predict_func is None:
return LmPrediction(
prompt.get_text_as_string_default_form(),
Expand Down
36 changes: 30 additions & 6 deletions lmwrapper/structs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
import dataclasses
import json
import pickle
import statistics
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any, Generic, Optional, TypeVar, Union

from lmwrapper.internals import ModelInternalsResults
from lmwrapper.utils import StrEnum

# Define TypeVar for user metadata
T = TypeVar('T')

LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[
str,
list[Union["LmChatTurn", tuple[str, str], dict, str]],
Expand All @@ -34,7 +38,7 @@ class StopMode(StrEnum):


@dataclass(frozen=True)
class LmPrompt:
class LmPrompt(Generic[T]):
text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES
"""The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES
which can become a LmChatDialog (such as a list of strings) it will be converted
Expand Down Expand Up @@ -125,6 +129,11 @@ class LmPrompt:
"""Whether or not to add special tokens when encoding the prompt."""
model_internals_request: Optional["ModelInternalsRequest"] = None
"""Used to attempt to get hidden states and attentions from the model."""
user_metadata: Optional[T] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about this more, maybe I just prefer metadata instead of user_metadata. A little less verbose, and think this seems better. Refractor and rename all the uses of this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a345f4a. Renamed all instances of user_metadata to metadata throughout the codebase, including in structs.py, tests, README.md, and CHANGELOG.md.

"""Optional user-defined metadata that gets transferred to the resulting LmPrediction.
This is not used for caching and can be any type. It's useful for tracking
additional information with each prompt and prediction (e.g., ground truth labels,
problem identifiers)."""

# TODO: make a auto_reduce_max_tokens to reduce when might go over.

Expand Down Expand Up @@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str:
else:
return self.text

def dict_serialize(self) -> dict:
def dict_serialize(self, include_user_metadata: bool = False) -> dict:
"""
Serialize the prompt into a json-compatible dictionary. Note this is not
guaranteed to be the same as the JSON representation for use
in an openai api call. This is just for serialization purposes.

Args:
include_user_metadata: Whether to include user_metadata in serialization.
Default is False since user_metadata should not be part of the cache key.
"""
out = {
"max_tokens": self.max_tokens,
Expand All @@ -232,6 +245,16 @@ def dict_serialize(self) -> dict:
out["text"] = self.get_text_as_chat().as_dicts()
else:
out["text"] = self.text

if include_user_metadata and self.user_metadata is not None:
try:
# Test if it's JSON serializable
json.dumps(self.user_metadata)
out["user_metadata"] = self.user_metadata
except (TypeError, ValueError):
# If metadata isn't JSON serializable, leave it out
pass

return out


Expand Down Expand Up @@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str:


@dataclass
class LmPrediction:
class LmPrediction(Generic[T]):
completion_text: str | None
"""The new text generated. It might be None if errors"""
prompt: LmPrompt
prompt: LmPrompt[T]
metad: Any
internals: ModelInternalsResults | None = field(default=None, kw_only=True)
error_message: str | None = field(default=None, kw_only=True)
Expand Down Expand Up @@ -458,10 +481,11 @@ def dict_serialize(
self,
pull_out_props: bool = True,
include_metad: bool = False,
include_user_metadata: bool = False,
) -> dict[str, Any]:
out = {
"completion_text": self.completion_text,
"prompt": self.prompt.dict_serialize(),
"prompt": self.prompt.dict_serialize(include_user_metadata=include_user_metadata),
"was_cached": self.was_cached,
"error_message": self.error_message,
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include = ["lmwrapper*"]

[project]
name = "lmwrapper"
version = "0.16.3.0"
version = "0.16.4.0"
authors = [
{ name = "David Gros" },
{ name = "Claudio Spiess" },
Expand Down
82 changes: 82 additions & 0 deletions test/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pytest

from lmwrapper.abstract_predictor import LmPredictor
from lmwrapper.batch_config import CompletionWindow
from lmwrapper.caching import cache_dir, clear_cache_dir
from lmwrapper.huggingface_wrapper import get_huggingface_lm
Expand Down Expand Up @@ -1107,6 +1108,87 @@ def test_system_prompt(lm):
assert pred.completion_text.strip() == "PARIS"


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict(lm):
# Test with string metadata
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata="test_metadata_string"
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == "test_metadata_string"

# Test with dictionary metadata
metadata_dict = {"label": "counting", "id": 12345}
prompt = LmPrompt(
"Count from 1 to 5:",
max_tokens=10,
cache=False,
temperature=0,
user_metadata=metadata_dict
)
pred = lm.predict(prompt)
assert pred.prompt.user_metadata == metadata_dict
assert pred.prompt.user_metadata["label"] == "counting"
assert pred.prompt.user_metadata["id"] == 12345


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_predict_many(lm):
# Create prompts with different metadata
prompts = [
LmPrompt(
f"What is {i} + {i}?",
max_tokens=5,
cache=False,
temperature=0,
user_metadata={"question_id": i, "expected_answer": i*2}
)
for i in range(1, 4)
]

# Use predict_many and verify metadata is preserved
results = list(lm.predict_many(prompts, CompletionWindow.ASAP))

for i, result in enumerate(results):
expected_id = i + 1
expected_answer = expected_id * 2
assert result.prompt.user_metadata["question_id"] == expected_id
assert result.prompt.user_metadata["expected_answer"] == expected_answer


@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
def test_user_metadata_with_cache_hit(lm: LmPredictor):
"""Test that user_metadata is preserved when there's a cache hit."""
# First prompt with metadata to populate the cache
prompt1 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": True, "id": 1}
)
result1 = lm.predict(prompt1)

# Same prompt with different metadata should use cache but keep the new metadata
prompt2 = LmPrompt(
"What is 5 + 5?",
max_tokens=5,
cache=True,
temperature=0,
user_metadata={"original": False, "id": 2}
)
result2 = lm.predict(prompt2)

# Check that the result is from cache but has the new metadata
assert result2.was_cached
assert result2.prompt.user_metadata["original"] == False
assert result2.prompt.user_metadata["id"] == 2


@pytest.mark.parametrize(
"lm", CHAT_MODELS,
ids=get_model_name
Expand Down