Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic adapter support in the grpc server #32

Merged
merged 8 commits into from
Jun 11, 2024
Merged

Generic adapter support in the grpc server #32

merged 8 commits into from
Jun 11, 2024

Conversation

joerunde
Copy link
Collaborator

@joerunde joerunde commented May 24, 2024

Adds support for multi-lora adapters.

Passing tests added over in this PR: https://github.ibm.com/ai-foundation/tgis-deploy-tests/pull/25/files

@@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-key-path', type=str)
# map to ssl_ca_certs
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when lora adapters will be loaded from
parser.add_argument('--lora-adapter-cache', type=str)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open to ideas on naming here

@joerunde joerunde marked this pull request as ready for review May 29, 2024 17:52
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @joerunde this looks great! Just minor comments

proto/generation.proto Outdated Show resolved Hide resolved
vllm/entrypoints/grpc/adapters.py Outdated Show resolved Hide resolved
Comment on lines 27 to 64
lora_id = request.lora_id
if lora_id:
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())

local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)

# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")

# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
else:
lora_request = None

if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")

# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about flattening this a bit?

Suggested change
lora_id = request.lora_id
if lora_id:
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)
# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")
# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
else:
lora_request = None
if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")
# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None
if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")
lora_id = request.lora_id
if not lora_id:
return None, None
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)
# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")
# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, I un-nested but then re-nested so that the file checking and opening will only happen if the adapter wasn't already loaded

@joerunde joerunde changed the title Lora stuff Generic adapter support in the grpc server Jun 3, 2024
@@ -224,7 +224,7 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

lora_request, _ = await self._validate_adapters(request, context)
adapter_kwargs, _ = await self._validate_adapters(request, context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a tuple now right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, totally not. Interestingly python seems totally fine with the unpacking mismatch if you leave an underscore, TIL

TGISValidationError.AdapterNotFound.error(
adapter_id, "invalid adapter: no adapter_config.json found")

# NB: blocks event loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be important to address - to remove the all the file access from the event loop

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I looked into this a bit and it sounds like the asyncio file access in third party libs is... not very good.

I'm not 100% up to speed on event loops, would we want to make a new executor for this sorta like

file_load_executor = ThreadPoolExecutor(max_workers=n)
task = _load_the_config_json_file(...)
await loop.run_in_exeuctor(task, file_load_executor)

or would that just also block the loop?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly .. probably should just make that function be the all the code that's run if we don't find adapter in the dict (i.e. checking on disk, loading it etc).

There's a default asyncio executor that can be used for this kind of thing, or we may want a static one rather than creating one on the fly (not that you were necessarily suggesting that).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, I'll see if I can get that working quickly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill can I get a run from your static analysis on this change?

# If not already cached, we need to validate that files exist and
# grab the type out of the adapter_config.json file
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should sanitize the adapter_id here to make sure that the user can't send funny things like ../../../etc/passwd.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Copy link
Contributor

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left a comment suggesting a security improvement, but otherwise it looks good to me.

Signed-off-by: Joe Runde <[email protected]>
@joerunde joerunde merged commit 79b7364 into main Jun 11, 2024
15 checks passed
@njhill njhill deleted the lora-stuff branch June 13, 2024 16:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants