Skip to content

Commit cb41dc5

Browse files
committed
Assistant device selection
1 parent 9b8b332 commit cb41dc5

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="text2text",
8-
version="1.9.3",
8+
version="1.9.4",
99
author="artitw",
1010
author_email="artitw@gmail.com",
1111
description="Text2Text Language Modeling Toolkit",

text2text/assistant.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@
1414
from openai import OpenAI
1515

1616

17+
def get_most_free_cuda_memory():
18+
cuda_device = -1
19+
max_free_memory = 0
20+
for i in range(torch.cuda.device_count()):
21+
device_free_memory = torch.cuda.mem_get_info(i)[0]
22+
if device_free_memory > max_free_memory:
23+
max_free_memory = device_free_memory
24+
cuda_device = i
25+
return cuda_device, max_free_memory / (1024 ** 3)
26+
1727
def can_use_apt():
1828
# Check if the OS is Linux and if it is a Debian-based distribution
1929
if platform.system() == "Linux":
@@ -52,6 +62,7 @@ def __init__(self, **kwargs):
5262
"task": "generate",
5363
})
5464
self.config["device"] = "cuda" if torch.cuda.is_available() else "cpu"
65+
self.cuda_device = 0
5566
self.min_device_memory_gb = kwargs.get("min_device_memory_gb", 8)
5667
self.server_proc = None
5768
self.load_model()
@@ -91,29 +102,33 @@ def is_server_up(self):
91102
return False
92103

93104
def set_available_device(self, num_tries=0):
94-
if num_tries > 3:
105+
if num_tries > torch.cuda.device_count() + 3:
95106
warnings.warn(f"{num_tries} times setting device. Aborting.")
96107
return
97108

98109
memory_cuda = 0
99110
if torch.cuda.is_available():
100111
torch.cuda.empty_cache()
101-
memory_cuda = torch.cuda.mem_get_info()[0] / (1024 ** 3)
102-
112+
memory_cuda = torch.cuda.mem_get_info(self.cuda_device)[0] / (1024 ** 3)
113+
103114
gc.collect()
104115
memory_cpu = psutil.virtual_memory().available / (1024 ** 3)
105116

106117
if self.config["device"] == "cuda" and memory_cuda < self.min_device_memory_gb:
107-
warnings.warn(f"{self.config['device']} {memory_cuda}GB RAM free is less than {self.min_device_memory_gb}GB specified.")
108-
if memory_cuda+memory_cpu >= self.min_device_memory_gb:
118+
warnings.warn(f"{self.config['device']} {memory_cuda} GB RAM free is less than {self.min_device_memory_gb} GB specified.")
119+
most_free_cuda, most_free_memory = get_most_free_cuda_memory()
120+
self.cuda_device = most_free_cuda
121+
if most_free_memory >= self.min_device_memory_gb:
122+
warnings.warn(f"Try cuda:{self.cuda_device} with {most_free_memory} GB RAM free")
123+
elif most_free_memory+memory_cpu >= self.min_device_memory_gb:
109124
self.config["cpu-offload-gb"] = memory_cpu
110-
warnings.warn(f"{memory_cpu}GB cpu offloading")
125+
warnings.warn(f"cuda:{self.cuda_device} with {memory_cpu} GB cpu offloading")
111126
else:
112127
self.config["device"] = "cpu"
113128
warnings.warn(f"Set device to {self.config['device']}")
114-
self.set_available_device(num_tries=num_tries+1)
129+
self.set_available_device(num_tries=num_tries+1)
115130
elif memory_cpu < self.min_device_memory_gb:
116-
warnings.warn(f"{self.config['device']} {memory_cpu}GB RAM free is less than {self.min_device_memory_gb}GB specified.")
131+
warnings.warn(f"{self.config['device']} {memory_cpu} GB RAM free is less than {self.min_device_memory_gb} GB specified.")
117132
pids = kill_processes("vllm")
118133
warnings.warn(f"Killed processes {pids}")
119134
self.config["device"] = "cuda" if torch.cuda.is_available() else "cpu"
@@ -125,9 +140,13 @@ def serve_model(self):
125140
args_strs = [f"--{k} {self.config[k]}" for k in self.config]
126141
args_str = ' '.join(args_strs)
127142
cmd_str = f"python -m vllm.entrypoints.openai.api_server {args_str}"
143+
env_mod = dict(os.environ)
144+
if self.config["device"] == "cuda":
145+
env_mod = dict(os.environ, CUDA_VISIBLE_DEVICES=str(self.cuda_device))
128146
try:
129147
self.server_proc = subprocess.Popen(
130148
shlex.split(cmd_str),
149+
env=env_mod,
131150
stdout=subprocess.PIPE,
132151
stderr=subprocess.STDOUT,
133152
text=True,
@@ -138,12 +157,12 @@ def serve_model(self):
138157

139158
def wait_for_startup(self):
140159
while True:
160+
time.sleep(1.0)
141161
output = self.server_proc.stdout.readline()
142162
if self.server_proc.poll() is not None:
143163
raise Exception(output)
144164
if "Application startup complete" in output:
145165
break
146-
time.sleep(1.0)
147166

148167
def load_model(self):
149168
pbar = tqdm(total=5, desc=f'Model Setup ({self.config["port"]})')
@@ -171,10 +190,10 @@ def chat_completion(self, messages = [{"role": "user", "content": "hello"}], **k
171190
if schema:
172191
try:
173192
completion = self.client.beta.chat.completions.parse(
174-
model=self.config["model"],
175-
messages=messages,
176-
response_format=schema,
177-
extra_body=dict(guided_decoding_backend="outlines"),
193+
model=self.config["model"],
194+
messages=messages,
195+
response_format=schema,
196+
extra_body=dict(guided_decoding_backend="outlines"),
178197
)
179198
schema_response = completion.choices[0].message.parsed
180199
return schema_response

0 commit comments

Comments
 (0)