Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt adapter support #44

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,8 @@ async def generate(
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: PromptAdapterRequest to use for generation,
if any.

Yields:
The output `RequestOutput` objects from the LLMEngine
Expand Down
33 changes: 24 additions & 9 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

global_thread_pool = None # used for loading adapter files from disk

Expand All @@ -24,6 +25,7 @@ class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str
full_config: Dict # The loaded adapter_config.json dict


@dataclasses.dataclass
Expand All @@ -34,8 +36,9 @@ class AdapterStore:


async def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]:
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
adapter_store: Optional[AdapterStore]
) -> Dict[str, Union[LoRARequest, PromptAdapterRequest]]:
"""Takes the adapter name from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist or adapter type is unsupported
Expand All @@ -44,6 +47,9 @@ async def validate_adapters(
"""
global global_thread_pool
adapter_id = request.adapter_id
# Backwards compatibility for `prefix_id` arg
if not adapter_id and request.prefix_id:
adapter_id = request.prefix_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()
Expand All @@ -62,16 +68,17 @@ async def validate_adapters(
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)

adapter_type = await loop.run_in_executor(global_thread_pool,
_get_adapter_type_from_file,
adapter_id,
local_adapter_path)
adapter_config = await (loop.run_in_executor(
global_thread_pool, _load_adapter_config_from_file, adapter_id,
local_adapter_path))
adapter_type = adapter_config.get("peft_type", None)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path)
full_path=local_adapter_path,
full_config=adapter_config)
adapter_store.adapters[adapter_id] = adapter_metadata

# Build the proper vllm request object
Expand All @@ -80,12 +87,20 @@ async def validate_adapters(
lora_int_id=adapter_metadata.unique_id,
lora_local_path=adapter_metadata.full_path)
return {"lora_request": lora_request}
elif adapter_metadata.adapter_type == "PROMPT_TUNING":
prompt_adapter_request = PromptAdapterRequest(
prompt_adapter_id=adapter_metadata.unique_id,
prompt_adapter_name=adapter_id,
prompt_adapter_local_path=adapter_metadata.full_path,
prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get(
"num_virtual_tokens", 0))
return {"prompt_adapter_request": prompt_adapter_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type)


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> Dict:
"""This function does all the filesystem access required to deduce the type
of the adapter. It's run in a separate thread pool executor so that file
access does not block the main event loop."""
Expand All @@ -102,7 +117,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)
return adapter_config


def _reject_bad_adapter_id(adapter_id: str) -> None:
Expand Down
9 changes: 6 additions & 3 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.inputs import TextTokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.guided_decoding import (
Expand Down Expand Up @@ -122,9 +123,11 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.default_include_stop_seqs = args.default_include_stop_seqs

self.adapter_store: Optional[AdapterStore] = None
if args.adapter_cache:
# Backwards compatibility for TGIS: PREFIX_STORE_PATH
adapter_cache_path = args.adapter_cache or args.prefix_store_path
if adapter_cache_path:
self.adapter_store = AdapterStore(
cache_path=args.adapter_cache,
cache_path=adapter_cache_path,
adapters={}
)

Expand Down Expand Up @@ -476,7 +479,7 @@ async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Dict[str, LoRARequest]:
-> Dict[str, Union[LoRARequest, PromptAdapterRequest]]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store)
Expand Down
4 changes: 4 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when peft adapters will be loaded from
parser.add_argument('--adapter-cache', type=str)
# backwards-compatibility support for tgis prompt tuning
parser.add_argument('--prefix-store-path',
type=str,
help="Deprecated, use --adapter-cache")

# TODO check/add other args here

Expand Down
Loading