Skip to content

Commit 79b7364

Browse files
authored
Generic adapter support in the grpc server (#32)
Adds support for multi-lora adapters. Passing tests added over in this PR: https://github.ibm.com/ai-foundation/tgis-deploy-tests/pull/25/files --------- Signed-off-by: Joe Runde <[email protected]>
1 parent 0fe7794 commit 79b7364

File tree

6 files changed

+167
-5
lines changed

6 files changed

+167
-5
lines changed

proto/generation.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@ enum DecodingMethod {
2727

2828
message BatchedGenerationRequest {
2929
string model_id = 1;
30+
// Deprecated in favor of adapter_id
3031
optional string prefix_id = 2;
32+
optional string adapter_id = 4;
3133
repeated GenerationRequest requests = 3;
3234

3335
Parameters params = 10;
3436
}
3537

3638
message SingleGenerationRequest {
3739
string model_id = 1;
40+
// Deprecated in favor of adapter_id
3841
optional string prefix_id = 2;
42+
optional string adapter_id = 4;
3943
GenerationRequest request = 3;
4044

4145
Parameters params = 10;

vllm/entrypoints/grpc/adapters.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA)
2+
into valid LLM engine requests"""
3+
import asyncio
4+
import concurrent.futures
5+
import dataclasses
6+
import json
7+
import os
8+
import re
9+
from pathlib import Path
10+
from typing import Dict, Optional, Union
11+
12+
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
13+
SingleGenerationRequest)
14+
from vllm.entrypoints.grpc.validation import TGISValidationError
15+
from vllm.lora.request import LoRARequest
16+
17+
global_thread_pool = None # used for loading adapter files from disk
18+
19+
VALID_ADAPTER_ID_PATTERN = re.compile("[/\\w\\-]+")
20+
21+
22+
@dataclasses.dataclass
23+
class AdapterMetadata:
24+
unique_id: int # Unique integer for vllm to identify the adapter
25+
adapter_type: str # The string name of the peft adapter type, e.g. LORA
26+
full_path: str
27+
28+
29+
@dataclasses.dataclass
30+
class AdapterStore:
31+
cache_path: str # Path to local store of adapters to load from
32+
adapters: Dict[str, AdapterMetadata]
33+
next_unique_id: int = 1
34+
35+
36+
async def validate_adapters(
37+
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
38+
adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]:
39+
"""Takes the adapter name from the request and constructs a valid
40+
engine request if one is set. Raises if the requested adapter
41+
does not exist or adapter type is unsupported
42+
43+
Returns the kwarg dictionary to add to an engine.generate() call.
44+
"""
45+
global global_thread_pool
46+
adapter_id = request.adapter_id
47+
48+
if adapter_id and not adapter_store:
49+
TGISValidationError.AdaptersDisabled.error()
50+
51+
if not adapter_id or not adapter_store:
52+
return {}
53+
54+
# If not already cached, we need to validate that files exist and
55+
# grab the type out of the adapter_config.json file
56+
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
57+
_reject_bad_adapter_id(adapter_id)
58+
local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id)
59+
60+
loop = asyncio.get_running_loop()
61+
if global_thread_pool is None:
62+
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
63+
max_workers=2)
64+
65+
adapter_type = await loop.run_in_executor(global_thread_pool,
66+
_get_adapter_type_from_file,
67+
adapter_id,
68+
local_adapter_path)
69+
70+
# Add to cache
71+
adapter_metadata = AdapterMetadata(
72+
unique_id=adapter_store.next_unique_id,
73+
adapter_type=adapter_type,
74+
full_path=local_adapter_path)
75+
adapter_store.adapters[adapter_id] = adapter_metadata
76+
77+
# Build the proper vllm request object
78+
if adapter_metadata.adapter_type == "LORA":
79+
lora_request = LoRARequest(lora_name=adapter_id,
80+
lora_int_id=adapter_metadata.unique_id,
81+
lora_local_path=adapter_metadata.full_path)
82+
return {"lora_request": lora_request}
83+
84+
# All other types unsupported
85+
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type)
86+
87+
88+
def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
89+
"""This function does all the filesystem access required to deduce the type
90+
of the adapter. It's run in a separate thread pool executor so that file
91+
access does not block the main event loop."""
92+
if not os.path.exists(adapter_path):
93+
TGISValidationError.AdapterNotFound.error(adapter_id,
94+
"directory does not exist")
95+
96+
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
97+
if not os.path.exists(adapter_config_path):
98+
TGISValidationError.AdapterNotFound.error(
99+
adapter_id, "invalid adapter: no adapter_config.json found")
100+
101+
# NB: blocks event loop
102+
with open(adapter_config_path) as adapter_config_file:
103+
adapter_config = json.load(adapter_config_file)
104+
105+
return adapter_config.get("peft_type", None)
106+
107+
108+
def _reject_bad_adapter_id(adapter_id: str) -> None:
109+
"""Raise if the adapter id attempts path traversal or has invalid file path
110+
characters"""
111+
if not VALID_ADAPTER_ID_PATTERN.fullmatch(adapter_id):
112+
TGISValidationError.InvalidAdapterID.error(adapter_id)
113+
114+
# Check for path traversal
115+
root_path = Path("/some/file/root")
116+
derived_path = root_path / adapter_id
117+
if not os.path.normpath(derived_path).startswith(str(root_path) + "/"):
118+
TGISValidationError.InvalidAdapterID.error(adapter_id)

vllm/entrypoints/grpc/grpc_server.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
1515
SamplingParams)
1616
from vllm.config import ModelConfig
17+
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters
1718
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
1819
# yapf: disable
1920
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
@@ -33,6 +34,7 @@
3334
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
3435
from vllm.inputs import TextTokensPrompt
3536
from vllm.logger import init_logger
37+
from vllm.lora.request import LoRARequest
3638
from vllm.sequence import Logprob
3739
from vllm.tgis_utils import logs
3840
from vllm.tgis_utils.guided_decoding import (
@@ -119,6 +121,13 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
119121
self.skip_special_tokens = not args.output_special_tokens
120122
self.default_include_stop_seqs = args.default_include_stop_seqs
121123

124+
self.adapter_store: Optional[AdapterStore] = None
125+
if args.adapter_cache:
126+
self.adapter_store = AdapterStore(
127+
cache_path=args.adapter_cache,
128+
adapters={}
129+
)
130+
122131
async def _post_init(self):
123132
self.config = await self.engine.get_model_config()
124133
# self.tokenizer_group = await self.engine.get_tokenizer_group()
@@ -148,6 +157,9 @@ async def Generate(self, request: BatchedGenerationRequest,
148157

149158
generators = []
150159
max_is_token_limit = [False] * request_count
160+
161+
adapter_kwargs = await self._validate_adapters(request, context)
162+
151163
for i, req in enumerate(request.requests):
152164
input_ids, max_is_token_limit[i]\
153165
= await self._validate_prompt_and_tokenize(
@@ -161,7 +173,8 @@ async def Generate(self, request: BatchedGenerationRequest,
161173
# re-tokenized when `prompt_token_ids` is supplied
162174
self.engine.generate(inputs=inputs,
163175
sampling_params=sampling_params,
164-
request_id=f"{request_id}-{i}"),
176+
request_id=f"{request_id}-{i}",
177+
**adapter_kwargs),
165178
)
166179

167180
# TODO handle cancellation
@@ -218,6 +231,7 @@ async def GenerateStream(
218231
sampling_params, truncate_input_tokens, request.request.text,
219232
context)
220233

234+
adapter_kwargs = await self._validate_adapters(request, context)
221235
inputs = TextTokensPrompt(
222236
prompt=request.request.text,
223237
prompt_token_ids=input_ids
@@ -228,7 +242,8 @@ async def GenerateStream(
228242
# re-tokenized when `prompt_token_ids` is supplied
229243
inputs=inputs,
230244
sampling_params=sampling_params,
231-
request_id=request_id
245+
request_id=request_id,
246+
**adapter_kwargs
232247
)
233248

234249
resp_options = request.params.response
@@ -442,6 +457,19 @@ async def _validate_and_convert_params(
442457

443458
return sampling_params, deadline
444459

460+
async def _validate_adapters(self,
461+
request: Union[SingleGenerationRequest,
462+
BatchedGenerationRequest],
463+
context: ServicerContext) \
464+
-> Dict[str, LoRARequest]:
465+
try:
466+
adapters = await validate_adapters(
467+
request=request, adapter_store=self.adapter_store)
468+
except ValueError as e:
469+
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
470+
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
471+
return adapters
472+
445473
@staticmethod
446474
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
447475
time_limit_reached: bool

vllm/entrypoints/grpc/validation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
from enum import Enum
23

34
from vllm import SamplingParams
@@ -39,8 +40,13 @@ class TGISValidationError(str, Enum):
3940

4041
# Additions that are _not_ in TGIS
4142
TopN = "top_n_tokens ({0}) must be <= {1}"
43+
AdapterNotFound = "can't retrieve adapter with id '{0}': {1}"
44+
AdaptersDisabled = "adapter_id supplied but no adapter store was configured"
45+
AdapterUnsupported = "adapter type {0} is not currently supported"
46+
InvalidAdapterID = ("Invalid adapter id '{0}', must contain only "
47+
"alphanumeric, _ and - and /")
4248

43-
def error(self, *args, **kwargs):
49+
def error(self, *args, **kwargs) -> typing.NoReturn:
4450
"""Raises a ValueError with a nicely formatted string"""
4551
raise ValueError(self.value.format(*args, **kwargs))
4652

vllm/tgis_utils/args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
8282
parser.add_argument('--tls-key-path', type=str)
8383
# map to ssl_ca_certs
8484
parser.add_argument('--tls-client-ca-cert-path', type=str)
85+
# add a path when peft adapters will be loaded from
86+
parser.add_argument('--adapter-cache', type=str)
8587

8688
# TODO check/add other args here
8789

vllm/tgis_utils/logs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def log_response(
4141
response=response,
4242
params=request.params,
4343
prefix_id=request.prefix_id,
44+
adapter_id=request.adapter_id,
4445
engine_metrics=engine_metrics,
4546
start_time=start_time,
4647
kind_log=kind_log,
@@ -57,6 +58,7 @@ def log_error(request: Union[BatchedGenerationRequest,
5758
# of just logging the simple string representation of the error
5859
param_str = text_format.MessageToString(request.params, as_one_line=True)
5960
prefix_id = request.prefix_id
61+
adapter_id = request.adapter_id
6062

6163
if isinstance(request, BatchedGenerationRequest):
6264
method_str = "generate"
@@ -69,13 +71,14 @@ def log_error(request: Union[BatchedGenerationRequest,
6971
input_chars = sum(len(input_) for input_ in inputs)
7072

7173
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
72-
f"input_chars=[{input_chars}] params={param_str}")
74+
f"adapter_id={adapter_id} input_chars=[{input_chars}] "
75+
f"params={param_str}")
7376

7477
logger.error("%s: %s", span_str, exception_str)
7578

7679

7780
def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
78-
response: GenerationResponse,
81+
adapter_id: str, response: GenerationResponse,
7982
engine_metrics: Optional[RequestMetrics], start_time: float,
8083
kind_log: str, method_str: str, logger: logging.Logger):
8184
"""Logs responses similar to how the TGIS server does"""
@@ -99,6 +102,7 @@ def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
99102

100103
paramstr = text_format.MessageToString(params, as_one_line=True)
101104
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
105+
f"adapter_id={adapter_id} "
102106
f"input_chars=[{input_chars}] params={paramstr} "
103107
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
104108
f"queue_time={queue_time * 1e3:.2f}ms "

0 commit comments

Comments
 (0)