Skip to content

Commit df5fa34

Browse files
committed
fixed template for commandr
1 parent c690f3a commit df5fa34

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

llm_exl2_client_multi_speculative.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ChatCompletionRequest(BaseModel):
8080
top_p: Optional[float] = 0.0 # default value of 0.0
8181
user: Optional[str] = None
8282

83-
repo_str = 'tess-xl-exl2-speculative'
83+
repo_str = 'commandr-exl2-speculative'
8484
#repo_str = 'theprofessor-exl2-speculative'
8585

8686
parser = argparse.ArgumentParser(description='Run server with specified port.')
@@ -108,12 +108,12 @@ class ChatCompletionRequest(BaseModel):
108108
config.model_dir = repo_id
109109
config.prepare()
110110

111-
use_dynamic_rope_scaling = True
111+
use_dynamic_rope_scaling = False
112112
dynamic_rope_mult = 1.5
113113
dynamic_rope_offset = 0.0
114114

115115
ropescale = 1.0
116-
max_context = 12288
116+
max_context = 8096
117117
config.scale_alpha_value = ropescale
118118
config.max_seq_len = max_context
119119
base_model_native_max = 4096
@@ -123,18 +123,18 @@ class ChatCompletionRequest(BaseModel):
123123
draft_config.model_dir = specrepo_id
124124
draft_config.prepare()
125125

126-
draft_ropescale = 3.0
127-
num_speculative_tokens = 5
128-
speculative_prob_threshold = 0.25
126+
draft_ropescale = 1.0
127+
num_speculative_tokens = 3
128+
speculative_prob_threshold = 0.15
129129
draft_config.scale_alpha_value = draft_ropescale
130130
draft_config.max_seq_len = max_context
131-
draft_model_native_max = 2048
131+
draft_model_native_max = 8048
132132

133133
model = ExLlamaV2(config)
134134
print("Loading model: " + repo_id)
135135
#cache = ExLlamaV2Cache(model, lazy=True, max_seq_len = 20480)
136136
#model.load_autosplit(cache)
137-
model.load([16,18,18,18])
137+
model.load([12,20,20,20])
138138

139139
draft = ExLlamaV2(draft_config)
140140
print("Loading draft model: " + specrepo_id)
@@ -394,7 +394,13 @@ def process_prompts():
394394

395395
new_text = tokenizer.decode(input_ids[i][:, -2:-1], decode_special_tokens=False)[0]
396396
new_text2 = tokenizer.decode(input_ids[i][:, -2:], decode_special_tokens=False)[0]
397-
diff = new_text2[len(new_text):]
397+
if '�' in new_text:
398+
diff = new_text2
399+
else:
400+
diff = new_text2[len(new_text):]
401+
402+
if '�' in diff:
403+
diff = ""
398404

399405
#print(diff)
400406
reason = None
@@ -620,6 +626,33 @@ async def format_prompt_mixtral(messages):
620626
formatted_prompt += f" {message.content}</s> " # Prep for user follow-up
621627
return formatted_prompt
622628

629+
async def format_prompt_commandr(messages):
630+
formatted_prompt = ""
631+
system_message_found = False
632+
633+
# Check for a system message first
634+
for message in messages:
635+
if message.role == "system":
636+
system_message_found = True
637+
break
638+
639+
# If no system message was found, prepend a default one
640+
if not system_message_found:
641+
formatted_prompt += f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
642+
643+
for message in messages:
644+
if message.role == "system":
645+
formatted_prompt += f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
646+
elif message.role == "user":
647+
formatted_prompt += f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
648+
elif message.role == "assistant":
649+
formatted_prompt += f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
650+
# Add the final "### Assistant:\n" to prompt for the next response
651+
formatted_prompt += "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
652+
return formatted_prompt
653+
654+
655+
623656
@app.post('/v1/chat/completions')
624657
async def mainchat(request: ChatCompletionRequest):
625658

@@ -639,6 +672,8 @@ async def mainchat(request: ChatCompletionRequest):
639672
prompt = await format_prompt_nous(request.messages)
640673
elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative':
641674
prompt = await format_prompt_tess(request.messages)
675+
elif repo_str == 'commandr-exl2' or repo_str == 'commandr-exl2-speculative':
676+
prompt = await format_prompt_commandr(request.messages)
642677
else:
643678
prompt = await format_prompt(request.messages)
644679
print(prompt)

0 commit comments

Comments
 (0)