Skip to content

Commit 077b5e1

Browse files
Merge pull request #15240 from BerriAI/litellm_dev_10_06_2025_p1
Azure - passthrough support with router models
2 parents 8b357c2 + 94a34dd commit 077b5e1

File tree

7 files changed

+463
-132
lines changed

7 files changed

+463
-132
lines changed

litellm/llms/azure/common_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,10 +665,6 @@ def _base_validate_azure_environment(
665665
) -> dict:
666666
litellm_params = litellm_params or GenericLiteLLMParams()
667667

668-
# If api-key is already in headers, preserve it
669-
if "api-key" in headers:
670-
return headers
671-
672668
api_key = (
673669
litellm_params.api_key
674670
or litellm.api_key
@@ -693,7 +689,7 @@ def _base_validate_azure_environment(
693689
def _get_base_azure_url(
694690
api_base: Optional[str],
695691
litellm_params: Optional[Union[GenericLiteLLMParams, Dict[str, Any]]],
696-
route: Literal["/openai/responses", "/openai/vector_stores"],
692+
route: Union[Literal["/openai/responses", "/openai/vector_stores"], str],
697693
default_api_version: Optional[Union[str, Literal["latest", "preview"]]] = None,
698694
) -> str:
699695
"""
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import TYPE_CHECKING, List, Optional, Tuple
2+
3+
import httpx
4+
5+
from litellm.llms.azure.common_utils import BaseAzureLLM
6+
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
7+
from litellm.secret_managers.main import get_secret_str
8+
from litellm.types.llms.openai import AllMessageValues
9+
from litellm.types.router import GenericLiteLLMParams
10+
11+
if TYPE_CHECKING:
12+
from httpx import URL
13+
14+
15+
class AzurePassthroughConfig(BasePassthroughConfig):
16+
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
17+
return "stream" in request_data
18+
19+
def get_complete_url(
20+
self,
21+
api_base: Optional[str],
22+
api_key: Optional[str],
23+
model: str,
24+
endpoint: str,
25+
request_query_params: Optional[dict],
26+
litellm_params: dict,
27+
) -> Tuple["URL", str]:
28+
base_target_url = self.get_api_base(api_base)
29+
30+
if base_target_url is None:
31+
raise Exception("Azure api base not found")
32+
33+
litellm_metadata = litellm_params.get("litellm_metadata") or {}
34+
model_group = litellm_metadata.get("model_group")
35+
if model_group and model_group in endpoint:
36+
endpoint = endpoint.replace(model_group, model)
37+
38+
complete_url = BaseAzureLLM._get_base_azure_url(
39+
api_base=base_target_url,
40+
litellm_params=litellm_params,
41+
route=endpoint,
42+
default_api_version=litellm_params.get("api_version"),
43+
)
44+
return (
45+
httpx.URL(complete_url),
46+
base_target_url,
47+
)
48+
49+
def validate_environment(
50+
self,
51+
headers: dict,
52+
model: str,
53+
messages: List[AllMessageValues],
54+
optional_params: dict,
55+
litellm_params: dict,
56+
api_key: Optional[str] = None,
57+
api_base: Optional[str] = None,
58+
) -> dict:
59+
return BaseAzureLLM._base_validate_azure_environment(
60+
headers=headers,
61+
litellm_params=GenericLiteLLMParams(
62+
**{**litellm_params, "api_key": api_key}
63+
),
64+
)
65+
66+
@staticmethod
67+
def get_api_base(
68+
api_base: Optional[str] = None,
69+
) -> Optional[str]:
70+
return api_base or get_secret_str("AZURE_API_BASE")
71+
72+
@staticmethod
73+
def get_api_key(
74+
api_key: Optional[str] = None,
75+
) -> Optional[str]:
76+
return api_key or get_secret_str("AZURE_API_KEY")
77+
78+
@staticmethod
79+
def get_base_model(model: str) -> Optional[str]:
80+
return model
81+
82+
def get_models(
83+
self, api_key: Optional[str] = None, api_base: Optional[str] = None
84+
) -> List[str]:
85+
return super().get_models(api_key, api_base)

litellm/passthrough/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,14 @@ def llm_passthrough_route(
242242
request_query_params=request_query_params,
243243
litellm_params=litellm_params_dict,
244244
)
245-
246-
# need to encode the id of application-inference-profile for bedrock
245+
246+
# [TODO: Refactor to bedrockpassthroughconfig] need to encode the id of application-inference-profile for bedrock
247247
if custom_llm_provider == "bedrock" and "application-inference-profile" in endpoint:
248-
encoded_url_str = CommonUtils.encode_bedrock_runtime_modelid_arn(str(updated_url))
248+
encoded_url_str = CommonUtils.encode_bedrock_runtime_modelid_arn(
249+
str(updated_url)
250+
)
249251
updated_url = httpx.URL(encoded_url_str)
250-
252+
251253
# Add or update query parameters
252254
provider_api_key = provider_config.get_api_key(api_key)
253255

0 commit comments

Comments
 (0)