|
1 | 1 | import asyncio
|
2 | 2 | import re
|
| 3 | +import warnings |
3 | 4 | from asyncio import Task
|
4 | 5 | from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast
|
5 | 6 | from inspect import getfullargspec
|
@@ -154,25 +155,95 @@ def assert_valid_name(name: str) -> str:
|
154 | 155 |
|
155 | 156 |
|
156 | 157 | class AzureAIChatCompletionClient(ChatCompletionClient):
|
| 158 | + """ |
| 159 | + Chat completion client for models hosted on Azure AI Foundry or GitHub Models. |
| 160 | + See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info. |
| 161 | +
|
| 162 | + Args: |
| 163 | + endpoint (str): The endpoint to use. **Required.** |
| 164 | + credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required** |
| 165 | + model_capabilities (ModelCapabilities): The capabilities of the model. **Required.** |
| 166 | + model (str): The name of the model. **Required if model is hosted on GitHub Models.** |
| 167 | + frequency_penalty: (optional,float) |
| 168 | + presence_penalty: (optional,float) |
| 169 | + temperature: (optional,float) |
| 170 | + top_p: (optional,float) |
| 171 | + max_tokens: (optional,int) |
| 172 | + response_format: (optional,ChatCompletionsResponseFormat) |
| 173 | + stop: (optional,List[str]) |
| 174 | + tools: (optional,List[ChatCompletionsToolDefinition]) |
| 175 | + tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]) |
| 176 | + seed: (optional,int) |
| 177 | + model_extras: (optional,Dict[str, Any]) |
| 178 | +
|
| 179 | + To use this client, you must install the `azure-ai-inference` extension: |
| 180 | +
|
| 181 | + .. code-block:: bash |
| 182 | +
|
| 183 | + pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11' |
| 184 | +
|
| 185 | + The following code snippet shows how to use the client: |
| 186 | +
|
| 187 | + .. code-block:: python |
| 188 | +
|
| 189 | + from azure.core.credentials import AzureKeyCredential |
| 190 | + from autogen_ext.models.azure import AzureAIChatCompletionClient |
| 191 | + from autogen_core.models import UserMessage |
| 192 | +
|
| 193 | + client = AzureAIChatCompletionClient( |
| 194 | + endpoint="endpoint", |
| 195 | + credential=AzureKeyCredential("api_key"), |
| 196 | + model_capabilities={ |
| 197 | + "json_output": False, |
| 198 | + "function_calling": False, |
| 199 | + "vision": False, |
| 200 | + }, |
| 201 | + ) |
| 202 | +
|
| 203 | + result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore |
| 204 | + print(result) |
| 205 | +
|
| 206 | + """ |
| 207 | + |
157 | 208 | def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
|
158 |
| - if "endpoint" not in kwargs: |
| 209 | + config = self._validate_config(kwargs) |
| 210 | + self._model_capabilities = config["model_capabilities"] |
| 211 | + self._client = self._create_client(config) |
| 212 | + self._create_args = self._prepare_create_args(config) |
| 213 | + |
| 214 | + self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) |
| 215 | + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) |
| 216 | + |
| 217 | + @staticmethod |
| 218 | + def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig: |
| 219 | + if "endpoint" not in config: |
159 | 220 | raise ValueError("endpoint is required for AzureAIChatCompletionClient")
|
160 |
| - if "credential" not in kwargs: |
| 221 | + if "credential" not in config: |
161 | 222 | raise ValueError("credential is required for AzureAIChatCompletionClient")
|
162 |
| - if "model_capabilities" not in kwargs: |
| 223 | + if "model_capabilities" not in config: |
163 | 224 | raise ValueError("model_capabilities is required for AzureAIChatCompletionClient")
|
164 |
| - if _is_github_model(kwargs['endpoint']) and "model" not in kwargs: |
| 225 | + if _is_github_model(config["endpoint"]) and "model" not in config: |
165 | 226 | raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
|
166 |
| - |
167 |
| - # TODO: Change |
168 |
| - _endpoint = kwargs.pop("endpoint") |
169 |
| - _credential = kwargs.pop("credential") |
170 |
| - self._model_capabilities = kwargs.pop("model_capabilities") |
171 |
| - self._create_args = kwargs.copy() |
172 |
| - |
173 |
| - self._client = ChatCompletionsClient(_endpoint, _credential, **self._create_args) |
174 |
| - self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) |
175 |
| - self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) |
| 227 | + return config |
| 228 | + |
| 229 | + @staticmethod |
| 230 | + def _create_client(config: AzureAIChatCompletionClientConfig): |
| 231 | + return ChatCompletionsClient(**config) |
| 232 | + |
| 233 | + @staticmethod |
| 234 | + def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]: |
| 235 | + create_args = {k: v for k, v in config.items() if k in create_kwargs} |
| 236 | + return create_args |
| 237 | + # self._endpoint = config.pop("endpoint") |
| 238 | + # self._credential = config.pop("credential") |
| 239 | + # self._model_capabilities = config.pop("model_capabilities") |
| 240 | + # self._create_args = config.copy() |
| 241 | + |
| 242 | + def add_usage(self, usage: RequestUsage): |
| 243 | + self._total_usage = RequestUsage( |
| 244 | + self._total_usage.prompt_tokens + usage.prompt_tokens, |
| 245 | + self._total_usage.completion_tokens + usage.completion_tokens, |
| 246 | + ) |
176 | 247 |
|
177 | 248 | async def create(
|
178 | 249 | self,
|
@@ -200,7 +271,7 @@ async def create(
|
200 | 271 | if self.capabilities["json_output"] is False and json_output is True:
|
201 | 272 | raise ValueError("Model does not support JSON output")
|
202 | 273 |
|
203 |
| - if json_output is True: |
| 274 | + if json_output is True and "response_format" not in create_args: |
204 | 275 | create_args["response_format"] = ChatCompletionsResponseFormatJSON()
|
205 | 276 |
|
206 | 277 | if self.capabilities["json_output"] is False and json_output is True:
|
@@ -259,6 +330,9 @@ async def create(
|
259 | 330 | usage=usage,
|
260 | 331 | cached=False,
|
261 | 332 | )
|
| 333 | + |
| 334 | + self.add_usage(usage) |
| 335 | + |
262 | 336 | return response
|
263 | 337 |
|
264 | 338 | async def create_stream(
|
@@ -286,7 +360,7 @@ async def create_stream(
|
286 | 360 | if self.capabilities["json_output"] is False and json_output is True:
|
287 | 361 | raise ValueError("Model does not support JSON output")
|
288 | 362 |
|
289 |
| - if json_output is True: |
| 363 | + if json_output is True and "response_format" not in create_args: |
290 | 364 | create_args["response_format"] = ChatCompletionsResponseFormatJSON()
|
291 | 365 |
|
292 | 366 | if self.capabilities["json_output"] is False and json_output is True:
|
@@ -380,6 +454,9 @@ async def create_stream(
|
380 | 454 | usage=usage,
|
381 | 455 | cached=False,
|
382 | 456 | )
|
| 457 | + |
| 458 | + self.add_usage(usage) |
| 459 | + |
383 | 460 | yield result
|
384 | 461 |
|
385 | 462 | def actual_usage(self) -> RequestUsage:
|
|
0 commit comments