Skip to content

Commit 410758a

Browse files
authored
Merge pull request #11 from DaiseyCode/copilot/fix-10
Add user metadata to LmPrompt class
2 parents 27326c3 + a345f4a commit 410758a

7 files changed

Lines changed: 145 additions & 29 deletions

File tree

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
"-s", "-vv", "--runslow",
44
],
55
"python.testing.unittestEnabled": false,
6-
"python.testing.pytestEnabled": true
6+
"python.testing.pytestEnabled": true,
7+
"python.defaultInterpreterPath": "~/miniconda3/envs/lmwrapper/bin/python",
78
}

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ breaking changes, Y is new features or larger non-breaking changes, and Z is sma
77
However, it is still pre-1.0 software, and does not claim to
88
be super stable.
99

10+
## [0.16.4.0]
11+
12+
### Added
13+
- Added metadata field to LmPrompt class with generic type support
1014

1115
## [0.16.3.0]
1216

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def make_prompts(data) -> list[LmPrompt]:
201201
max_tokens=10,
202202
temperature=0,
203203
cache=True,
204+
metadata={"country": country}
204205
)
205206
for country in data
206207
]
@@ -216,9 +217,10 @@ predictions = lm.predict_many(
216217
# the non-batching API at a higher cost.
217218
) # The batch is submitted here
218219

