@@ -75,6 +75,7 @@ def run_vllm(
75
75
device : str ,
76
76
enable_prefix_caching : bool ,
77
77
gpu_memory_utilization : float = 0.9 ,
78
+ download_dir : Optional [str ] = None ,
78
79
) -> float :
79
80
from vllm import LLM , SamplingParams
80
81
llm = LLM (model = model ,
@@ -89,7 +90,8 @@ def run_vllm(
89
90
enforce_eager = enforce_eager ,
90
91
kv_cache_dtype = kv_cache_dtype ,
91
92
device = device ,
92
- enable_prefix_caching = enable_prefix_caching )
93
+ enable_prefix_caching = enable_prefix_caching ,
94
+ download_dir = download_dir )
93
95
94
96
# Add the requests to the engine.
95
97
for prompt , _ , output_len in requests :
@@ -208,12 +210,14 @@ def main(args: argparse.Namespace):
208
210
args .output_len )
209
211
210
212
if args .backend == "vllm" :
211
- elapsed_time = run_vllm (
212
- requests , args .model , args .tokenizer , args .quantization ,
213
- args .tensor_parallel_size , args .seed , args .n , args .use_beam_search ,
214
- args .trust_remote_code , args .dtype , args .max_model_len ,
215
- args .enforce_eager , args .kv_cache_dtype , args .device ,
216
- args .enable_prefix_caching , args .gpu_memory_utilization )
213
+ elapsed_time = run_vllm (requests , args .model , args .tokenizer ,
214
+ args .quantization , args .tensor_parallel_size ,
215
+ args .seed , args .n , args .use_beam_search ,
216
+ args .trust_remote_code , args .dtype ,
217
+ args .max_model_len , args .enforce_eager ,
218
+ args .kv_cache_dtype , args .device ,
219
+ args .enable_prefix_caching ,
220
+ args .gpu_memory_utilization , args .download_dir )
217
221
elif args .backend == "hf" :
218
222
assert args .tensor_parallel_size == 1
219
223
elapsed_time = run_hf (requests , args .model , tokenizer , args .n ,
@@ -314,6 +318,11 @@ def main(args: argparse.Namespace):
314
318
"--enable-prefix-caching" ,
315
319
action = 'store_true' ,
316
320
help = "enable automatic prefix caching for vLLM backend." )
321
+ parser .add_argument ('--download-dir' ,
322
+ type = str ,
323
+ default = None ,
324
+ help = 'directory to download and load the weights, '
325
+ 'default to the default cache dir of huggingface' )
317
326
args = parser .parse_args ()
318
327
if args .tokenizer is None :
319
328
args .tokenizer = args .model
0 commit comments