Skip to content

Commit 0912b51

Browse files
committed
WIP: Azure AI Client
* Added: object-level usage data * Added: doc string * Added: check existing response_format value * Added: _validate_config and _create_client
1 parent f441aa5 commit 0912b51

File tree

5 files changed

+131
-32
lines changed

5 files changed

+131
-32
lines changed

python/packages/autogen-core/docs/src/reference/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ python/autogen_ext.agents.video_surfer
4747
python/autogen_ext.agents.video_surfer.tools
4848
python/autogen_ext.models.openai
4949
python/autogen_ext.models.replay
50+
python/autogen_ext.models.azure
5051
python/autogen_ext.tools.langchain
5152
python/autogen_ext.code_executors.local
5253
python/autogen_ext.code_executors.docker
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
autogen\_ext.models.azure
2+
==========================
3+
4+
5+
.. automodule:: autogen_ext.models.azure
6+
:members:
7+
:undoc-members:
8+
:show-inheritance:
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from ._azure_ai_client import AzureAIChatCompletionClient
2+
from .config import AzureAIChatCompletionClientConfig
23

3-
__all__ = ["AzureAIChatCompletionClient"]
4+
__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"]

python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py

+93-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import re
3+
import warnings
34
from asyncio import Task
45
from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast
56
from inspect import getfullargspec
@@ -154,25 +155,95 @@ def assert_valid_name(name: str) -> str:
154155

155156

156157
class AzureAIChatCompletionClient(ChatCompletionClient):
158+
"""
159+
Chat completion client for models hosted on Azure AI Foundry or GitHub Models.
160+
See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info.
161+
162+
Args:
163+
endpoint (str): The endpoint to use. **Required.**
164+
credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required**
165+
model_capabilities (ModelCapabilities): The capabilities of the model. **Required.**
166+
model (str): The name of the model. **Required if model is hosted on GitHub Models.**
167+
frequency_penalty: (optional,float)
168+
presence_penalty: (optional,float)
169+
temperature: (optional,float)
170+
top_p: (optional,float)
171+
max_tokens: (optional,int)
172+
response_format: (optional,ChatCompletionsResponseFormat)
173+
stop: (optional,List[str])
174+
tools: (optional,List[ChatCompletionsToolDefinition])
175+
tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]])
176+
seed: (optional,int)
177+
model_extras: (optional,Dict[str, Any])
178+
179+
To use this client, you must install the `azure-ai-inference` extension:
180+
181+
.. code-block:: bash
182+
183+
pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11'
184+
185+
The following code snippet shows how to use the client:
186+
187+
.. code-block:: python
188+
189+
from azure.core.credentials import AzureKeyCredential
190+
from autogen_ext.models.azure import AzureAIChatCompletionClient
191+
from autogen_core.models import UserMessage
192+
193+
client = AzureAIChatCompletionClient(
194+
endpoint="endpoint",
195+
credential=AzureKeyCredential("api_key"),
196+
model_capabilities={
197+
"json_output": False,
198+
"function_calling": False,
199+
"vision": False,
200+
},
201+
)
202+
203+
result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore
204+
print(result)
205+
206+
"""
207+
157208
def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
158-
if "endpoint" not in kwargs:
209+
config = self._validate_config(kwargs)
210+
self._model_capabilities = config["model_capabilities"]
211+
self._client = self._create_client(config)
212+
self._create_args = self._prepare_create_args(config)
213+
214+
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
215+
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
216+
217+
@staticmethod
218+
def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig:
219+
if "endpoint" not in config:
159220
raise ValueError("endpoint is required for AzureAIChatCompletionClient")
160-
if "credential" not in kwargs:
221+
if "credential" not in config:
161222
raise ValueError("credential is required for AzureAIChatCompletionClient")
162-
if "model_capabilities" not in kwargs:
223+
if "model_capabilities" not in config:
163224
raise ValueError("model_capabilities is required for AzureAIChatCompletionClient")
164-
if _is_github_model(kwargs['endpoint']) and "model" not in kwargs:
225+
if _is_github_model(config["endpoint"]) and "model" not in config:
165226
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
166-
167-
# TODO: Change
168-
_endpoint = kwargs.pop("endpoint")
169-
_credential = kwargs.pop("credential")
170-
self._model_capabilities = kwargs.pop("model_capabilities")
171-
self._create_args = kwargs.copy()
172-
173-
self._client = ChatCompletionsClient(_endpoint, _credential, **self._create_args)
174-
self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
175-
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
227+
return config
228+
229+
@staticmethod
230+
def _create_client(config: AzureAIChatCompletionClientConfig):
231+
return ChatCompletionsClient(**config)
232+
233+
@staticmethod
234+
def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]:
235+
create_args = {k: v for k, v in config.items() if k in create_kwargs}
236+
return create_args
237+
# self._endpoint = config.pop("endpoint")
238+
# self._credential = config.pop("credential")
239+
# self._model_capabilities = config.pop("model_capabilities")
240+
# self._create_args = config.copy()
241+
242+
def add_usage(self, usage: RequestUsage):
243+
self._total_usage = RequestUsage(
244+
self._total_usage.prompt_tokens + usage.prompt_tokens,
245+
self._total_usage.completion_tokens + usage.completion_tokens,
246+
)
176247

