1111from ..utils import decode_stopping_sequences_where_needed , construct_prompts
1212import json
1313
14- from typing import Generator
15- from transformers import TextIteratorStreamer
16- from threading import Thread
17- from queue import Empty
18- import asyncio
19-
2014if TYPE_CHECKING :
2115 from llama_cpp import Llama , LogitsProcessorList , StoppingCriteriaList
2216
@@ -104,97 +98,12 @@ def _add_default_generate_kwargs(
10498 return generate_kwargs
10599
106100 def __call__ (self , inputs : List [str ], ** kwargs ) -> List [Response ]:
107- logger .info (f"prompt_format: { self .prompt_format } " )
108- logger .info (f"before construct_prompts: { inputs } " )
109- inputs = construct_prompts (inputs , prompt_format = self .prompt_format )
110- logger .info (f"after construct_prompts: { inputs } " )
111-
112- tokenized_inputs = self .tokenizer .encode (inputs )
113- kwargs = self ._add_default_generate_kwargs (
114- kwargs ,
115- model_inputs = {"inputs" : inputs ,
116- "tokenized_inputs" : tokenized_inputs },
117- )
118-
119- chat_completion = False
120- try :
121- inputs_bak = inputs
122- inputs = [json .loads (prompt , strict = False ) for prompt in inputs ]
123- chat_completion = True
124- except Exception as ex :
125- logger .error (f"Exception apply_chat_template: { ex } " )
126- logger .info ("Seems no chat template from user" )
127- inputs = inputs_bak
128-
129- logger .info (f"Forward params: { kwargs } , model_inputs { inputs } " )
130- responses = []
131- for input in inputs :
132- st = time .monotonic ()
133- if chat_completion :
134- kwargs .pop ('stopping_criteria' , None )
135- kwargs .pop ('echo' , None )
136- logger .info (f"Forward params: { kwargs } , model_inputs { inputs } " )
137- output = self .model .create_chat_completion (
138- messages = input ,
139- ** kwargs
140- )
141- text = output ["choices" ][0 ]["message" ]["content" ].replace ("\u200b " , "" ).strip ()
142- else :
143- output = self .model (input , ** kwargs )
144- text = output ["choices" ][0 ]["text" ].replace ("\u200b " , "" ).strip ()
145-
146-
147- logger .info (f"llm's raw response is: { output } " )
148- gen_time = time .monotonic () - st
149-
150- responses .append (
151- Response (
152- generated_text = text ,
153- num_generated_tokens = output ["usage" ]["completion_tokens" ],
154- num_input_tokens = output ["usage" ]["prompt_tokens" ],
155- num_generated_tokens_batch = output ["usage" ]["completion_tokens" ],
156- num_input_tokens_batch = output ["usage" ]["prompt_tokens" ],
157- preprocessing_time = None ,
158- postprocessing_time = None ,
159- generation_time = gen_time ,
160- )
161- )
162- return responses
163-
164- # def stream(
165- # self,
166- # inputs: List[Union[str, Prompt]],
167- # **kwargs,
168- # ) -> Iterator[List[Response]]:
169- # tokenized_inputs = self.tokenizer.encode(inputs[0])
170- # kwargs = self._add_default_generate_kwargs(
171- # kwargs,
172- # model_inputs={"inputs": inputs,
173- # "tokenized_inputs": tokenized_inputs},
174- # )
101+ streams = [list () for _ in range (len (inputs ))]
102+ for batch_response in self .stream (inputs , ** kwargs ):
103+ for i , response in enumerate (batch_response ):
104+ streams [i ].append (response )
175105
176- # logger.info(f"Forward params: {kwargs}, model_inputs {inputs}")
177- # first_token_done = False
178- # for input in inputs:
179- # for output in self.model(input, stream=True, **kwargs):
180- # st = time.monotonic()
181- # gen_time = time.monotonic() - st
182- # text = output["choices"][0]["text"].replace("\u200b", "")
183- # if not first_token_done:
184- # text = text.lstrip()
185- # first_token_done = True
186- # yield [
187- # Response(
188- # generated_text=text,
189- # num_generated_tokens=1,
190- # num_input_tokens=len(tokenized_inputs),
191- # num_generated_tokens_batch=1,
192- # num_input_tokens_batch=len(tokenized_inputs),
193- # preprocessing_time=None,
194- # postprocessing_time=None,
195- # generation_time=gen_time,
196- # )
197- # ]
106+ return [Response .merge_stream (* stream ) for stream in streams ]
198107
199108 def preprocess (self , prompts : List [str ], ** generate_kwargs ):
200109 pass
@@ -231,6 +140,14 @@ def stream(
231140 logger .info (f"stream prompt: { inputs } " )
232141 inputs = construct_prompts (inputs , prompt_format = self .prompt_format )
233142 logger .info (f"stream inputs: { inputs } " )
143+
144+ tokenized_inputs = self .tokenizer .encode (inputs )
145+ kwargs = self ._add_default_generate_kwargs (
146+ kwargs ,
147+ model_inputs = {"inputs" : inputs ,
148+ "tokenized_inputs" : tokenized_inputs },
149+ )
150+
234151 chat_completion = False
235152 try :
236153 inputs_bak = inputs
@@ -250,20 +167,22 @@ def stream(
250167 for idx , input in enumerate (inputs ):
251168 tokenized_inputs = self .tokenizer .encode (input )
252169 if chat_completion :
253- kwargs .pop ('stopping_sequences' , None )
170+ # kwargs.pop('stopping_sequences', None)
171+ kwargs .pop ('stopping_criteria' , None )
254172 kwargs .pop ('echo' , None )
255173 logger .info (f"chat generate_kwargs: { kwargs } " )
256174 output = self .model .create_chat_completion (messages = input , stream = True , ** kwargs )
257175 for chunk in output :
258176 st = time .monotonic ()
259177 gen_time = time .monotonic () - st
260178 delta = chunk ['choices' ][0 ]['delta' ]
179+
261180 val = ''
262181 if 'role' in delta :
263182 val = ''
264183 elif 'content' in delta :
265184 val = delta ['content' ]
266- logger .info (f'LlamaCppPipeline -> create_chat_completion -> Yield -> "{ val } "' )
185+ # logger.info(f'LlamaCppPipeline -> create_chat_completion -> Yield -> "{val}"')
267186 if val :
268187 yield [
269188 Response (
@@ -296,29 +215,29 @@ def stream(
296215 st = time .monotonic ()
297216 gen_time = time .monotonic () - st
298217 chunk = token ["choices" ][0 ]["text" ].replace ("\u200b " , "" )
299- logger .info (f'LlamaCppPipeline -> generate -> Yield -> "{ chunk } "' )
300- if val :
301- yield [
302- Response (
303- generated_text = val ,
304- num_generated_tokens = 1 ,
305- num_input_tokens = len (tokenized_inputs ),
306- num_generated_tokens_batch = 1 ,
307- num_input_tokens_batch = len (tokenized_inputs ),
308- preprocessing_time = None ,
309- postprocessing_time = None ,
310- generation_time = gen_time ,
311- )
312- if i == idx else
313- Response (
314- generated_text = "" ,
315- num_generated_tokens = 0 ,
316- num_input_tokens = len (tokenized_inputs ),
317- num_generated_tokens_batch = 0 ,
318- num_input_tokens_batch = len (tokenized_inputs ),
319- preprocessing_time = None ,
320- postprocessing_time = None ,
321- generation_time = gen_time ,
322- )
323- for i in range (batch_size )
324- ]
218+ # logger.info(f'LlamaCppPipeline -> generate -> Yield -> "{chunk}"')
219+ # if chunk :
220+ yield [
221+ Response (
222+ generated_text = chunk ,
223+ num_generated_tokens = 1 ,
224+ num_input_tokens = len (tokenized_inputs ),
225+ num_generated_tokens_batch = 1 ,
226+ num_input_tokens_batch = len (tokenized_inputs ),
227+ preprocessing_time = None ,
228+ postprocessing_time = None ,
229+ generation_time = gen_time ,
230+ )
231+ if i == idx else
232+ Response (
233+ generated_text = "" ,
234+ num_generated_tokens = 0 ,
235+ num_input_tokens = len (tokenized_inputs ),
236+ num_generated_tokens_batch = 0 ,
237+ num_input_tokens_batch = len (tokenized_inputs ),
238+ preprocessing_time = None ,
239+ postprocessing_time = None ,
240+ generation_time = gen_time ,
241+ )
242+ for i in range (batch_size )
243+ ]
0 commit comments