Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ message BatchedGenerationRequest {
string model_id = 1;
optional string prefix_id = 2;
repeated GenerationRequest requests = 3;
optional string lora_id = 4;

Parameters params = 10;
}
Expand All @@ -37,6 +38,7 @@ message SingleGenerationRequest {
string model_id = 1;
optional string prefix_id = 2;
GenerationRequest request = 3;
optional string lora_id = 4;

Parameters params = 10;
}
Expand Down
64 changes: 64 additions & 0 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA)
into valid LLM engine requests"""
import dataclasses
import os
from typing import Dict, Optional, Tuple, Union

from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest


@dataclasses.dataclass
class AdapterStore:
cache_path: str # Path to local store of adapters to load from
unique_id_map: Dict[str, int] # maps adapter names to unique integer IDs
next_unique_id: int = 1


def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
lora_adapter_store: Optional[AdapterStore]
) -> Tuple[Optional[LoRARequest], None]:
"""Takes the adapter names from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist"""
lora_id = request.lora_id
if lora_id:
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())

local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)

# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")

# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
else:
lora_request = None

if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")

# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about flattening this a bit?

Suggested change
lora_id = request.lora_id
if lora_id:
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)
# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")
# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
else:
lora_request = None
if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")
# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None
if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")
lora_id = request.lora_id
if not lora_id:
return None, None
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)
# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")
# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, I un-nested but then re-nested so that the file checking and opening will only happen if the adapter wasn't already loaded

34 changes: 32 additions & 2 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
SamplingParams)
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
# yapf: disable
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
Expand All @@ -32,6 +33,7 @@
from vllm.entrypoints.grpc.validation import validate_input, validate_params
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
Expand Down Expand Up @@ -116,9 +118,17 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

self.lora_adapter_store: Optional[AdapterStore] = None
if args.lora_adapter_cache:
self.lora_adapter_store = AdapterStore(
cache_path=args.lora_adapter_cache,
unique_id_map={}
)

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
self.tokenizer_group = self.engine.engine.tokenizer
self.tokenizer = await self.engine.get_tokenizer()

# Swap in the special TGIS stats logger
Expand All @@ -144,6 +154,9 @@ async def Generate(self, request: BatchedGenerationRequest,

generators = []
max_is_token_limit = [False] * request_count

lora_request, _ = await self._validate_adapters(request, context)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
Expand All @@ -154,7 +167,8 @@ async def Generate(self, request: BatchedGenerationRequest,
self.engine.generate(prompt=req.text,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
prompt_token_ids=input_ids),
prompt_token_ids=input_ids,
lora_request=lora_request),
)

# TODO handle cancellation
Expand Down Expand Up @@ -210,13 +224,16 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

lora_request, _ = await self._validate_adapters(request, context)

result_generator = self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
prompt=request.request.text,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
lora_request=lora_request
)

resp_options = request.params.response
Expand Down Expand Up @@ -423,6 +440,19 @@ async def _validate_and_convert_params(

return sampling_params, deadline

async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Tuple[Optional[LoRARequest], None]:
try:
adapters = validate_adapters(
request=request, lora_adapter_store=self.lora_adapter_store)
except ValueError as e:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

@staticmethod
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
time_limit_reached: bool
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/grpc/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class TGISValidationError(str, Enum):

# Additions that are _not_ in TGIS
TopN = "top_n_tokens ({0}) must be <= {1}"
LoraAdapterNotFound = "can't retrieve LoRA adapter with id '{0}': {1}"
LoraDisabled = "lora_id supplied but no lora adapter store was configured"

def error(self, *args, **kwargs):
"""Raises a ValueError with a nicely formatted string"""
Expand Down
2 changes: 2 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-key-path', type=str)
# map to ssl_ca_certs
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when lora adapters will be loaded from
parser.add_argument('--lora-adapter-cache', type=str)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open to ideas on naming here


# TODO check/add other args here

Expand Down
8 changes: 6 additions & 2 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def log_response(
response=response,
params=request.params,
prefix_id=request.prefix_id,
lora_id=request.lora_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
Expand All @@ -57,6 +58,7 @@ def log_error(request: Union[BatchedGenerationRequest,
# of just logging the simple string representation of the error
param_str = text_format.MessageToString(request.params, as_one_line=True)
prefix_id = request.prefix_id
lora_id = request.lora_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
Expand All @@ -69,13 +71,14 @@ def log_error(request: Union[BatchedGenerationRequest,
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={param_str}")
f"lora_id={lora_id} input_chars=[{input_chars}] "
f"params={param_str}")

logger.error("%s: %s", span_str, exception_str)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
lora_id: str, response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
Expand All @@ -99,6 +102,7 @@ def _log_response(inputs: List[str], params: Parameters, prefix_id: str,

paramstr = text_format.MessageToString(params, as_one_line=True)
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"lora_id={lora_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_time={queue_time * 1e3:.2f}ms "
Expand Down