|
| 1 | +"""Contains code to map api requests for adapters (e.g. peft prefixes, LoRA) |
| 2 | +into valid LLM engine requests""" |
| 3 | +import asyncio |
| 4 | +import concurrent.futures |
| 5 | +import dataclasses |
| 6 | +import json |
| 7 | +import os |
| 8 | +import re |
| 9 | +from pathlib import Path |
| 10 | +from typing import Dict, Optional, Union |
| 11 | + |
| 12 | +from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest, |
| 13 | + SingleGenerationRequest) |
| 14 | +from vllm.entrypoints.grpc.validation import TGISValidationError |
| 15 | +from vllm.lora.request import LoRARequest |
| 16 | + |
| 17 | +global_thread_pool = None # used for loading adapter files from disk |
| 18 | + |
| 19 | +VALID_ADAPTER_ID_PATTERN = re.compile("[/\\w\\-]+") |
| 20 | + |
| 21 | + |
| 22 | +@dataclasses.dataclass |
| 23 | +class AdapterMetadata: |
| 24 | + unique_id: int # Unique integer for vllm to identify the adapter |
| 25 | + adapter_type: str # The string name of the peft adapter type, e.g. LORA |
| 26 | + full_path: str |
| 27 | + |
| 28 | + |
| 29 | +@dataclasses.dataclass |
| 30 | +class AdapterStore: |
| 31 | + cache_path: str # Path to local store of adapters to load from |
| 32 | + adapters: Dict[str, AdapterMetadata] |
| 33 | + next_unique_id: int = 1 |
| 34 | + |
| 35 | + |
| 36 | +async def validate_adapters( |
| 37 | + request: Union[SingleGenerationRequest, BatchedGenerationRequest], |
| 38 | + adapter_store: Optional[AdapterStore]) -> Dict[str, LoRARequest]: |
| 39 | + """Takes the adapter name from the request and constructs a valid |
| 40 | + engine request if one is set. Raises if the requested adapter |
| 41 | + does not exist or adapter type is unsupported |
| 42 | +
|
| 43 | + Returns the kwarg dictionary to add to an engine.generate() call. |
| 44 | + """ |
| 45 | + global global_thread_pool |
| 46 | + adapter_id = request.adapter_id |
| 47 | + |
| 48 | + if adapter_id and not adapter_store: |
| 49 | + TGISValidationError.AdaptersDisabled.error() |
| 50 | + |
| 51 | + if not adapter_id or not adapter_store: |
| 52 | + return {} |
| 53 | + |
| 54 | + # If not already cached, we need to validate that files exist and |
| 55 | + # grab the type out of the adapter_config.json file |
| 56 | + if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None: |
| 57 | + _reject_bad_adapter_id(adapter_id) |
| 58 | + local_adapter_path = os.path.join(adapter_store.cache_path, adapter_id) |
| 59 | + |
| 60 | + loop = asyncio.get_running_loop() |
| 61 | + if global_thread_pool is None: |
| 62 | + global_thread_pool = concurrent.futures.ThreadPoolExecutor( |
| 63 | + max_workers=2) |
| 64 | + |
| 65 | + adapter_type = await loop.run_in_executor(global_thread_pool, |
| 66 | + _get_adapter_type_from_file, |
| 67 | + adapter_id, |
| 68 | + local_adapter_path) |
| 69 | + |
| 70 | + # Add to cache |
| 71 | + adapter_metadata = AdapterMetadata( |
| 72 | + unique_id=adapter_store.next_unique_id, |
| 73 | + adapter_type=adapter_type, |
| 74 | + full_path=local_adapter_path) |
| 75 | + adapter_store.adapters[adapter_id] = adapter_metadata |
| 76 | + |
| 77 | + # Build the proper vllm request object |
| 78 | + if adapter_metadata.adapter_type == "LORA": |
| 79 | + lora_request = LoRARequest(lora_name=adapter_id, |
| 80 | + lora_int_id=adapter_metadata.unique_id, |
| 81 | + lora_local_path=adapter_metadata.full_path) |
| 82 | + return {"lora_request": lora_request} |
| 83 | + |
| 84 | + # All other types unsupported |
| 85 | + TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) |
| 86 | + |
| 87 | + |
| 88 | +def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: |
| 89 | + """This function does all the filesystem access required to deduce the type |
| 90 | + of the adapter. It's run in a separate thread pool executor so that file |
| 91 | + access does not block the main event loop.""" |
| 92 | + if not os.path.exists(adapter_path): |
| 93 | + TGISValidationError.AdapterNotFound.error(adapter_id, |
| 94 | + "directory does not exist") |
| 95 | + |
| 96 | + adapter_config_path = os.path.join(adapter_path, "adapter_config.json") |
| 97 | + if not os.path.exists(adapter_config_path): |
| 98 | + TGISValidationError.AdapterNotFound.error( |
| 99 | + adapter_id, "invalid adapter: no adapter_config.json found") |
| 100 | + |
| 101 | + # NB: blocks event loop |
| 102 | + with open(adapter_config_path) as adapter_config_file: |
| 103 | + adapter_config = json.load(adapter_config_file) |
| 104 | + |
| 105 | + return adapter_config.get("peft_type", None) |
| 106 | + |
| 107 | + |
| 108 | +def _reject_bad_adapter_id(adapter_id: str) -> None: |
| 109 | + """Raise if the adapter id attempts path traversal or has invalid file path |
| 110 | + characters""" |
| 111 | + if not VALID_ADAPTER_ID_PATTERN.fullmatch(adapter_id): |
| 112 | + TGISValidationError.InvalidAdapterID.error(adapter_id) |
| 113 | + |
| 114 | + # Check for path traversal |
| 115 | + root_path = Path("/some/file/root") |
| 116 | + derived_path = root_path / adapter_id |
| 117 | + if not os.path.normpath(derived_path).startswith(str(root_path) + "/"): |
| 118 | + TGISValidationError.InvalidAdapterID.error(adapter_id) |
0 commit comments