@@ -80,7 +80,7 @@ class ChatCompletionRequest(BaseModel):
8080 top_p : Optional [float ] = 0.0 # default value of 0.0
8181 user : Optional [str ] = None
8282
83- repo_str = 'dbrx-instruct -exl2'
83+ repo_str = 'commandr -exl2'
8484#repo_str = 'theprofessor-exl2-speculative'
8585
8686parser = argparse .ArgumentParser (description = 'Run server with specified port.' )
@@ -283,7 +283,13 @@ def process_prompts():
283283
284284 new_text = tokenizer .decode (input_ids [i ][:, - 2 :- 1 ], decode_special_tokens = False )[0 ]
285285 new_text2 = tokenizer .decode (input_ids [i ][:, - 2 :], decode_special_tokens = False )[0 ]
286- diff = new_text2 [len (new_text ):]
286+ if '�' in new_text :
287+ diff = new_text2
288+ else :
289+ diff = new_text2 [len (new_text ):]
290+
291+ if '�' in diff :
292+ diff = ""
287293
288294 #print(diff)
289295 reason = None
@@ -507,6 +513,21 @@ async def format_prompt_mixtral(messages):
507513 formatted_prompt += f" { message .content } </s> " # Prep for user follow-up
508514 return formatted_prompt
509515
516+ async def format_prompt_commandr (messages ):
517+ formatted_prompt = ""
518+ for message in messages :
519+ if message .role == "system" :
520+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
521+ elif message .role == "user" :
522+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
523+ elif message .role == "assistant" :
524+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
525+ # Add the final "### Assistant:\n" to prompt for the next response
526+ formatted_prompt += "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
527+ return formatted_prompt
528+
529+
530+
510531@app .post ('/v1/chat/completions' )
511532async def mainchat (request : ChatCompletionRequest ):
512533
@@ -526,6 +547,8 @@ async def mainchat(request: ChatCompletionRequest):
526547 prompt = await format_prompt_nous (request .messages )
527548 elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative' :
528549 prompt = await format_prompt_tess (request .messages )
550+ elif repo_str == 'commandr-exl2' :
551+ prompt = await format_prompt_commandr (request .messages )
529552 else :
530553 prompt = await format_prompt (request .messages )
531554 print (prompt )
0 commit comments