-
Notifications
You must be signed in to change notification settings - Fork 2
Add user metadata to LmPrompt class #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
907cc4b
a848c4d
b80cd02
0c29dc5
e0ed21f
0cf836a
fc34031
7b928b4
a345f4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,15 +2,16 @@ | |
| from abc import abstractmethod | ||
| from collections.abc import Callable, Iterable | ||
| from sqlite3 import OperationalError | ||
| from typing import Generic, TypeVar, Union | ||
|
|
||
| from ratemate import RateLimit | ||
|
|
||
| from lmwrapper.batch_config import CompletionWindow | ||
| from lmwrapper.sqlcache_struct import BatchPredictionPlaceholder | ||
| from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt | ||
| from lmwrapper.structs import LM_CHAT_DIALOG_COERCIBLE_TYPES, LmPrediction, LmPrompt, T | ||
|
|
||
|
|
||
| class LmPredictor: | ||
| class LmPredictor(Generic[T]): | ||
| _rate_limit: RateLimit | None = None | ||
|
|
||
| def __init__( | ||
|
|
@@ -23,13 +24,13 @@ def __init__( | |
|
|
||
| self._disk_cache = SqlBackedCache(self) | ||
|
|
||
| def find_prediction_class(self, prompt): | ||
| return LmPrediction | ||
| def find_prediction_class(self, prompt: LmPrompt[T]): | ||
| return LmPrediction[T] | ||
|
|
||
| def predict( | ||
| self, | ||
| prompt: LmPrompt | str | LM_CHAT_DIALOG_COERCIBLE_TYPES, | ||
| ) -> LmPrediction | list[LmPrediction]: | ||
| prompt: Union[LmPrompt[T], str, LM_CHAT_DIALOG_COERCIBLE_TYPES], | ||
|
||
| ) -> Union[LmPrediction[T], list[LmPrediction[T]]]: | ||
| prompt = self._cast_prompt(prompt) | ||
| should_cache = self._cache_default if prompt.cache is None else prompt.cache | ||
| if should_cache and prompt.model_internals_request is not None: | ||
|
|
@@ -74,7 +75,7 @@ def predict( | |
| return vals[0] | ||
| return vals | ||
|
|
||
| def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]: | ||
| def _read_cached_values(self, prompt: LmPrompt[T]) -> list[LmPrediction[T]]: | ||
| """ | ||
| Checks the cache for any matches of the prompt. Returns a list | ||
| as if num_completions is >1 we might have multiple items | ||
|
|
@@ -101,9 +102,9 @@ def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]: | |
|
|
||
| def predict_many( | ||
| self, | ||
| prompts: list[LmPrompt], | ||
| prompts: list[LmPrompt[T]], | ||
| completion_window: CompletionWindow, | ||
| ) -> Iterable[LmPrediction | list[LmPrediction]]: | ||
| ) -> Iterable[Union[LmPrediction[T], list[LmPrediction[T]]]]: | ||
| self._validate_predict_many_prompts(prompts) | ||
| for prompt in prompts: | ||
| val = self.predict(prompt) | ||
|
|
@@ -125,11 +126,11 @@ def _validate_predict_many_prompts(self, prompts): | |
|
|
||
| def remove_prompt_from_cache( | ||
| self, | ||
| prompt: str | LmPrompt, | ||
| prompt: Union[str, LmPrompt[T]], | ||
| ) -> bool: | ||
| return self._disk_cache.delete(prompt) | ||
|
|
||
| def _validate_prompt(self, prompt: LmPrompt, raise_on_invalid: bool = True) -> bool: | ||
| def _validate_prompt(self, prompt: LmPrompt[T], raise_on_invalid: bool = True) -> bool: | ||
| """Called on prediction to make sure the prompt is valid for the model""" | ||
| return True | ||
|
|
||
|
|
@@ -145,11 +146,11 @@ def supports_token_operations(self) -> bool: | |
| @abstractmethod | ||
| def _predict_maybe_cached( | ||
| self, | ||
| prompt: LmPrompt, | ||
| ) -> list[LmPrediction]: | ||
| prompt: LmPrompt[T], | ||
| ) -> Union[LmPrediction[T], list[LmPrediction[T]]]: | ||
| pass | ||
|
|
||
| def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt: | ||
| def _cast_prompt(self, prompt: Union[str, LmPrompt[T], list]) -> LmPrompt[T]: | ||
| if isinstance(prompt, str): | ||
| return LmPrompt(prompt, 100) | ||
| if isinstance(prompt, list): | ||
|
|
@@ -180,14 +181,14 @@ def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt: | |
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| def estimate_tokens_in_prompt(self, prompt: LmPrompt) -> int: | ||
| def estimate_tokens_in_prompt(self, prompt: LmPrompt[T]) -> int: | ||
| raise NotImplementedError | ||
|
|
||
| @property | ||
| def token_limit(self): | ||
| raise NotImplementedError | ||
|
|
||
| def could_completion_go_over_token_limit(self, prompt: LmPrompt) -> bool: | ||
| def could_completion_go_over_token_limit(self, prompt: LmPrompt[T]) -> bool: | ||
| count = self.estimate_tokens_in_prompt(prompt) | ||
| return ( | ||
| count + (prompt.max_tokens or self.default_tokens_generated) | ||
|
|
@@ -255,20 +256,22 @@ def supports_prefilled_chat(self) -> bool: | |
|
|
||
|
|
||
| def get_mock_predictor( | ||
| predict_func: Callable[[LmPrompt], LmPrediction] = None, | ||
| predict_func: Callable[[LmPrompt[T]], LmPrediction[T]] = None, | ||
| is_chat_model: bool = False, | ||
| ): | ||
| """Gets a mock predictor. By default returns whatever the prompt txt is""" | ||
|
|
||
| S = TypeVar('S') # Local TypeVar for the mock predictor | ||
|
|
||
| class MockPredict(LmPredictor): | ||
| class MockPredict(LmPredictor[S]): | ||
|
||
| def get_model_cache_key(self): | ||
| return "mock_predictor" | ||
|
|
||
| @property | ||
| def is_chat_model(self) -> bool: | ||
| return is_chat_model | ||
|
|
||
| def _predict_maybe_cached(self, prompt): | ||
| def _predict_maybe_cached(self, prompt: LmPrompt[S]) -> LmPrediction[S]: | ||
| if predict_func is None: | ||
| return LmPrediction( | ||
| prompt.get_text_as_string_default_form(), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,17 @@ | ||
| import contextlib | ||
| import dataclasses | ||
| import json | ||
| import pickle | ||
| import statistics | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, Optional, Union | ||
| from typing import Any, Generic, Optional, TypeVar, Union | ||
|
|
||
| from lmwrapper.internals import ModelInternalsResults | ||
| from lmwrapper.utils import StrEnum | ||
|
|
||
| # Define TypeVar for user metadata | ||
| T = TypeVar('T') | ||
|
|
||
| LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[ | ||
| str, | ||
| list[Union["LmChatTurn", tuple[str, str], dict, str]], | ||
|
|
@@ -34,7 +38,7 @@ class StopMode(StrEnum): | |
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class LmPrompt: | ||
| class LmPrompt(Generic[T]): | ||
| text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES | ||
| """The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES | ||
| which can become a LmChatDialog (such as a list of strings) it will be converted | ||
|
|
@@ -125,6 +129,11 @@ class LmPrompt: | |
| """Whether or not to add special tokens when encoding the prompt.""" | ||
| model_internals_request: Optional["ModelInternalsRequest"] = None | ||
| """Used to attempt to get hidden states and attentions from the model.""" | ||
| user_metadata: Optional[T] = None | ||
|
||
| """Optional user-defined metadata that gets transferred to the resulting LmPrediction. | ||
| This is not used for caching and can be any type. It's useful for tracking | ||
| additional information with each prompt and prediction (e.g., ground truth labels, | ||
| problem identifiers).""" | ||
|
|
||
| # TODO: make a auto_reduce_max_tokens to reduce when might go over. | ||
|
|
||
|
|
@@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str: | |
| else: | ||
| return self.text | ||
|
|
||
| def dict_serialize(self) -> dict: | ||
| def dict_serialize(self, include_user_metadata: bool = False) -> dict: | ||
| """ | ||
| Serialize the prompt into a json-compatible dictionary. Note this is not | ||
| guaranteed to be the same as the JSON representation for use | ||
| in an openai api call. This is just for serialization purposes. | ||
|
|
||
DNGros marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Args: | ||
| include_user_metadata: Whether to include user_metadata in serialization. | ||
| Default is False since user_metadata should not be part of the cache key. | ||
| """ | ||
| out = { | ||
| "max_tokens": self.max_tokens, | ||
|
|
@@ -232,6 +245,16 @@ def dict_serialize(self) -> dict: | |
| out["text"] = self.get_text_as_chat().as_dicts() | ||
| else: | ||
| out["text"] = self.text | ||
|
|
||
| if include_user_metadata and self.user_metadata is not None: | ||
| try: | ||
| # Test if it's JSON serializable | ||
| json.dumps(self.user_metadata) | ||
| out["user_metadata"] = self.user_metadata | ||
| except (TypeError, ValueError): | ||
| # If metadata isn't JSON serializable, leave it out | ||
| pass | ||
|
|
||
| return out | ||
|
|
||
|
|
||
|
|
@@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str: | |
|
|
||
|
|
||
| @dataclass | ||
| class LmPrediction: | ||
| class LmPrediction(Generic[T]): | ||
DNGros marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| completion_text: str | None | ||
| """The new text generated. It might be None if errors""" | ||
| prompt: LmPrompt | ||
| prompt: LmPrompt[T] | ||
| metad: Any | ||
| internals: ModelInternalsResults | None = field(default=None, kw_only=True) | ||
| error_message: str | None = field(default=None, kw_only=True) | ||
|
|
@@ -458,10 +481,11 @@ def dict_serialize( | |
| self, | ||
| pull_out_props: bool = True, | ||
| include_metad: bool = False, | ||
| include_user_metadata: bool = False, | ||
| ) -> dict[str, Any]: | ||
| out = { | ||
| "completion_text": self.completion_text, | ||
| "prompt": self.prompt.dict_serialize(), | ||
| "prompt": self.prompt.dict_serialize(include_user_metadata=include_user_metadata), | ||
| "was_cached": self.was_cached, | ||
| "error_message": self.error_message, | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the predictor should be generic. It can accept different kinds of prompts. It just should the same type as the input prompt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in fc34031. I've removed the Generic[T] from the LmPredictor class definition while still maintaining the type parameters in method signatures.