@@ -80,7 +80,7 @@ class ChatCompletionRequest(BaseModel):
8080 top_p : Optional [float ] = 0.0 # default value of 0.0
8181 user : Optional [str ] = None
8282
83- repo_str = 'tess-xl -exl2-speculative'
83+ repo_str = 'commandr -exl2-speculative'
8484#repo_str = 'theprofessor-exl2-speculative'
8585
8686parser = argparse .ArgumentParser (description = 'Run server with specified port.' )
@@ -108,12 +108,12 @@ class ChatCompletionRequest(BaseModel):
108108config .model_dir = repo_id
109109config .prepare ()
110110
111- use_dynamic_rope_scaling = True
111+ use_dynamic_rope_scaling = False
112112dynamic_rope_mult = 1.5
113113dynamic_rope_offset = 0.0
114114
115115ropescale = 1.0
116- max_context = 12288
116+ max_context = 8096
117117config .scale_alpha_value = ropescale
118118config .max_seq_len = max_context
119119base_model_native_max = 4096
@@ -123,18 +123,18 @@ class ChatCompletionRequest(BaseModel):
123123draft_config .model_dir = specrepo_id
124124draft_config .prepare ()
125125
126- draft_ropescale = 3 .0
127- num_speculative_tokens = 5
128- speculative_prob_threshold = 0.25
126+ draft_ropescale = 1 .0
127+ num_speculative_tokens = 3
128+ speculative_prob_threshold = 0.15
129129draft_config .scale_alpha_value = draft_ropescale
130130draft_config .max_seq_len = max_context
131- draft_model_native_max = 2048
131+ draft_model_native_max = 8048
132132
133133model = ExLlamaV2 (config )
134134print ("Loading model: " + repo_id )
135135#cache = ExLlamaV2Cache(model, lazy=True, max_seq_len = 20480)
136136#model.load_autosplit(cache)
137- model .load ([16 , 18 , 18 , 18 ])
137+ model .load ([12 , 20 , 20 , 20 ])
138138
139139draft = ExLlamaV2 (draft_config )
140140print ("Loading draft model: " + specrepo_id )
@@ -394,7 +394,13 @@ def process_prompts():
394394
395395 new_text = tokenizer .decode (input_ids [i ][:, - 2 :- 1 ], decode_special_tokens = False )[0 ]
396396 new_text2 = tokenizer .decode (input_ids [i ][:, - 2 :], decode_special_tokens = False )[0 ]
397- diff = new_text2 [len (new_text ):]
397+ if '�' in new_text :
398+ diff = new_text2
399+ else :
400+ diff = new_text2 [len (new_text ):]
401+
402+ if '�' in diff :
403+ diff = ""
398404
399405 #print(diff)
400406 reason = None
@@ -620,6 +626,33 @@ async def format_prompt_mixtral(messages):
620626 formatted_prompt += f" { message .content } </s> " # Prep for user follow-up
621627 return formatted_prompt
622628
629+ async def format_prompt_commandr (messages ):
630+ formatted_prompt = ""
631+ system_message_found = False
632+
633+ # Check for a system message first
634+ for message in messages :
635+ if message .role == "system" :
636+ system_message_found = True
637+ break
638+
639+ # If no system message was found, prepend a default one
640+ if not system_message_found :
641+ formatted_prompt += f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
642+
643+ for message in messages :
644+ if message .role == "system" :
645+ formatted_prompt += f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
646+ elif message .role == "user" :
647+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
648+ elif message .role == "assistant" :
649+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{ message .content } <|END_OF_TURN_TOKEN|>"
650+ # Add the final "### Assistant:\n" to prompt for the next response
651+ formatted_prompt += "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
652+ return formatted_prompt
653+
654+
655+
623656@app .post ('/v1/chat/completions' )
624657async def mainchat (request : ChatCompletionRequest ):
625658
@@ -639,6 +672,8 @@ async def mainchat(request: ChatCompletionRequest):
639672 prompt = await format_prompt_nous (request .messages )
640673 elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative' :
641674 prompt = await format_prompt_tess (request .messages )
675+ elif repo_str == 'commandr-exl2' or repo_str == 'commandr-exl2-speculative' :
676+ prompt = await format_prompt_commandr (request .messages )
642677 else :
643678 prompt = await format_prompt (request .messages )
644679 print (prompt )
0 commit comments