177248
async def create(
178249
self,
@@ -200,7 +271,7 @@ async def create(
200271
if self.capabilities["json_output"] is False and json_output is True:
201272
raise ValueError("Model does not support JSON output")
202273

203-
if json_output is True:
274+
if json_output is True and "response_format" not in create_args:
204275
create_args["response_format"] = ChatCompletionsResponseFormatJSON()
205276

206277
if self.capabilities["json_output"] is False and json_output is True:
@@ -259,6 +330,9 @@ async def create(
259330
usage=usage,
260331
cached=False,
261332
)
333+
334+
self.add_usage(usage)
335+
262336
return response
263337

264338
async def create_stream(
@@ -286,7 +360,7 @@ async def create_stream(
286360
if self.capabilities["json_output"] is False and json_output is True:
287361
raise ValueError("Model does not support JSON output")
288362

289-
if json_output is True:
363+
if json_output is True and "response_format" not in create_args:
290364
create_args["response_format"] = ChatCompletionsResponseFormatJSON()
291365

292366
if self.capabilities["json_output"] is False and json_output is True:
@@ -380,6 +454,9 @@ async def create_stream(
380454
usage=usage,
381455
cached=False,
382456
)
457+
458+
self.add_usage(usage)
459+
383460
yield result
384461

385462
def actual_usage(self) -> RequestUsage:

python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77
ChatCompletionsClient,
88
)
99

10+
1011
from azure.ai.inference.models import (
1112
ChatChoice,
1213
ChatResponseMessage,
1314
CompletionsUsage,
15+
ChatCompletionsResponseFormatJSON,
16+
)
1417

18+
from azure.ai.inference.models import (
19+
ChatCompletions,
20+
StreamingChatCompletionsUpdate,
21+
StreamingChatChoiceUpdate,
22+
StreamingChatResponseMessageUpdate,
1523
)
16-
from azure.ai.inference.models import (ChatCompletions,
17-
StreamingChatCompletionsUpdate, StreamingChatChoiceUpdate,
18-
StreamingChatResponseMessageUpdate)
1924

2025
from azure.core.credentials import AzureKeyCredential
2126

@@ -32,7 +37,8 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea
3237
index=0,
3338
finish_reason="stop",
3439
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
35-
) for chunk_content in mock_chunks_content
40+
)
41+
for chunk_content in mock_chunks_content
3642
]
3743

3844
for mock_chunk in mock_chunks:
@@ -46,20 +52,20 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea
4652
)
4753

4854

49-
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
55+
async def _mock_create(
56+
*args: Any, **kwargs: Any
57+
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
5058
stream = kwargs.get("stream", False)
5159

5260
if not stream:
5361
await asyncio.sleep(0.1)
5462
return ChatCompletions(
5563
id="id",
5664
created=datetime.now(),
57-
model='model',
65+
model="model",
5866
choices=[
5967
ChatChoice(
60-
index=0,
61-
finish_reason="stop",
62-
message=ChatResponseMessage(content="Hello", role="assistant")
68+
index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant")
6369
)
6470
],
6571
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
@@ -68,28 +74,29 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGene
6874
return _mock_create_stream(*args, **kwargs)
6975

7076

71-
7277
@pytest.mark.asyncio
7378
async def test_azure_ai_chat_completion_client() -> None:
7479
client = AzureAIChatCompletionClient(
7580
endpoint="endpoint",
7681
credential=AzureKeyCredential("api_key"),
77-
model_capabilities = {
82+
model_capabilities={
7883
"json_output": False,
7984
"function_calling": False,
8085
"vision": False,
8186
},
87+
model="model",
8288
)
8389
assert client
8490

91+
8592
@pytest.mark.asyncio
8693
async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None:
8794
# monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
8895
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
8996
client = AzureAIChatCompletionClient(
9097
endpoint="endpoint",
9198
credential=AzureKeyCredential("api_key"),
92-
model_capabilities = {
99+
model_capabilities={
93100
"json_output": False,
94101
"function_calling": False,
95102
"vision": False,
@@ -98,14 +105,15 @@ async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.Monkey
98105
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
99106
assert result.content == "Hello"
100107

108+
101109
@pytest.mark.asyncio
102-
async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.MonkeyPatch) -> None:
110+
async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None:
103111
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
104112
chunks = []
105113
client = AzureAIChatCompletionClient(
106114
endpoint="endpoint",
107115
credential=AzureKeyCredential("api_key"),
108-
model_capabilities = {
116+
model_capabilities={
109117
"json_output": False,
110118
"function_calling": False,
111119
"vision": False,
@@ -118,6 +126,7 @@ async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.
118126
assert chunks[1] == " Another Hello"
119127
assert chunks[2] == " Yet Another Hello"
120128

129+
121130
@pytest.mark.asyncio
122131
async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
123132
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
@@ -138,6 +147,7 @@ async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest
138147
with pytest.raises(asyncio.CancelledError):
139148
await task
140149

150+
141151
@pytest.mark.asyncio
142152
async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
143153
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
@@ -151,7 +161,9 @@ async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch:
151161
"vision": False,
152162
},
153163
)
154-
stream=client.create_stream(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token)
164+
stream = client.create_stream(
165+
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
166+
)
155167
cancellation_token.cancel()
156168
with pytest.raises(asyncio.CancelledError):
157169
async for _ in stream:

0 commit comments

Comments
 (0)