diff --git a/.vscode/settings.json b/.vscode/settings.json index aed4db4..bfb00fc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,6 @@ "-s", "-vv", "--runslow", ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "~/miniconda3/envs/lmwrapper/bin/python", } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 83e7085..931d9e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ breaking changes, Y is new features or larger non-breaking changes, and Z is sma However, it is still pre-1.0 software, and does not claim to be super stable. +## [0.16.4.0] + +### Added +- Added metadata field to LmPrompt class with generic type support ## [0.16.3.0] diff --git a/README.md b/README.md index edec27e..5cd7cf5 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,7 @@ def make_prompts(data) -> list[LmPrompt]: max_tokens=10, temperature=0, cache=True, + metadata={"country": country} ) for country in data ] @@ -216,9 +217,10 @@ predictions = lm.predict_many( # the non-batching API at a higher cost. ) # The batch is submitted here -for ex, pred in zip(data, predictions): # Will wait for the batch to complete - print(f"Country: {ex} --- Capital: {pred.completion_text}") - if ex == "France": assert pred.completion_text == "Paris" +for pred in predictions: # Will wait for the batch to complete + country = pred.prompt.metadata['country'] + print(f"Country: {country} --- Capital: {pred.completion_text}") + if country == "France": assert pred.completion_text == "Paris" # ... ``` @@ -364,7 +366,7 @@ please make a Github Issue. - [X] Anthropic interface (basic) - [X] Claude system messages - [X] Use the huggingface chat templates for chat models if available -- [ ] Be able to add user metadata to a prompt +- [X] Be able to add metadata to a prompt - [ ] Automatic cache eviction to limit count or disk size (right now have to run a SQL query to delete entries before a certain time or matching your criteria) - [ ] Multimodal/images in super easy format (like automatically process pil, opencv, etc) - [ ] sort through usage of quantized models diff --git a/lmwrapper/abstract_predictor.py b/lmwrapper/abstract_predictor.py index 06b13fb..deea242 100644 --- a/lmwrapper/abstract_predictor.py +++ b/lmwrapper/abstract_predictor.py @@ -2,12 +2,13 @@ from abc import abstractmethod from collections.abc import Callable, Iterable from sqlite3 import OperationalError +from typing import TypeVar, Union from ratemate import RateLimit from lmwrapper.batch_config import CompletionWindow from lmwrapper.sqlcache_struct import BatchPredictionPlaceholder -from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt +from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt, T class LmPredictor: @@ -23,13 +24,13 @@ def __init__( self._disk_cache = SqlBackedCache(self) - def find_prediction_class(self, prompt): - return LmPrediction + def find_prediction_class(self, prompt: LmPrompt[T]): + return LmPrediction[T] def predict( self, - prompt: LmPrompt | str | LM_CHAT_DIALOG_COERCIBLE_TYPES, - ) -> LmPrediction | list[LmPrediction]: + prompt: LmPrompt[T] | str | LM_CHAT_DIALOG_COERCIBLE_TYPES, + ) -> LmPrediction[T] | list[LmPrediction[T]]: prompt = self._cast_prompt(prompt) should_cache = self._cache_default if prompt.cache is None else prompt.cache if should_cache and prompt.model_internals_request is not None: @@ -74,7 +75,7 @@ def predict( return vals[0] return vals - def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]: + def _read_cached_values(self, prompt: LmPrompt[T]) -> list[LmPrediction[T]]: """ Checks the cache for any matches of the prompt. Returns a list as if num_completions is >1 we might have multiple items @@ -101,9 +102,9 @@ def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]: def predict_many( self, - prompts: list[LmPrompt], + prompts: list[LmPrompt[T]], completion_window: CompletionWindow, - ) -> Iterable[LmPrediction | list[LmPrediction]]: + ) -> Iterable[LmPrediction[T] | list[LmPrediction[T]]]: self._validate_predict_many_prompts(prompts) for prompt in prompts: val = self.predict(prompt) @@ -125,11 +126,11 @@ def _validate_predict_many_prompts(self, prompts): def remove_prompt_from_cache( self, - prompt: str | LmPrompt, + prompt: str | LmPrompt[T], ) -> bool: return self._disk_cache.delete(prompt) - def _validate_prompt(self, prompt: LmPrompt, raise_on_invalid: bool = True) -> bool: + def _validate_prompt(self, prompt: LmPrompt[T], raise_on_invalid: bool = True) -> bool: """Called on prediction to make sure the prompt is valid for the model""" return True @@ -145,11 +146,11 @@ def supports_token_operations(self) -> bool: @abstractmethod def _predict_maybe_cached( self, - prompt: LmPrompt, - ) -> list[LmPrediction]: + prompt: LmPrompt[T], + ) -> LmPrediction[T] | list[LmPrediction[T]]: pass - def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt: + def _cast_prompt(self, prompt: str | LmPrompt[T] | list) -> LmPrompt[T]: if isinstance(prompt, str): return LmPrompt(prompt, 100) if isinstance(prompt, list): @@ -180,14 +181,14 @@ def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt: ) raise ValueError(msg) - def estimate_tokens_in_prompt(self, prompt: LmPrompt) -> int: + def estimate_tokens_in_prompt(self, prompt: LmPrompt[T]) -> int: raise NotImplementedError @property def token_limit(self): raise NotImplementedError - def could_completion_go_over_token_limit(self, prompt: LmPrompt) -> bool: + def could_completion_go_over_token_limit(self, prompt: LmPrompt[T]) -> bool: count = self.estimate_tokens_in_prompt(prompt) return ( count + (prompt.max_tokens or self.default_tokens_generated) @@ -255,10 +256,12 @@ def supports_prefilled_chat(self) -> bool: def get_mock_predictor( - predict_func: Callable[[LmPrompt], LmPrediction] = None, + predict_func: Callable[[LmPrompt[T]], LmPrediction[T]] = None, is_chat_model: bool = False, ): """Gets a mock predictor. By default returns whatever the prompt txt is""" + + S = TypeVar('S') # Local TypeVar for the mock predictor class MockPredict(LmPredictor): def get_model_cache_key(self): @@ -268,7 +271,7 @@ def get_model_cache_key(self): def is_chat_model(self) -> bool: return is_chat_model - def _predict_maybe_cached(self, prompt): + def _predict_maybe_cached(self, prompt: LmPrompt[S]) -> LmPrediction[S]: if predict_func is None: return LmPrediction( prompt.get_text_as_string_default_form(), diff --git a/lmwrapper/structs.py b/lmwrapper/structs.py index cd193e6..282100c 100644 --- a/lmwrapper/structs.py +++ b/lmwrapper/structs.py @@ -1,13 +1,17 @@ import contextlib import dataclasses +import json import pickle import statistics from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any, Generic, Optional, TypeVar, Union from lmwrapper.internals import ModelInternalsResults from lmwrapper.utils import StrEnum +# Define TypeVar for user metadata +T = TypeVar('T') + LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[ str, list[Union["LmChatTurn", tuple[str, str], dict, str]], @@ -34,7 +38,7 @@ class StopMode(StrEnum): @dataclass(frozen=True) -class LmPrompt: +class LmPrompt(Generic[T]): text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES """The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES which can become a LmChatDialog (such as a list of strings) it will be converted @@ -125,6 +129,11 @@ class LmPrompt: """Whether or not to add special tokens when encoding the prompt.""" model_internals_request: Optional["ModelInternalsRequest"] = None """Used to attempt to get hidden states and attentions from the model.""" + metadata: Optional[T] = None + """Optional user-defined metadata that gets transferred to the resulting LmPrediction. + This is not used for caching and can be any type. It's useful for tracking + additional information with each prompt and prediction (e.g., ground truth labels, + problem identifiers).""" # TODO: make a auto_reduce_max_tokens to reduce when might go over. @@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str: else: return self.text - def dict_serialize(self) -> dict: + def dict_serialize(self, include_metadata: bool = False) -> dict: """ Serialize the prompt into a json-compatible dictionary. Note this is not guaranteed to be the same as the JSON representation for use in an openai api call. This is just for serialization purposes. + + Args: + include_metadata: Whether to include metadata in serialization. + Default is False since metadata should not be part of the cache key. """ out = { "max_tokens": self.max_tokens, @@ -232,6 +245,16 @@ def dict_serialize(self) -> dict: out["text"] = self.get_text_as_chat().as_dicts() else: out["text"] = self.text + + if include_metadata and self.metadata is not None: + try: + # Test if it's JSON serializable + json.dumps(self.metadata) + out["metadata"] = self.metadata + except (TypeError, ValueError): + # If metadata isn't JSON serializable, leave it out + pass + return out @@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str: @dataclass -class LmPrediction: +class LmPrediction(Generic[T]): completion_text: str | None """The new text generated. It might be None if errors""" - prompt: LmPrompt + prompt: LmPrompt[T] metad: Any internals: ModelInternalsResults | None = field(default=None, kw_only=True) error_message: str | None = field(default=None, kw_only=True) @@ -458,10 +481,11 @@ def dict_serialize( self, pull_out_props: bool = True, include_metad: bool = False, + include_metadata: bool = False, ) -> dict[str, Any]: out = { "completion_text": self.completion_text, - "prompt": self.prompt.dict_serialize(), + "prompt": self.prompt.dict_serialize(include_metadata=include_metadata), "was_cached": self.was_cached, "error_message": self.error_message, } diff --git a/pyproject.toml b/pyproject.toml index db81f3e..9d710a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ include = ["lmwrapper*"] [project] name = "lmwrapper" -version = "0.16.3.0" +version = "0.16.4.0" authors = [ { name = "David Gros" }, { name = "Claudio Spiess" }, diff --git a/test/test_models_common.py b/test/test_models_common.py index ea6f759..1b982cc 100644 --- a/test/test_models_common.py +++ b/test/test_models_common.py @@ -8,6 +8,7 @@ import numpy as np import pytest +from lmwrapper.abstract_predictor import LmPredictor from lmwrapper.batch_config import CompletionWindow from lmwrapper.caching import cache_dir, clear_cache_dir from lmwrapper.huggingface_wrapper import get_huggingface_lm @@ -1107,6 +1108,87 @@ def test_system_prompt(lm): assert pred.completion_text.strip() == "PARIS" +@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name) +def test_metadata_predict(lm): + # Test with string metadata + prompt = LmPrompt( + "Count from 1 to 5:", + max_tokens=10, + cache=False, + temperature=0, + metadata="test_metadata_string" + ) + pred = lm.predict(prompt) + assert pred.prompt.metadata == "test_metadata_string" + + # Test with dictionary metadata + metadata_dict = {"label": "counting", "id": 12345} + prompt = LmPrompt( + "Count from 1 to 5:", + max_tokens=10, + cache=False, + temperature=0, + metadata=metadata_dict + ) + pred = lm.predict(prompt) + assert pred.prompt.metadata == metadata_dict + assert pred.prompt.metadata["label"] == "counting" + assert pred.prompt.metadata["id"] == 12345 + + +@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name) +def test_metadata_predict_many(lm): + # Create prompts with different metadata + prompts = [ + LmPrompt( + f"What is {i} + {i}?", + max_tokens=5, + cache=False, + temperature=0, + metadata={"question_id": i, "expected_answer": i*2} + ) + for i in range(1, 4) + ] + + # Use predict_many and verify metadata is preserved + results = list(lm.predict_many(prompts, CompletionWindow.ASAP)) + + for i, result in enumerate(results): + expected_id = i + 1 + expected_answer = expected_id * 2 + assert result.prompt.metadata["question_id"] == expected_id + assert result.prompt.metadata["expected_answer"] == expected_answer + + +@pytest.mark.parametrize("lm", ALL_MODELS, ids=get_model_name) +def test_metadata_with_cache_hit(lm: LmPredictor): + """Test that metadata is preserved when there's a cache hit.""" + # First prompt with metadata to populate the cache + prompt1 = LmPrompt( + "What is 5 + 5?", + max_tokens=5, + cache=True, + temperature=0, + metadata={"original": True, "id": 1} + ) + result1 = lm.predict(prompt1) + + # Same prompt with different metadata should use cache but keep the new metadata + prompt2 = LmPrompt( + "What is 5 + 5?", + max_tokens=5, + cache=True, + temperature=0, + metadata={"original": False, "id": 2} + ) + result2 = lm.predict(prompt2) + + # Check that the result is from cache but has the new metadata + assert result2.was_cached + assert result2.prompt.metadata["original"] == False + assert result2.prompt.metadata["id"] == 2 + + @pytest.mark.parametrize( "lm", CHAT_MODELS, ids=get_model_name