22from abc import abstractmethod
33from collections .abc import Callable , Iterable
44from sqlite3 import OperationalError
5+ from typing import TypeVar , Union
56
67from ratemate import RateLimit
78
89from lmwrapper .batch_config import CompletionWindow
910from lmwrapper .sqlcache_struct import BatchPredictionPlaceholder
10- from lmwrapper .structs import LM_CHAT_DIALOG_COERCIBLE_TYPES , LmPrediction , LmPrompt
11+ from lmwrapper .structs import LM_CHAT_DIALOG_COERCIBLE_TYPES , LmPrediction , LmPrompt , T
1112
1213
1314class LmPredictor :
@@ -23,13 +24,13 @@ def __init__(
2324
2425 self ._disk_cache = SqlBackedCache (self )
2526
26- def find_prediction_class (self , prompt ):
27- return LmPrediction
27+ def find_prediction_class (self , prompt : LmPrompt [ T ] ):
28+ return LmPrediction [ T ]
2829
2930 def predict (
3031 self ,
31- prompt : LmPrompt | str | LM_CHAT_DIALOG_COERCIBLE_TYPES ,
32- ) -> LmPrediction | list [LmPrediction ]:
32+ prompt : LmPrompt [ T ] | str | LM_CHAT_DIALOG_COERCIBLE_TYPES ,
33+ ) -> LmPrediction [ T ] | list [LmPrediction [ T ] ]:
3334 prompt = self ._cast_prompt (prompt )
3435 should_cache = self ._cache_default if prompt .cache is None else prompt .cache
3536 if should_cache and prompt .model_internals_request is not None :
@@ -74,7 +75,7 @@ def predict(
7475 return vals [0 ]
7576 return vals
7677
77- def _read_cached_values (self , prompt : LmPrompt ) -> list [LmPrediction ]:
78+ def _read_cached_values (self , prompt : LmPrompt [ T ] ) -> list [LmPrediction [ T ] ]:
7879 """
7980 Checks the cache for any matches of the prompt. Returns a list
8081 as if num_completions is >1 we might have multiple items
@@ -101,9 +102,9 @@ def _read_cached_values(self, prompt: LmPrompt) -> list[LmPrediction]:
101102
102103 def predict_many (
103104 self ,
104- prompts : list [LmPrompt ],
105+ prompts : list [LmPrompt [ T ] ],
105106 completion_window : CompletionWindow ,
106- ) -> Iterable [LmPrediction | list [LmPrediction ]]:
107+ ) -> Iterable [LmPrediction [ T ] | list [LmPrediction [ T ] ]]:
107108 self ._validate_predict_many_prompts (prompts )
108109 for prompt in prompts :
109110 val = self .predict (prompt )
@@ -125,11 +126,11 @@ def _validate_predict_many_prompts(self, prompts):
125126
126127 def remove_prompt_from_cache (
127128 self ,
128- prompt : str | LmPrompt ,
129+ prompt : str | LmPrompt [ T ] ,
129130 ) -> bool :
130131 return self ._disk_cache .delete (prompt )
131132
132- def _validate_prompt (self , prompt : LmPrompt , raise_on_invalid : bool = True ) -> bool :
133+ def _validate_prompt (self , prompt : LmPrompt [ T ] , raise_on_invalid : bool = True ) -> bool :
133134 """Called on prediction to make sure the prompt is valid for the model"""
134135 return True
135136
@@ -145,11 +146,11 @@ def supports_token_operations(self) -> bool:
145146 @abstractmethod
146147 def _predict_maybe_cached (
147148 self ,
148- prompt : LmPrompt ,
149- ) -> list [LmPrediction ]:
149+ prompt : LmPrompt [ T ] ,
150+ ) -> LmPrediction [ T ] | list [LmPrediction [ T ] ]:
150151 pass
151152
152- def _cast_prompt (self , prompt : str | LmPrompt ) -> LmPrompt :
153+ def _cast_prompt (self , prompt : str | LmPrompt [ T ] | list ) -> LmPrompt [ T ] :
153154 if isinstance (prompt , str ):
154155 return LmPrompt (prompt , 100 )
155156 if isinstance (prompt , list ):
@@ -180,14 +181,14 @@ def _cast_prompt(self, prompt: str | LmPrompt) -> LmPrompt:
180181 )
181182 raise ValueError (msg )
182183
183- def estimate_tokens_in_prompt (self , prompt : LmPrompt ) -> int :
184+ def estimate_tokens_in_prompt (self , prompt : LmPrompt [ T ] ) -> int :
184185 raise NotImplementedError
185186
186187 @property
187188 def token_limit (self ):
188189 raise NotImplementedError
189190
190- def could_completion_go_over_token_limit (self , prompt : LmPrompt ) -> bool :
191+ def could_completion_go_over_token_limit (self , prompt : LmPrompt [ T ] ) -> bool :
191192 count = self .estimate_tokens_in_prompt (prompt )
192193 return (
193194 count + (prompt .max_tokens or self .default_tokens_generated )
@@ -255,10 +256,12 @@ def supports_prefilled_chat(self) -> bool:
255256
256257
257258def get_mock_predictor (
258- predict_func : Callable [[LmPrompt ] , LmPrediction ] = None ,
259+ predict_func : Callable [[LmPrompt [ T ]] , LmPrediction [ T ] ] = None ,
259260 is_chat_model : bool = False ,
260261):
261262 """Gets a mock predictor. By default returns whatever the prompt txt is"""
263+
264+ S = TypeVar ('S' ) # Local TypeVar for the mock predictor
262265
263266 class MockPredict (LmPredictor ):
264267 def get_model_cache_key (self ):
@@ -268,7 +271,7 @@ def get_model_cache_key(self):
268271 def is_chat_model (self ) -> bool :
269272 return is_chat_model
270273
271- def _predict_maybe_cached (self , prompt ) :
274+ def _predict_maybe_cached (self , prompt : LmPrompt [ S ]) -> LmPrediction [ S ] :
272275 if predict_func is None :
273276 return LmPrediction (
274277 prompt .get_text_as_string_default_form (),
0 commit comments