1414from openai import OpenAI
1515
1616
17+ def get_most_free_cuda_memory ():
18+ cuda_device = - 1
19+ max_free_memory = 0
20+ for i in range (torch .cuda .device_count ()):
21+ device_free_memory = torch .cuda .mem_get_info (i )[0 ]
22+ if device_free_memory > max_free_memory :
23+ max_free_memory = device_free_memory
24+ cuda_device = i
25+ return cuda_device , max_free_memory / (1024 ** 3 )
26+
1727def can_use_apt ():
1828 # Check if the OS is Linux and if it is a Debian-based distribution
1929 if platform .system () == "Linux" :
@@ -52,6 +62,7 @@ def __init__(self, **kwargs):
5262 "task" : "generate" ,
5363 })
5464 self .config ["device" ] = "cuda" if torch .cuda .is_available () else "cpu"
65+ self .cuda_device = 0
5566 self .min_device_memory_gb = kwargs .get ("min_device_memory_gb" , 8 )
5667 self .server_proc = None
5768 self .load_model ()
@@ -91,29 +102,33 @@ def is_server_up(self):
91102 return False
92103
93104 def set_available_device (self , num_tries = 0 ):
94- if num_tries > 3 :
105+ if num_tries > torch . cuda . device_count () + 3 :
95106 warnings .warn (f"{ num_tries } times setting device. Aborting." )
96107 return
97108
98109 memory_cuda = 0
99110 if torch .cuda .is_available ():
100111 torch .cuda .empty_cache ()
101- memory_cuda = torch .cuda .mem_get_info ()[0 ] / (1024 ** 3 )
102-
112+ memory_cuda = torch .cuda .mem_get_info (self . cuda_device )[0 ] / (1024 ** 3 )
113+
103114 gc .collect ()
104115 memory_cpu = psutil .virtual_memory ().available / (1024 ** 3 )
105116
106117 if self .config ["device" ] == "cuda" and memory_cuda < self .min_device_memory_gb :
107- warnings .warn (f"{ self .config ['device' ]} { memory_cuda } GB RAM free is less than { self .min_device_memory_gb } GB specified." )
108- if memory_cuda + memory_cpu >= self .min_device_memory_gb :
118+ warnings .warn (f"{ self .config ['device' ]} { memory_cuda } GB RAM free is less than { self .min_device_memory_gb } GB specified." )
119+ most_free_cuda , most_free_memory = get_most_free_cuda_memory ()
120+ self .cuda_device = most_free_cuda
121+ if most_free_memory >= self .min_device_memory_gb :
122+ warnings .warn (f"Try cuda:{ self .cuda_device } with { most_free_memory } GB RAM free" )
123+ elif most_free_memory + memory_cpu >= self .min_device_memory_gb :
109124 self .config ["cpu-offload-gb" ] = memory_cpu
110- warnings .warn (f"{ memory_cpu } GB cpu offloading" )
125+ warnings .warn (f"cuda: { self . cuda_device } with { memory_cpu } GB cpu offloading" )
111126 else :
112127 self .config ["device" ] = "cpu"
113128 warnings .warn (f"Set device to { self .config ['device' ]} " )
114- self .set_available_device (num_tries = num_tries + 1 )
129+ self .set_available_device (num_tries = num_tries + 1 )
115130 elif memory_cpu < self .min_device_memory_gb :
116- warnings .warn (f"{ self .config ['device' ]} { memory_cpu } GB RAM free is less than { self .min_device_memory_gb } GB specified." )
131+ warnings .warn (f"{ self .config ['device' ]} { memory_cpu } GB RAM free is less than { self .min_device_memory_gb } GB specified." )
117132 pids = kill_processes ("vllm" )
118133 warnings .warn (f"Killed processes { pids } " )
119134 self .config ["device" ] = "cuda" if torch .cuda .is_available () else "cpu"
@@ -125,9 +140,13 @@ def serve_model(self):
125140 args_strs = [f"--{ k } { self .config [k ]} " for k in self .config ]
126141 args_str = ' ' .join (args_strs )
127142 cmd_str = f"python -m vllm.entrypoints.openai.api_server { args_str } "
143+ env_mod = dict (os .environ )
144+ if self .config ["device" ] == "cuda" :
145+ env_mod = dict (os .environ , CUDA_VISIBLE_DEVICES = str (self .cuda_device ))
128146 try :
129147 self .server_proc = subprocess .Popen (
130148 shlex .split (cmd_str ),
149+ env = env_mod ,
131150 stdout = subprocess .PIPE ,
132151 stderr = subprocess .STDOUT ,
133152 text = True ,
@@ -138,12 +157,12 @@ def serve_model(self):
138157
139158 def wait_for_startup (self ):
140159 while True :
160+ time .sleep (1.0 )
141161 output = self .server_proc .stdout .readline ()
142162 if self .server_proc .poll () is not None :
143163 raise Exception (output )
144164 if "Application startup complete" in output :
145165 break
146- time .sleep (1.0 )
147166
148167 def load_model (self ):
149168 pbar = tqdm (total = 5 , desc = f'Model Setup ({ self .config ["port" ]} )' )
@@ -171,10 +190,10 @@ def chat_completion(self, messages = [{"role": "user", "content": "hello"}], **k
171190 if schema :
172191 try :
173192 completion = self .client .beta .chat .completions .parse (
174- model = self .config ["model" ],
175- messages = messages ,
176- response_format = schema ,
177- extra_body = dict (guided_decoding_backend = "outlines" ),
193+ model = self .config ["model" ],
194+ messages = messages ,
195+ response_format = schema ,
196+ extra_body = dict (guided_decoding_backend = "outlines" ),
178197 )
179198 schema_response = completion .choices [0 ].message .parsed
180199 return schema_response
0 commit comments