219-
for ex, pred in zip(data, predictions): # Will wait for the batch to complete
220-
print(f"Country: {ex} --- Capital: {pred.completion_text}")
221-
if ex == "France": assert pred.completion_text == "Paris"
220+
for pred in predictions: # Will wait for the batch to complete
221+
country = pred.prompt.metadata['country']
222+
print(f"Country: {country} --- Capital: {pred.completion_text}")
223+
if country == "France": assert pred.completion_text == "Paris"
222224
# ...
223225
```
224226

@@ -364,7 +366,7 @@ please make a Github Issue.
364366
- [X] Anthropic interface (basic)
365367
- [X] Claude system messages
366368
- [X] Use the huggingface chat templates for chat models if available
367-
- [ ] Be able to add user metadata to a prompt
369+
- [X] Be able to add metadata to a prompt
368370
- [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria)
369371
- [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc)
370372
- [ ] sort through usage of quantized models

lmwrapper/abstract_predictor.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from abc import abstractmethod
33
from collections.abc import Callable, Iterable
44
from sqlite3 import OperationalError
5+
from typing import TypeVar, Union
56

67
from ratemate import RateLimit
78

89
from lmwrapper.batch_config import CompletionWindow
910
from lmwrapper.sqlcache_struct import BatchPredictionPlaceholder
10-
from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt
11+
from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt, T
1112

1213

1314
class LmPredictor:
@@ -23,13 +24,13 @@ def __init__(
2324

2425
self._disk_cache = SqlBackedCache(self)
2526

26-
def find_prediction_class(self, prompt):
27-
return LmPrediction
27+
def find_prediction_class(self, prompt: LmPrompt[T]):
28+
return LmPrediction[T]
2829

2930
def predict(
3031
self,
31-
prompt: LmPrompt | str | LM_CHAT_DIALOG_COERCIBLE_TYPES,
32-
) -> LmPrediction | list[LmPrediction]:
32+
prompt: LmPrompt[T] | str | LM_CHAT_DIALOG_COERCIBLE_TYPES,
33+
) -> LmPrediction[T] | list[LmPrediction[T]]:
3334
prompt = self._cast_prompt(prompt)
3435
should_cache = self._cache_default if prompt.cache is None else prompt.cache
3536
if should_cache and prompt.model_internals_request is not None:
@@ -74,7 +75,7 @@ def predict(
7475
return vals[0]
7576
return vals
7677

77-
def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:
78+
def _read_cached_values(self, prompt: LmPrompt[T]) -> list[LmPrediction[T]]:
7879
"""
7980
Checks the cache for any matches of the prompt. Returns a list
8081
as if num_completions is >1 we might have multiple items
@@ -101,9 +102,9 @@ def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:
101102

102103
def predict_many(
103104
self,
104-
prompts: list[LmPrompt],
105+
prompts: list[LmPrompt[T]],
105106
completion_window: CompletionWindow,
106-
) -> Iterable[LmPrediction | list[LmPrediction]]:
107+
) -> Iterable[LmPrediction[T] | list[LmPrediction[T]]]:
107108
self._validate_predict_many_prompts(prompts)
108109
for prompt in prompts:
109110
val = self.predict(prompt)
@@ -125,11 +126,11 @@ def _validate_predict_many_prompts(self, prompts):
125126

126127
def remove_prompt_from_cache(
127128
self,
128-
prompt: str | LmPrompt,
129+
prompt: str | LmPrompt[T],
129130
) -> bool:
130131
return self._disk_cache.delete(prompt)
131132

132-
def _validate_prompt(self, prompt: LmPrompt, raise_on_invalid: bool = True) -> bool:
133+
def _validate_prompt(self, prompt: LmPrompt[T], raise_on_invalid: bool = True) -> bool:
133134
"""Called on prediction to make sure the prompt is valid for the model"""
134135
return True
135136

@@ -145,11 +146,11 @@ def supports_token_operations(self) -> bool:
145146
@abstractmethod
146147
def _predict_maybe_cached(
147148
self,
148-
prompt: LmPrompt,
149-
) -> list[LmPrediction]:
149+
prompt: LmPrompt[T],
150+
) -> LmPrediction[T] | list[LmPrediction[T]]:
150151
pass
151152

152-
def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
153+
def _cast_prompt(self, prompt: str | LmPrompt[T] | list) -> LmPrompt[T]:
153154
if isinstance(prompt, str):
154155
return LmPrompt(prompt, 100)
155156
if isinstance(prompt, list):
@@ -180,14 +181,14 @@ def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
180181
)
181182
raise ValueError(msg)
182183

183-
def estimate_tokens_in_prompt(self, prompt: LmPrompt) -> int:
184+
def estimate_tokens_in_prompt(self, prompt: LmPrompt[T]) -> int:
184185
raise NotImplementedError
185186

186187
@property
187188
def token_limit(self):
188189
raise NotImplementedError
189190

190-
def could_completion_go_over_token_limit(self, prompt: LmPrompt) -> bool:
191+
def could_completion_go_over_token_limit(self, prompt: LmPrompt[T]) -> bool:
191192
count = self.estimate_tokens_in_prompt(prompt)
192193
return (
193194
count + (prompt.max_tokens or self.default_tokens_generated)
@@ -255,10 +256,12 @@ def supports_prefilled_chat(self) -> bool:
255256

256257

257258
def get_mock_predictor(
258-
predict_func: Callable[[LmPrompt], LmPrediction] = None,
259+
predict_func: Callable[[LmPrompt[T]], LmPrediction[T]] = None,
259260
is_chat_model: bool = False,
260261
):
261262
"""Gets a mock predictor. By default returns whatever the prompt txt is"""
263+
264+
S = TypeVar('S') # Local TypeVar for the mock predictor
262265

263266
class MockPredict(LmPredictor):
264267
def get_model_cache_key(self):
@@ -268,7 +271,7 @@ def get_model_cache_key(self):
268271
def is_chat_model(self) -> bool:
269272
return is_chat_model
270273

271-
def _predict_maybe_cached(self, prompt):
274+
def _predict_maybe_cached(self, prompt: LmPrompt[S]) -> LmPrediction[S]:
272275
if predict_func is None:
273276
return LmPrediction(
274277
prompt.get_text_as_string_default_form(),

lmwrapper/structs.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import contextlib
22
import dataclasses
3+
import json
34
import pickle
45
import statistics
56
from dataclasses import dataclass, field
6-
from typing import Any, Optional, Union
7+
from typing import Any, Generic, Optional, TypeVar, Union
78

89
from lmwrapper.internals import ModelInternalsResults
910
from lmwrapper.utils import StrEnum
1011

12+
# Define TypeVar for user metadata
13+
T = TypeVar('T')
14+
1115
LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[
1216
str,
1317
list[Union["LmChatTurn", tuple[str, str], dict, str]],
@@ -34,7 +38,7 @@ class StopMode(StrEnum):
3438

3539

3640
@dataclass(frozen=True)
37-
class LmPrompt:
41+
class LmPrompt(Generic[T]):
3842
text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES
3943
"""The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES
4044
which can become a LmChatDialog (such as a list of strings) it will be converted
@@ -125,6 +129,11 @@ class LmPrompt:
125129
"""Whether or not to add special tokens when encoding the prompt."""
126130
model_internals_request: Optional["ModelInternalsRequest"] = None
127131
"""Used to attempt to get hidden states and attentions from the model."""
132+
metadata: Optional[T] = None
133+
"""Optional user-defined metadata that gets transferred to the resulting LmPrediction.
134+
This is not used for caching and can be any type. It's useful for tracking
135+
additional information with each prompt and prediction (e.g., ground truth labels,
136+
problem identifiers)."""
128137

129138
# TODO: make a auto_reduce_max_tokens to reduce when might go over.
130139

@@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str:
208217
else:
209218
return self.text
210219

211-
def dict_serialize(self) -> dict:
220+
def dict_serialize(self, include_metadata: bool = False) -> dict:
212221
"""
213222
Serialize the prompt into a json-compatible dictionary. Note this is not
214223
guaranteed to be the same as the JSON representation for use
215224
in an openai api call. This is just for serialization purposes.
225+
226+
Args:
227+
include_metadata: Whether to include metadata in serialization.
228+
Default is False since metadata should not be part of the cache key.
216229
"""
217230
out = {
218231
"max_tokens": self.max_tokens,
@@ -232,6 +245,16 @@ def dict_serialize(self) -> dict:
232245
out["text"] = self.get_text_as_chat().as_dicts()
233246
else:
234247
out["text"] = self.text
248+
249+
if include_metadata and self.metadata is not None:
250+
try:
251+
# Test if it's JSON serializable
252+
json.dumps(self.metadata)
253+
out["metadata"] = self.metadata
254+
except (TypeError, ValueError):
255+
# If metadata isn't JSON serializable, leave it out
256+
pass
257+
235258
return out
236259

237260

@@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str:
316339

317340

318341
@dataclass
319-
class LmPrediction:
342+
class LmPrediction(Generic[T]):
320343
completion_text: str | None
321344
"""The new text generated. It might be None if errors"""
322-
prompt: LmPrompt
345+
prompt: LmPrompt[T]
323346
metad: Any
324347
internals: ModelInternalsResults | None = field(default=None, kw_only=True)
325348
error_message: str | None = field(default=None, kw_only=True)
@@ -458,10 +481,11 @@ def dict_serialize(
458481
self,
459482
pull_out_props: bool = True,
460483
include_metad: bool = False,
484+
include_metadata: bool = False,
461485
) -> dict[str, Any]:
462486
out = {
463487
"completion_text": self.completion_text,
464-
"prompt": self.prompt.dict_serialize(),
488+
"prompt": self.prompt.dict_serialize(include_metadata=include_metadata),
465489
"was_cached": self.was_cached,
466490
"error_message": self.error_message,
467491
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ include = ["lmwrapper*"]
2525

2626
[project]
2727
name = "lmwrapper"
28-
version = "0.16.3.0"
28+
version = "0.16.4.0"
2929
authors = [
3030
{ name = "David Gros" },
3131
{ name = "Claudio Spiess" },

test/test_models_common.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pytest
1010

11+
from lmwrapper.abstract_predictor import LmPredictor
1112
from lmwrapper.batch_config import CompletionWindow
1213
from lmwrapper.caching import cache_dir, clear_cache_dir
1314
from lmwrapper.huggingface_wrapper import get_huggingface_lm
@@ -1107,6 +1108,87 @@ def test_system_prompt(lm):
11071108
assert pred.completion_text.strip() == "PARIS"
11081109

11091110

1111+
@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
1112+
def test_metadata_predict(lm):
1113+
# Test with string metadata
1114+
prompt = LmPrompt(
1115+
"Count from 1 to 5:",
1116+
max_tokens=10,
1117+
cache=False,
1118+
temperature=0,
1119+
metadata="test_metadata_string"
1120+
)
1121+
pred = lm.predict(prompt)
1122+
assert pred.prompt.metadata == "test_metadata_string"
1123+
1124+
# Test with dictionary metadata
1125+
metadata_dict = {"label": "counting", "id": 12345}
1126+
prompt = LmPrompt(
1127+
"Count from 1 to 5:",
1128+
max_tokens=10,
1129+
cache=False,
1130+
temperature=0,
1131+
metadata=metadata_dict
1132+
)
1133+
pred = lm.predict(prompt)
1134+
assert pred.prompt.metadata == metadata_dict
1135+
assert pred.prompt.metadata["label"] == "counting"
1136+
assert pred.prompt.metadata["id"] == 12345
1137+
1138+
1139+
@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
1140+
def test_metadata_predict_many(lm):
1141+
# Create prompts with different metadata
1142+
prompts = [
1143+
LmPrompt(
1144+
f"What is {i} + {i}?",
1145+
max_tokens=5,
1146+
cache=False,
1147+
temperature=0,
1148+
metadata={"question_id": i, "expected_answer": i*2}
1149+
)
1150+
for i in range(1, 4)
1151+
]
1152+
1153+
# Use predict_many and verify metadata is preserved
1154+
results = list(lm.predict_many(prompts, CompletionWindow.ASAP))
1155+
1156+
for i, result in enumerate(results):
1157+
expected_id = i + 1
1158+
expected_answer = expected_id * 2
1159+
assert result.prompt.metadata["question_id"] == expected_id
1160+
assert result.prompt.metadata["expected_answer"] == expected_answer
1161+
1162+
1163+
@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name)
1164+
def test_metadata_with_cache_hit(lm: LmPredictor):
1165+
"""Test that metadata is preserved when there's a cache hit."""
1166+
# First prompt with metadata to populate the cache
1167+
prompt1 = LmPrompt(
1168+
"What is 5 + 5?",
1169+
max_tokens=5,
1170+
cache=True,
1171+
temperature=0,
1172+
metadata={"original": True, "id": 1}
1173+
)
1174+
result1 = lm.predict(prompt1)
1175+
1176+
# Same prompt with different metadata should use cache but keep the new metadata
1177+
prompt2 = LmPrompt(
1178+
"What is 5 + 5?",
1179+
max_tokens=5,
1180+
cache=True,
1181+
temperature=0,
1182+
metadata={"original": False, "id": 2}
1183+
)
1184+
result2 = lm.predict(prompt2)
1185+
1186+
# Check that the result is from cache but has the new metadata
1187+
assert result2.was_cached
1188+
assert result2.prompt.metadata["original"] == False
1189+
assert result2.prompt.metadata["id"] == 2
1190+
1191+
11101192
@pytest.mark.parametrize(
11111193
"lm", CHAT_MODELS,
11121194
ids=get_model_name

0 commit comments

Comments
 (0)