|
| 1 | +from typing import TYPE_CHECKING, List, Optional, Tuple |
| 2 | + |
| 3 | +import httpx |
| 4 | + |
| 5 | +from litellm.llms.azure.common_utils import BaseAzureLLM |
| 6 | +from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig |
| 7 | +from litellm.secret_managers.main import get_secret_str |
| 8 | +from litellm.types.llms.openai import AllMessageValues |
| 9 | +from litellm.types.router import GenericLiteLLMParams |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from httpx import URL |
| 13 | + |
| 14 | + |
| 15 | +class AzurePassthroughConfig(BasePassthroughConfig): |
| 16 | + def is_streaming_request(self, endpoint: str, request_data: dict) -> bool: |
| 17 | + return "stream" in request_data |
| 18 | + |
| 19 | + def get_complete_url( |
| 20 | + self, |
| 21 | + api_base: Optional[str], |
| 22 | + api_key: Optional[str], |
| 23 | + model: str, |
| 24 | + endpoint: str, |
| 25 | + request_query_params: Optional[dict], |
| 26 | + litellm_params: dict, |
| 27 | + ) -> Tuple["URL", str]: |
| 28 | + base_target_url = self.get_api_base(api_base) |
| 29 | + |
| 30 | + if base_target_url is None: |
| 31 | + raise Exception("Azure api base not found") |
| 32 | + |
| 33 | + litellm_metadata = litellm_params.get("litellm_metadata") or {} |
| 34 | + model_group = litellm_metadata.get("model_group") |
| 35 | + if model_group and model_group in endpoint: |
| 36 | + endpoint = endpoint.replace(model_group, model) |
| 37 | + |
| 38 | + complete_url = BaseAzureLLM._get_base_azure_url( |
| 39 | + api_base=base_target_url, |
| 40 | + litellm_params=litellm_params, |
| 41 | + route=endpoint, |
| 42 | + default_api_version=litellm_params.get("api_version"), |
| 43 | + ) |
| 44 | + return ( |
| 45 | + httpx.URL(complete_url), |
| 46 | + base_target_url, |
| 47 | + ) |
| 48 | + |
| 49 | + def validate_environment( |
| 50 | + self, |
| 51 | + headers: dict, |
| 52 | + model: str, |
| 53 | + messages: List[AllMessageValues], |
| 54 | + optional_params: dict, |
| 55 | + litellm_params: dict, |
| 56 | + api_key: Optional[str] = None, |
| 57 | + api_base: Optional[str] = None, |
| 58 | + ) -> dict: |
| 59 | + return BaseAzureLLM._base_validate_azure_environment( |
| 60 | + headers=headers, |
| 61 | + litellm_params=GenericLiteLLMParams( |
| 62 | + **{**litellm_params, "api_key": api_key} |
| 63 | + ), |
| 64 | + ) |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def get_api_base( |
| 68 | + api_base: Optional[str] = None, |
| 69 | + ) -> Optional[str]: |
| 70 | + return api_base or get_secret_str("AZURE_API_BASE") |
| 71 | + |
| 72 | + @staticmethod |
| 73 | + def get_api_key( |
| 74 | + api_key: Optional[str] = None, |
| 75 | + ) -> Optional[str]: |
| 76 | + return api_key or get_secret_str("AZURE_API_KEY") |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def get_base_model(model: str) -> Optional[str]: |
| 80 | + return model |
| 81 | + |
| 82 | + def get_models( |
| 83 | + self, api_key: Optional[str] = None, api_base: Optional[str] = None |
| 84 | + ) -> List[str]: |
| 85 | + return super().get_models(api_key, api_base) |
0 commit comments