-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generic adapter support in the grpc server #32
Conversation
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
vllm/tgis_utils/args.py
Outdated
@@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | |||
parser.add_argument('--tls-key-path', type=str) | |||
# map to ssl_ca_certs | |||
parser.add_argument('--tls-client-ca-cert-path', type=str) | |||
# add a path when lora adapters will be loaded from | |||
parser.add_argument('--lora-adapter-cache', type=str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
open to ideas on naming here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @joerunde this looks great! Just minor comments
vllm/entrypoints/grpc/adapters.py
Outdated
lora_id = request.lora_id | ||
if lora_id: | ||
if not lora_adapter_store: | ||
# using raise/format instead of .error so mypy knows this raises | ||
raise ValueError(TGISValidationError.LoraDisabled.value.format()) | ||
|
||
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id) | ||
|
||
# Do a bit of up-front validation so that we don't ask the engine | ||
# to try to load an invalid adapter | ||
if not os.path.exists(local_lora_path): | ||
TGISValidationError.LoraAdapterNotFound.error( | ||
lora_id, "directory does not exist") | ||
if not os.path.exists( | ||
os.path.join(local_lora_path, "adapter_config.json")): | ||
TGISValidationError.LoraAdapterNotFound.error( | ||
lora_id, "invalid adapter: no adapter_config.json found") | ||
|
||
# We need to track a unique integer for vLLM to identify the lora | ||
# adapters | ||
if lora_id not in lora_adapter_store.unique_id_map: | ||
lora_adapter_store.unique_id_map[ | ||
lora_id] = lora_adapter_store.next_unique_id | ||
lora_adapter_store.next_unique_id += 1 | ||
unique_id = lora_adapter_store.unique_id_map[lora_id] | ||
lora_request = LoRARequest(lora_name=lora_id, | ||
lora_int_id=unique_id, | ||
lora_local_path=local_lora_path) | ||
else: | ||
lora_request = None | ||
|
||
if request.prefix_id: | ||
# TODO: hook up PromptAdapterRequest once implemented in the engine | ||
raise ValueError("prefix_id not implemented yet") | ||
|
||
# Second return slot left here for the incoming PromptAdapterRequest | ||
# See https://github.com/vllm-project/vllm/pull/4645/files | ||
return lora_request, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about flattening this a bit?
lora_id = request.lora_id | |
if lora_id: | |
if not lora_adapter_store: | |
# using raise/format instead of .error so mypy knows this raises | |
raise ValueError(TGISValidationError.LoraDisabled.value.format()) | |
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id) | |
# Do a bit of up-front validation so that we don't ask the engine | |
# to try to load an invalid adapter | |
if not os.path.exists(local_lora_path): | |
TGISValidationError.LoraAdapterNotFound.error( | |
lora_id, "directory does not exist") | |
if not os.path.exists( | |
os.path.join(local_lora_path, "adapter_config.json")): | |
TGISValidationError.LoraAdapterNotFound.error( | |
lora_id, "invalid adapter: no adapter_config.json found") | |
# We need to track a unique integer for vLLM to identify the lora | |
# adapters | |
if lora_id not in lora_adapter_store.unique_id_map: | |
lora_adapter_store.unique_id_map[ | |
lora_id] = lora_adapter_store.next_unique_id | |
lora_adapter_store.next_unique_id += 1 | |
unique_id = lora_adapter_store.unique_id_map[lora_id] | |
lora_request = LoRARequest(lora_name=lora_id, | |
lora_int_id=unique_id, | |
lora_local_path=local_lora_path) | |
else: | |
lora_request = None | |
if request.prefix_id: | |
# TODO: hook up PromptAdapterRequest once implemented in the engine | |
raise ValueError("prefix_id not implemented yet") | |
# Second return slot left here for the incoming PromptAdapterRequest | |
# See https://github.com/vllm-project/vllm/pull/4645/files | |
return lora_request, None | |
if request.prefix_id: | |
# TODO: hook up PromptAdapterRequest once implemented in the engine | |
raise ValueError("prefix_id not implemented yet") | |
lora_id = request.lora_id | |
if not lora_id: | |
return None, None | |
if not lora_adapter_store: | |
# using raise/format instead of .error so mypy knows this raises | |
raise ValueError(TGISValidationError.LoraDisabled.value.format()) | |
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id) | |
# Do a bit of up-front validation so that we don't ask the engine | |
# to try to load an invalid adapter | |
if not os.path.exists(local_lora_path): | |
TGISValidationError.LoraAdapterNotFound.error( | |
lora_id, "directory does not exist") | |
if not os.path.exists( | |
os.path.join(local_lora_path, "adapter_config.json")): | |
TGISValidationError.LoraAdapterNotFound.error( | |
lora_id, "invalid adapter: no adapter_config.json found") | |
# We need to track a unique integer for vLLM to identify the lora | |
# adapters | |
if lora_id not in lora_adapter_store.unique_id_map: | |
lora_adapter_store.unique_id_map[ | |
lora_id] = lora_adapter_store.next_unique_id | |
lora_adapter_store.next_unique_id += 1 | |
unique_id = lora_adapter_store.unique_id_map[lora_id] | |
lora_request = LoRARequest(lora_name=lora_id, | |
lora_int_id=unique_id, | |
lora_local_path=local_lora_path) | |
# Second return slot left here for the incoming PromptAdapterRequest | |
# See https://github.com/vllm-project/vllm/pull/4645/files | |
return lora_request, None | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, I un-nested but then re-nested so that the file checking and opening will only happen if the adapter wasn't already loaded
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
vllm/entrypoints/grpc/grpc_server.py
Outdated
@@ -224,7 +224,7 @@ async def GenerateStream( | |||
sampling_params, truncate_input_tokens, request.request.text, | |||
context) | |||
|
|||
lora_request, _ = await self._validate_adapters(request, context) | |||
adapter_kwargs, _ = await self._validate_adapters(request, context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a tuple now right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah, totally not. Interestingly python seems totally fine with the unpacking mismatch if you leave an underscore, TIL
vllm/entrypoints/grpc/adapters.py
Outdated
TGISValidationError.AdapterNotFound.error( | ||
adapter_id, "invalid adapter: no adapter_config.json found") | ||
|
||
# NB: blocks event loop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will be important to address - to remove the all the file access from the event loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I looked into this a bit and it sounds like the asyncio file access in third party libs is... not very good.
I'm not 100% up to speed on event loops, would we want to make a new executor for this sorta like
file_load_executor = ThreadPoolExecutor(max_workers=n)
task = _load_the_config_json_file(...)
await loop.run_in_exeuctor(task, file_load_executor)
or would that just also block the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah exactly .. probably should just make that function be the all the code that's run if we don't find adapter in the dict (i.e. checking on disk, loading it etc).
There's a default asyncio executor that can be used for this kind of thing, or we may want a static one rather than creating one on the fly (not that you were necessarily suggesting that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, I'll see if I can get that working quickly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill can I get a run from your static analysis on this change?
Signed-off-by: Joe Runde <[email protected]>
# If not already cached, we need to validate that files exist and | ||
# grab the type out of the adapter_config.json file | ||
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None: | ||
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should sanitize the adapter_id
here to make sure that the user can't send funny things like ../../../etc/passwd
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a comment suggesting a security improvement, but otherwise it looks good to me.
Signed-off-by: Joe Runde <[email protected]>
Adds support for multi-lora adapters.
Passing tests added over in this PR: https://github.ibm.com/ai-foundation/tgis-deploy-tests/pull/25/files