1212
1313import requests
1414import sseclient
15+ import subprocess
1516
1617from copy import copy , deepcopy
1718from fastapi import FastAPI
@@ -83,9 +84,10 @@ class ChatCompletionRequest(BaseModel):
8384 top_p : Optional [float ] = 0.0 # default value of 0.0
8485 user : Optional [str ] = None
8586
86- repo_str = 'Yi-34B-Chat'
87+ # repo_str = 'Yi-34B-Chat'
8788#repo_str = 'theprofessor-exl2-speculative'
88- #repo_str = 'tinyllama-exl2-speculative'
89+ #repo_str = 'venus-exl2-speculative'
90+ repo_str = 'miqu-exl2-speculative'
8991
9092parser = argparse .ArgumentParser (description = 'Run server with specified port.' )
9193
@@ -138,7 +140,7 @@ class ChatCompletionRequest(BaseModel):
138140print ("Loading model: " + repo_id )
139141#cache = ExLlamaV2Cache(model, lazy=True, max_seq_len = 20480)
140142#model.load_autosplit(cache)
141- model .load ()
143+ model .load ([ 19 , 9 , 10 , 10 ] )
142144
143145draft = ExLlamaV2 (draft_config )
144146print ("Loading draft model: " + specrepo_id )
@@ -179,8 +181,8 @@ class ChatCompletionRequest(BaseModel):
179181# Global variable for storing partial responses
180182partial_responses = {}
181183
182- max_parallel_seqs = 5
183- num_of_gpus = 1
184+ max_parallel_seqs = 3
185+ num_of_gpus = 4
184186
185187dynamic_merge = True
186188## Dynamic Merge Slicing
@@ -224,13 +226,15 @@ class ChatCompletionRequest(BaseModel):
224226
225227
226228
227- # Yi-34B-Chat
229+ #Venus 120b
228230 layer_arrangement .extend (range (0 , 20 ))
229231 layer_ranges = [
230232 (10 , 30 ), # adjusted for Python's zero-indexing and end-exclusive range.
231233 (20 , 40 ),
232234 (30 , 50 ),
233235 (40 , 60 ),
236+ (50 , 70 ),
237+ (60 , 80 ),
234238 ]
235239
236240
@@ -422,7 +426,7 @@ class ChatCompletionRequest(BaseModel):
422426 device_tensors .get_scratch_slice (model_self .temp_b_size ()),
423427 device_tensors .get_scratch_slice (model_self .temp_dq_size ()),
424428 model .config .max_input_len * model .config .max_batch_size ,
425- model .config .architecture == "Gemma " )
429+ model .config .arch . mlp_act_func == "gelu " )
426430
427431 model .modules += old_modules [- 2 :]
428432 model .head_layer_idx = len (model .modules ) - 1
@@ -439,15 +443,15 @@ class ChatCompletionRequest(BaseModel):
439443 print (key )
440444 print ("Num of hidden layers:" + str (model .config .num_hidden_layers ))
441445# Load LoRA
442- lora_directory = "../exllamav2/checkpoint-100/"
446+ # lora_directory = "../exllamav2/checkpoint-100/"
443447#lora_directory = "../exllamav2/unsloth/unsloth_outputs_expand/checkpoint-7000/"
444448#lora_directory = "../exllamav2/unsloth/unsloth_outputs_expand8x/checkpoint-12000/"
445449#lora_directory = "../exllamav2/unsloth/unsloth_outputs_yi_lima/checkpoint-8000/"
446450#lora_directory = "../exllamav2/unsloth/trained_unsloth_tinyllama_lima/"
447451#lora_directory = "../exllamav2/openhermes_out_stacked_94layers/checkpoint-11000/"
448452#lora_directory = "../exllamav2/openhermes_out/checkpoint-6500/"
449- lora = ExLlamaV2Lora .from_directory (model , lora_directory )
450- # lora = None
453+ # lora = ExLlamaV2Lora.from_directory(model, lora_directory)
454+ lora = None
451455
452456
453457
@@ -667,7 +671,13 @@ def process_prompts():
667671
668672 new_text = tokenizer .decode (input_ids [i ][:, - 2 :- 1 ], decode_special_tokens = False )[0 ]
669673 new_text2 = tokenizer .decode (input_ids [i ][:, - 2 :], decode_special_tokens = False )[0 ]
670- diff = new_text2 [len (new_text ):]
674+ if '�' in new_text :
675+ diff = new_text2
676+ else :
677+ diff = new_text2 [len (new_text ):]
678+
679+ if '�' in diff :
680+ diff = ""
671681
672682 #print(diff)
673683 reason = None
@@ -909,13 +919,13 @@ async def mainchat(request: ChatCompletionRequest):
909919 prompt = await format_prompt_zephyr (request .messages )
910920 elif repo_str == 'Starling-LM-7B-alpha' :
911921 prompt = await format_prompt_starling (request .messages )
912- elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ' :
922+ elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ' or repo_str == 'miqu-exl2-speculative' :
913923 prompt = await format_prompt_mixtral (request .messages )
914924 elif repo_str == 'Yi-34B-Chat-GPTQ' or repo_str == 'Nous-Hermes-2-Yi-34B-GPTQ' or repo_str == 'theprofessor-exl2-speculative' or repo_str == 'Yi-34B-Chat' :
915925 prompt = await format_prompt_yi (request .messages )
916926 elif repo_str == 'Nous-Capybara-34B-GPTQ' or repo_str == 'goliath-120b-GPTQ' or repo_str == 'goliath-120b-exl2' or repo_str == 'goliath-120b-exl2-rpcal' :
917927 prompt = await format_prompt_nous (request .messages )
918- elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative' :
928+ elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative' or repo_str == 'venus-exl2-speculative' :
919929 prompt = await format_prompt_tess (request .messages )
920930 elif repo_str == 'tinyllama-exl2-speculative' :
921931 prompt = await format_prompt_zephyr (request .messages )
@@ -953,6 +963,29 @@ async def mainchat(request: ChatCompletionRequest):
953963async def get_status ():
954964 return {"ping" : sum (prompt_length )}
955965
966+ @app .get ("/nvidia-smi" )
967+ async def get_nvidia_smi ():
968+ # Execute the nvidia-smi command
969+ result = subprocess .run (
970+ ["nvidia-smi" , "--query-gpu=utilization.gpu,memory.used,memory.total" , "--format=csv,noheader" ],
971+ capture_output = True , text = True
972+ )
973+ nvidia_smi_output = result .stdout .strip () # Remove any extra whitespace
974+ # Split the output by lines and then by commas
975+ gpu_data = []
976+ for line in nvidia_smi_output .split ("\n " ):
977+ utilization , memory_used , memory_total = line .split (", " )
978+ # Strip the '%' and 'MiB' and convert to appropriate types
979+ utilization = float (utilization .strip (' %' ))
980+ memory_used = int (memory_used .strip (' MiB' ))
981+ memory_total = int (memory_total .strip (' MiB' ))
982+ gpu_data .append ({
983+ "utilization" : utilization ,
984+ "memory_used" : memory_used ,
985+ "memory_total" : memory_total
986+ })
987+ return gpu_data
988+
956989if __name__ == "__main__" :
957990 import uvicorn
958991
0 commit comments