Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic adapter support in the grpc server #32

Merged
merged 8 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions proto/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ message BatchedGenerationRequest {
string model_id = 1;
optional string prefix_id = 2;
repeated GenerationRequest requests = 3;
optional string lora_id = 4;
joerunde marked this conversation as resolved.
Show resolved Hide resolved

Parameters params = 10;
}
Expand All @@ -37,6 +38,7 @@ message SingleGenerationRequest {
string model_id = 1;
optional string prefix_id = 2;
GenerationRequest request = 3;
optional string lora_id = 4;

Parameters params = 10;
}
Expand Down
64 changes: 64 additions & 0 deletions vllm/entrypoints/grpc/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA)
into valid LLM engine requests"""
import dataclasses
import os
from typing import Dict, Optional, Tuple, Union

from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
SingleGenerationRequest)
from vllm.entrypoints.grpc.validation import TGISValidationError
from vllm.lora.request import LoRARequest


@dataclasses.dataclass
class AdapterStore:
cache_path: str # Path to local store of adapters to load from
unique_id_map: Dict[str, int] # maps adapter names to unique integer IDs
next_unique_id: int = 1


def validate_adapters(
request: Union[SingleGenerationRequest, BatchedGenerationRequest],
lora_adapter_store: Optional[AdapterStore]
) -> Tuple[Optional[LoRARequest], None]:
"""Takes the adapter names from the request and constructs a valid
engine request if one is set. Raises if the requested adapter
does not exist"""
lora_id = request.lora_id
if lora_id:
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())

local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)

# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")

# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
joerunde marked this conversation as resolved.
Show resolved Hide resolved
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
else:
lora_request = None

if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")

# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about flattening this a bit?

Suggested change
lora_id = request.lora_id
if lora_id:
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)
# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")
# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
else:
lora_request = None
if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")
# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None
if request.prefix_id:
# TODO: hook up PromptAdapterRequest once implemented in the engine
raise ValueError("prefix_id not implemented yet")
lora_id = request.lora_id
if not lora_id:
return None, None
if not lora_adapter_store:
# using raise/format instead of .error so mypy knows this raises
raise ValueError(TGISValidationError.LoraDisabled.value.format())
local_lora_path = os.path.join(lora_adapter_store.cache_path, lora_id)
# Do a bit of up-front validation so that we don't ask the engine
# to try to load an invalid adapter
if not os.path.exists(local_lora_path):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "directory does not exist")
if not os.path.exists(
os.path.join(local_lora_path, "adapter_config.json")):
TGISValidationError.LoraAdapterNotFound.error(
lora_id, "invalid adapter: no adapter_config.json found")
# We need to track a unique integer for vLLM to identify the lora
# adapters
if lora_id not in lora_adapter_store.unique_id_map:
lora_adapter_store.unique_id_map[
lora_id] = lora_adapter_store.next_unique_id
lora_adapter_store.next_unique_id += 1
unique_id = lora_adapter_store.unique_id_map[lora_id]
lora_request = LoRARequest(lora_name=lora_id,
lora_int_id=unique_id,
lora_local_path=local_lora_path)
# Second return slot left here for the incoming PromptAdapterRequest
# See https://github.com/vllm-project/vllm/pull/4645/files
return lora_request, None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, I un-nested but then re-nested so that the file checking and opening will only happen if the adapter wasn't already loaded

34 changes: 32 additions & 2 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm import (AsyncLLMEngine, CompletionOutput, RequestOutput,
SamplingParams)
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.adapters import AdapterStore, validate_adapters
from vllm.entrypoints.grpc.pb import generation_pb2_grpc # type: ignore
# yapf: disable
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
Expand All @@ -32,6 +33,7 @@
from vllm.entrypoints.grpc.validation import validate_input, validate_params
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
from vllm.tgis_utils import logs
from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper,
Expand Down Expand Up @@ -116,9 +118,17 @@ def __init__(self, engine: AsyncLLMEngine, args: argparse.Namespace):
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

self.lora_adapter_store: Optional[AdapterStore] = None
if args.lora_adapter_cache:
self.lora_adapter_store = AdapterStore(
cache_path=args.lora_adapter_cache,
unique_id_map={}
)

async def _post_init(self):
self.config = await self.engine.get_model_config()
self.tokenizer_group = await self.engine.get_tokenizer_group()
# self.tokenizer_group = await self.engine.get_tokenizer_group()
joerunde marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer_group = self.engine.engine.tokenizer
joerunde marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = await self.engine.get_tokenizer()

# Swap in the special TGIS stats logger
Expand All @@ -144,6 +154,9 @@ async def Generate(self, request: BatchedGenerationRequest,

generators = []
max_is_token_limit = [False] * request_count

lora_request, _ = await self._validate_adapters(request, context)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i]\
= await self._validate_prompt_and_tokenize(
Expand All @@ -154,7 +167,8 @@ async def Generate(self, request: BatchedGenerationRequest,
self.engine.generate(prompt=req.text,
sampling_params=sampling_params,
request_id=f"{request_id}-{i}",
prompt_token_ids=input_ids),
prompt_token_ids=input_ids,
lora_request=lora_request),
)

# TODO handle cancellation
Expand Down Expand Up @@ -210,13 +224,16 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text,
context)

lora_request, _ = await self._validate_adapters(request, context)

result_generator = self.engine.generate(
# prompt is supplied for observability, the text is not
# re-tokenized when `prompt_token_ids` is supplied
prompt=request.request.text,
sampling_params=sampling_params,
request_id=request_id,
prompt_token_ids=input_ids,
lora_request=lora_request
)

resp_options = request.params.response
Expand Down Expand Up @@ -423,6 +440,19 @@ async def _validate_and_convert_params(

return sampling_params, deadline

async def _validate_adapters(self,
request: Union[SingleGenerationRequest,
BatchedGenerationRequest],
context: ServicerContext) \
-> Tuple[Optional[LoRARequest], None]:
try:
adapters = validate_adapters(
request=request, lora_adapter_store=self.lora_adapter_store)
except ValueError as e:
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

@staticmethod
def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
time_limit_reached: bool
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/grpc/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class TGISValidationError(str, Enum):

# Additions that are _not_ in TGIS
TopN = "top_n_tokens ({0}) must be <= {1}"
LoraAdapterNotFound = "can't retrieve LoRA adapter with id '{0}': {1}"
LoraDisabled = "lora_id supplied but no lora adapter store was configured"

def error(self, *args, **kwargs):
"""Raises a ValueError with a nicely formatted string"""
Expand Down
2 changes: 2 additions & 0 deletions vllm/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--tls-key-path', type=str)
# map to ssl_ca_certs
parser.add_argument('--tls-client-ca-cert-path', type=str)
# add a path when lora adapters will be loaded from
parser.add_argument('--lora-adapter-cache', type=str)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open to ideas on naming here


# TODO check/add other args here

Expand Down
8 changes: 6 additions & 2 deletions vllm/tgis_utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def log_response(
response=response,
params=request.params,
prefix_id=request.prefix_id,
lora_id=request.lora_id,
engine_metrics=engine_metrics,
start_time=start_time,
kind_log=kind_log,
Expand All @@ -57,6 +58,7 @@ def log_error(request: Union[BatchedGenerationRequest,
# of just logging the simple string representation of the error
param_str = text_format.MessageToString(request.params, as_one_line=True)
prefix_id = request.prefix_id
lora_id = request.lora_id

if isinstance(request, BatchedGenerationRequest):
method_str = "generate"
Expand All @@ -69,13 +71,14 @@ def log_error(request: Union[BatchedGenerationRequest,
input_chars = sum(len(input_) for input_ in inputs)

span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"input_chars=[{input_chars}] params={param_str}")
f"lora_id={lora_id} input_chars=[{input_chars}] "
f"params={param_str}")

logger.error("%s: %s", span_str, exception_str)


def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
response: GenerationResponse,
lora_id: str, response: GenerationResponse,
engine_metrics: Optional[RequestMetrics], start_time: float,
kind_log: str, method_str: str, logger: logging.Logger):
"""Logs responses similar to how the TGIS server does"""
Expand All @@ -99,6 +102,7 @@ def _log_response(inputs: List[str], params: Parameters, prefix_id: str,

paramstr = text_format.MessageToString(params, as_one_line=True)
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
f"lora_id={lora_id} "
f"input_chars=[{input_chars}] params={paramstr} "
f"tokenization_time={tokenization_time * 1e3:.2f}ms "
f"queue_time={queue_time * 1e3:.2f}ms "
Expand Down
Loading