-
Notifications
You must be signed in to change notification settings - Fork 2
Add user metadata to LmPrompt class #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
907cc4b
a848c4d
b80cd02
0c29dc5
e0ed21f
0cf836a
fc34031
7b928b4
a345f4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,17 @@ | ||
| import contextlib | ||
| import dataclasses | ||
| import json | ||
| import pickle | ||
| import statistics | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, Optional, Union | ||
| from typing import Any, Generic, Optional, TypeVar, Union | ||
|
|
||
| from lmwrapper.internals import ModelInternalsResults | ||
| from lmwrapper.utils import StrEnum | ||
|
|
||
| # Define TypeVar for user metadata | ||
| T = TypeVar('T') | ||
|
|
||
| LM_CHAT_DIALOG_COERCIBLE_TYPES = Union[ | ||
| str, | ||
| list[Union["LmChatTurn", tuple[str, str], dict, str]], | ||
|
|
@@ -34,7 +38,7 @@ class StopMode(StrEnum): | |
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class LmPrompt: | ||
| class LmPrompt(Generic[T]): | ||
| text: str | LM_CHAT_DIALOG_COERCIBLE_TYPES | ||
| """The actual text of the prompt. If it is a LM_CHAT_DIALOG_COERCIBLE_TYPES | ||
| which can become a LmChatDialog (such as a list of strings) it will be converted | ||
|
|
@@ -125,6 +129,11 @@ class LmPrompt: | |
| """Whether or not to add special tokens when encoding the prompt.""" | ||
| model_internals_request: Optional["ModelInternalsRequest"] = None | ||
| """Used to attempt to get hidden states and attentions from the model.""" | ||
| user_metadata: Optional[T] = None | ||
|
||
| """Optional user-defined metadata that gets transferred to the resulting LmPrediction. | ||
| This is not used for caching and can be any type. It's useful for tracking | ||
| additional information with each prompt and prediction (e.g., ground truth labels, | ||
| problem identifiers).""" | ||
|
|
||
| # TODO: make a auto_reduce_max_tokens to reduce when might go over. | ||
|
|
||
|
|
@@ -208,11 +217,15 @@ def get_text_as_string_default_form(self) -> str: | |
| else: | ||
| return self.text | ||
|
|
||
| def dict_serialize(self) -> dict: | ||
| def dict_serialize(self, include_user_metadata: bool = False) -> dict: | ||
| """ | ||
| Serialize the prompt into a json-compatible dictionary. Note this is not | ||
| guaranteed to be the same as the JSON representation for use | ||
| in an openai api call. This is just for serialization purposes. | ||
|
|
||
DNGros marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Args: | ||
| include_user_metadata: Whether to include user_metadata in serialization. | ||
| Default is False since user_metadata should not be part of the cache key. | ||
| """ | ||
| out = { | ||
| "max_tokens": self.max_tokens, | ||
|
|
@@ -232,6 +245,16 @@ def dict_serialize(self) -> dict: | |
| out["text"] = self.get_text_as_chat().as_dicts() | ||
| else: | ||
| out["text"] = self.text | ||
|
|
||
| if include_user_metadata and self.user_metadata is not None: | ||
| try: | ||
| # Test if it's JSON serializable | ||
| json.dumps(self.user_metadata) | ||
| out["user_metadata"] = self.user_metadata | ||
| except (TypeError, ValueError): | ||
| # If metadata isn't JSON serializable, leave it out | ||
| pass | ||
|
|
||
| return out | ||
|
|
||
|
|
||
|
|
@@ -316,10 +339,10 @@ def to_default_string_prompt(self) -> str: | |
|
|
||
|
|
||
| @dataclass | ||
| class LmPrediction: | ||
| class LmPrediction(Generic[T]): | ||
DNGros marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| completion_text: str | None | ||
| """The new text generated. It might be None if errors""" | ||
| prompt: LmPrompt | ||
| prompt: LmPrompt[T] | ||
| metad: Any | ||
| internals: ModelInternalsResults | None = field(default=None, kw_only=True) | ||
| error_message: str | None = field(default=None, kw_only=True) | ||
|
|
@@ -458,10 +481,11 @@ def dict_serialize( | |
| self, | ||
| pull_out_props: bool = True, | ||
| include_metad: bool = False, | ||
| include_user_metadata: bool = False, | ||
| ) -> dict[str, Any]: | ||
| out = { | ||
| "completion_text": self.completion_text, | ||
| "prompt": self.prompt.dict_serialize(), | ||
| "prompt": self.prompt.dict_serialize(include_user_metadata=include_user_metadata), | ||
| "was_cached": self.was_cached, | ||
| "error_message": self.error_message, | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the predictor is not subscriptable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in a345f4a. Removed the generic type parameter from MockPredict class definition, changing
class MockPredict(LmPredictor[S]):toclass MockPredict(LmPredictor):.