diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md
index 869ffc2347cf..fdaf598c0029 100644
--- a/python/packages/autogen-core/docs/src/reference/index.md
+++ b/python/packages/autogen-core/docs/src/reference/index.md
@@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
 python/autogen_ext.models.cache
 python/autogen_ext.models.openai
 python/autogen_ext.models.replay
+python/autogen_ext.models.azure
 python/autogen_ext.models.semantic_kernel
 python/autogen_ext.tools.langchain
 python/autogen_ext.tools.graphrag
diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst
new file mode 100644
index 000000000000..64c16a5a57d4
--- /dev/null
+++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst
@@ -0,0 +1,8 @@
+autogen\_ext.models.azure
+==========================
+
+
+.. automodule:: autogen_ext.models.azure
+   :members:
+   :undoc-members:
+   :show-inheritance:
diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml
index 20579c99baec..0e404db2d20f 100644
--- a/python/packages/autogen-ext/pyproject.toml
+++ b/python/packages/autogen-ext/pyproject.toml
@@ -20,7 +20,11 @@ dependencies = [
 
 [project.optional-dependencies]
 langchain = ["langchain_core~= 0.3.3"]
-azure = ["azure-core", "azure-identity"]
+azure = [
+    "azure-ai-inference>=1.0.0b7",
+    "azure-core",
+    "azure-identity",
+]
 docker = ["docker~=7.0"]
 openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"]
 file-surfer = [
@@ -52,7 +56,6 @@ diskcache = [
 redis = [
     "redis>=5.2.1"
 ]
-
 grpc = [
     "grpcio~=1.62.0", # TODO: update this once we have a stable version.
 ]
@@ -60,47 +63,36 @@ jupyter-executor = [
     "ipykernel>=6.29.5",
     "nbclient>=0.10.2",
 ]
-
 semantic-kernel-core = [
     "semantic-kernel>=1.17.1",
 ]
-
 semantic-kernel-google = [
     "semantic-kernel[google]>=1.17.1",
 ]
-
 semantic-kernel-hugging-face = [
     "semantic-kernel[hugging_face]>=1.17.1",
 ]
-
 semantic-kernel-mistralai = [
     "semantic-kernel[mistralai]>=1.17.1",
 ]
-
 semantic-kernel-ollama = [
     "semantic-kernel[ollama]>=1.17.1",
 ]
-
 semantic-kernel-onnx = [
     "semantic-kernel[onnx]>=1.17.1",
 ]
-
 semantic-kernel-anthropic = [
     "semantic-kernel[anthropic]>=1.17.1",
 ]
-
 semantic-kernel-pandas = [
     "semantic-kernel[pandas]>=1.17.1",
 ]
-
 semantic-kernel-aws = [
     "semantic-kernel[aws]>=1.17.1",
 ]
-
 semantic-kernel-dapr = [
     "semantic-kernel[dapr]>=1.17.1",
 ]
-
 semantic-kernel-all = [
     "semantic-kernel[google,hugging_face,mistralai,ollama,onnx,anthropic,usearch,pandas,aws,dapr]>=1.17.1",
 ]
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py
new file mode 100644
index 000000000000..2dc7b9c70a98
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py
@@ -0,0 +1,4 @@
+from ._azure_ai_client import AzureAIChatCompletionClient
+from .config import AzureAIChatCompletionClientConfig
+
+__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"]
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py
new file mode 100644
index 000000000000..7e36a869862a
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py
@@ -0,0 +1,501 @@
+import asyncio
+import re
+import warnings
+from asyncio import Task
+from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast
+from inspect import getfullargspec
+from azure.ai.inference.aio import ChatCompletionsClient
+from azure.ai.inference.models import (
+    ChatCompletions,
+    CompletionsFinishReason,
+    ChatCompletionsToolCall,
+    ChatCompletionsToolDefinition,
+    FunctionDefinition,
+    ContentItem,
+    TextContentItem,
+    ImageContentItem,
+    ImageUrl,
+    ImageDetailLevel,
+    StreamingChatCompletionsUpdate,
+    SystemMessage as AzureSystemMessage,
+    UserMessage as AzureUserMessage,
+    AssistantMessage as AzureAssistantMessage,
+    ToolMessage as AzureToolMessage,
+    FunctionCall as AzureFunctionCall,
+)
+from typing_extensions import AsyncGenerator, Union
+
+from autogen_core import CancellationToken
+from autogen_core import FunctionCall, Image
+from autogen_core.models import (
+    ChatCompletionClient,
+    LLMMessage,
+    CreateResult,
+    ModelInfo,
+    RequestUsage,
+    UserMessage,
+    SystemMessage,
+    AssistantMessage,
+    FunctionExecutionResultMessage,
+    FinishReasons,
+)
+from autogen_core.tools import Tool, ToolSchema
+from autogen_ext.models.azure.config import AzureAIChatCompletionClientConfig, GITHUB_MODELS_ENDPOINT
+
+create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)
+
+
+def _is_github_model(endpoint: str) -> bool:
+    return endpoint == GITHUB_MODELS_ENDPOINT
+
+
+def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]:
+    result: List[ChatCompletionsToolDefinition] = []
+    for tool in tools:
+        if isinstance(tool, Tool):
+            tool_schema = tool.schema.copy()
+        else:
+            assert isinstance(tool, dict)
+            tool_schema = tool.copy()
+        # tool_schema["parameters"] = {k:v for k,v in tool_schema["parameters"].items()}
+        # azure_ai_schema = {k:v for k,v in tool_schema["parameters"].items()}
+
+        for key, value in tool_schema["parameters"]["properties"].items():
+            if "title" in value.keys():
+                del value["title"]
+
+        result.append(
+            ChatCompletionsToolDefinition(
+                function=FunctionDefinition(
+                    name=tool_schema["name"],
+                    description=(tool_schema["description"] if "description" in tool_schema else ""),
+                    parameters=(tool_schema["parameters"]) if "parameters" in tool_schema else {},
+                ),
+            ),
+        )
+    return result
+
+
+def _func_call_to_azure(message: FunctionCall) -> ChatCompletionsToolCall:
+    return ChatCompletionsToolCall(
+        id=message.id,
+        function=AzureFunctionCall(arguments=message.arguments, name=message.name),
+    )
+
+
+def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage:
+    return AzureSystemMessage(content=message.content)
+
+
+def _user_message_to_azure(message: UserMessage) -> AzureUserMessage:
+    assert_valid_name(message.source)
+    if isinstance(message.content, str):
+        return AzureUserMessage(content=message.content)
+    else:
+        parts: List[ContentItem] = []
+        for part in message.content:
+            if isinstance(part, str):
+                parts.append(TextContentItem(text=part))
+            elif isinstance(part, Image):
+                # TODO: support url based images
+                # TODO: support specifying details
+                parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO)))
+            else:
+                raise ValueError(f"Unknown content type: {message.content}")
+        return AzureUserMessage(content=parts)
+
+
+def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage:
+    assert_valid_name(message.source)
+    if isinstance(message.content, list):
+        return AzureAssistantMessage(
+            tool_calls=[_func_call_to_azure(x) for x in message.content],
+        )
+    else:
+        return AzureAssistantMessage(content=message.content)
+
+
+def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]:
+    return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content]
+
+
+def to_azure_message(message: LLMMessage):
+    if isinstance(message, SystemMessage):
+        return [_system_message_to_azure(message)]
+    elif isinstance(message, UserMessage):
+        return [_user_message_to_azure(message)]
+    elif isinstance(message, AssistantMessage):
+        return [_assistant_message_to_azure(message)]
+    else:
+        return _tool_message_to_azure(message)
+
+
+def normalize_name(name: str) -> str:
+    """
+    LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
+
+    Prefer _assert_valid_name for validating user configuration or input
+    """
+    return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
+
+
+def assert_valid_name(name: str) -> str:
+    """
+    Ensure that configured names are valid, raises ValueError if not.
+
+    For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
+    """
+    if not re.match(r"^[a-zA-Z0-9_-]+$", name):
+        raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
+    if len(name) > 64:
+        raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
+    return name
+
+
+def normalize_stop_reason(stop_reason: str|None) -> FinishReasons:
+    if stop_reason is None:
+        return "unknown"
+    
+    stop_reason = stop_reason.lower()
+
+    KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
+        "end_turn": "stop",
+        "tool_calls": "function_calls",
+    }
+
+    return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown") 
+
+
+class AzureAIChatCompletionClient(ChatCompletionClient):
+    """
+    Chat completion client for models hosted on Azure AI Foundry or GitHub Models.
+    See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info.
+
+    Args:
+        endpoint (str): The endpoint to use. **Required.**
+        credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required**
+        model_info (ModelInfo): The capabilities of the model. **Required.**
+        model (str): The name of the model. **Required if model is hosted on GitHub Models.**
+        frequency_penalty: (optional,float)
+        presence_penalty: (optional,float)
+        temperature: (optional,float)
+        top_p: (optional,float)
+        max_tokens: (optional,int)
+        response_format: (optional,ChatCompletionsResponseFormat)
+        stop: (optional,List[str])
+        tools: (optional,List[ChatCompletionsToolDefinition])
+        tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]])
+        seed: (optional,int)
+        model_extras: (optional,Dict[str, Any])
+
+    To use this client, you must install the `azure-ai-inference` extension:
+
+        .. code-block:: bash
+
+            pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11'
+
+    The following code snippet shows how to use the client:
+
+        .. code-block:: python
+
+            from azure.core.credentials import AzureKeyCredential
+            from autogen_ext.models.azure import AzureAIChatCompletionClient
+            from autogen_core.models import UserMessage
+
+            client = AzureAIChatCompletionClient(
+                endpoint="endpoint",
+                credential=AzureKeyCredential("api_key"),
+                model_info={
+                    "family": "unknown",
+                    "json_output": False,
+                    "function_calling": False,
+                    "vision": False,
+                },
+            )
+
+            result = await client.create([UserMessage(content="What is the capital of France?", source="user")])  # type: ignore
+            print(result)
+
+    """
+
+    def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
+        config = self._validate_config(kwargs)
+        self._model_info = config["model_info"]
+        self._client = self._create_client(config)
+        self._create_args = self._prepare_create_args(config)
+
+        self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
+        self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
+
+    @staticmethod
+    def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig:
+        if "endpoint" not in config:
+            raise ValueError("endpoint is required for AzureAIChatCompletionClient")
+        if "credential" not in config:
+            raise ValueError("credential is required for AzureAIChatCompletionClient")
+        if "model_info" not in config:
+            raise ValueError("model_info is required for AzureAIChatCompletionClient")
+        if _is_github_model(config["endpoint"]) and "model" not in config:
+            raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
+        return config
+
+    @staticmethod
+    def _create_client(config: AzureAIChatCompletionClientConfig):
+        return ChatCompletionsClient(**config)
+
+    @staticmethod
+    def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]:
+        create_args = {k: v for k, v in config.items() if k in create_kwargs}
+        return create_args
+        # self._endpoint = config.pop("endpoint")
+        # self._credential = config.pop("credential")
+        # self._model_capabilities = config.pop("model_capabilities")
+        # self._create_args = config.copy()
+
+    def add_usage(self, usage: RequestUsage):
+        self._total_usage = RequestUsage(
+            self._total_usage.prompt_tokens + usage.prompt_tokens,
+            self._total_usage.completion_tokens + usage.completion_tokens,
+        )
+
+    async def create(
+        self,
+        messages: Sequence[LLMMessage],
+        tools: Sequence[Tool | ToolSchema] = [],
+        json_output: Optional[bool] = None,
+        extra_create_args: Mapping[str, Any] = {},
+        cancellation_token: Optional[CancellationToken] = None,
+    ) -> CreateResult:
+        extra_create_args_keys = set(extra_create_args.keys())
+        if not create_kwargs.issuperset(extra_create_args_keys):
+            raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
+
+        # Copy the create args and overwrite anything in extra_create_args
+        create_args = self._create_args.copy()
+        create_args.update(extra_create_args)
+
+        if self.model_info["vision"] is False:
+            for message in messages:
+                if isinstance(message, UserMessage):
+                    if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
+                        raise ValueError("Model does not support vision and image was provided")
+
+        if json_output is not None:
+            if self.model_info["json_output"] is False and json_output is True:
+                raise ValueError("Model does not support JSON output")
+
+            if json_output is True and "response_format" not in create_args:
+                create_args["response_format"] = "json-object"
+
+        if self.model_info["json_output"] is False and json_output is True:
+            raise ValueError("Model does not support JSON output")
+        if self.model_info["function_calling"] is False and len(tools) > 0:
+            raise ValueError("Model does not support function calling")
+
+        azure_messages_nested = [to_azure_message(msg) for msg in messages]
+        azure_messages = [item for sublist in azure_messages_nested for item in sublist]
+
+        task: Task[ChatCompletions]
+
+        if len(tools) > 0:
+            converted_tools = convert_tools(tools)
+            task = asyncio.create_task(
+                self._client.complete(messages=azure_messages, tools=converted_tools, **create_args)
+            )
+        else:
+            task = asyncio.create_task(
+                self._client.complete(
+                    messages=azure_messages,
+                    **create_args,
+                )
+            )
+
+        if cancellation_token is not None:
+            cancellation_token.link_future(task)
+
+        result: ChatCompletions = await task
+
+        usage = RequestUsage(
+            prompt_tokens=result.usage.prompt_tokens if result.usage else 0,
+            completion_tokens=result.usage.completion_tokens if result.usage else 0,
+        )
+
+        choice = result.choices[0]
+        if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS:
+            assert choice.message.tool_calls is not None
+
+            content = [
+                FunctionCall(
+                    id=x.id,
+                    arguments=x.function.arguments,
+                    name=normalize_name(x.function.name),
+                )
+                for x in choice.message.tool_calls
+            ]
+            finish_reason = "function_calls"
+        else:
+            finish_reason = choice.finish_reason
+            content = choice.message.content or ""
+
+        response = CreateResult(
+            finish_reason=normalize_stop_reason(finish_reason.value),  # type: ignore
+            content=content,
+            usage=usage,
+            cached=False,
+        )
+
+        self.add_usage(usage)
+
+        return response
+
+    async def create_stream(
+        self,
+        messages: Sequence[LLMMessage],
+        tools: Sequence[Tool | ToolSchema] = [],
+        json_output: Optional[bool] = None,
+        extra_create_args: Mapping[str, Any] = {},
+        cancellation_token: Optional[CancellationToken] = None,
+    ) -> AsyncGenerator[Union[str, CreateResult], None]:
+        extra_create_args_keys = set(extra_create_args.keys())
+        if not create_kwargs.issuperset(extra_create_args_keys):
+            raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
+
+        create_args = self._create_args.copy()
+        create_args.update(extra_create_args)
+
+        if self.model_info["vision"] is False:
+            for message in messages:
+                if isinstance(message, UserMessage):
+                    if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
+                        raise ValueError("Model does not support vision and image was provided")
+
+        if json_output is not None:
+            if self.model_info["json_output"] is False and json_output is True:
+                raise ValueError("Model does not support JSON output")
+
+            if json_output is True and "response_format" not in create_args:
+                create_args["response_format"] = "json-object"
+
+        if self.model_info["json_output"] is False and json_output is True:
+            raise ValueError("Model does not support JSON output")
+        if self.model_info["function_calling"] is False and len(tools) > 0:
+            raise ValueError("Model does not support function calling")
+
+        # azure_messages = [to_azure_message(m) for m in messages]
+        azure_messages_nested = [to_azure_message(msg) for msg in messages]
+        azure_messages = [item for sublist in azure_messages_nested for item in sublist]
+
+        # task: Task[StreamingChatCompletionsUpdate]
+
+        if len(tools) > 0:
+            converted_tools = convert_tools(tools)
+            task = asyncio.create_task(
+                self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args)
+            )
+        else:
+            task = asyncio.create_task(
+                self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args)
+            )
+
+        if cancellation_token is not None:
+            cancellation_token.link_future(task)
+
+        # result: ChatCompletions = await task
+        finish_reason = None
+        content_deltas: List[str] = []
+        full_tool_calls: Dict[str, FunctionCall] = {}
+        prompt_tokens = 0
+        completion_tokens = 0
+        chunk: Optional[StreamingChatCompletionsUpdate] = None
+        async for chunk in await task:
+            choice = chunk.choices[0] if len(chunk.choices) > 0 else cast(StreamingChatCompletionsUpdate, None)
+            if choice.finish_reason is not None:
+                finish_reason = choice.finish_reason.value
+
+            # We first try to load the content
+            if choice.delta.content is not None:
+                content_deltas.append(choice.delta.content)
+                yield choice.delta.content
+            # Otherwise, we try to load the tool calls
+            if choice.delta.tool_calls is not None:
+                for tool_call_chunk in choice.delta.tool_calls:
+                    # print(tool_call_chunk)
+                    if "index" in tool_call_chunk:
+                        idx = tool_call_chunk["index"]
+                    else:
+                        idx = tool_call_chunk.id
+                    if idx not in full_tool_calls:
+                        full_tool_calls[idx] = FunctionCall(id="", arguments="", name="")
+
+                    if tool_call_chunk.id is not None:
+                        full_tool_calls[idx].id += tool_call_chunk.id
+
+                    if tool_call_chunk.function is not None:
+                        if tool_call_chunk.function.name is not None:
+                            full_tool_calls[idx].name += tool_call_chunk.function.name
+                        if tool_call_chunk.function.arguments is not None:
+                            full_tool_calls[idx].arguments += tool_call_chunk.function.arguments
+
+        if chunk and chunk.usage:
+            prompt_tokens = chunk.usage.prompt_tokens
+
+        if finish_reason is None:
+            raise ValueError("No stop reason found")
+
+        if choice and choice.finish_reason is CompletionsFinishReason.TOOL_CALLS:
+            finish_reason = "function_calls"
+
+        content: Union[str, List[FunctionCall]]
+
+        if len(content_deltas) > 1:
+            content = "".join(content_deltas)
+            if chunk and chunk.usage:
+                completion_tokens = chunk.usage.completion_tokens
+            else:
+                completion_tokens = 0
+        else:
+            content = list(full_tool_calls.values())
+
+        usage = RequestUsage(
+            completion_tokens=completion_tokens,
+            prompt_tokens=prompt_tokens,
+        )
+
+        result = CreateResult(
+            finish_reason=normalize_stop_reason(finish_reason),  # type: ignore
+            content=content,
+            usage=usage,
+            cached=False,
+        )
+
+        self.add_usage(usage)
+
+        yield result
+
+    def actual_usage(self) -> RequestUsage:
+        return self._actual_usage
+
+    def total_usage(self) -> RequestUsage:
+        return self._total_usage
+
+    def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
+        return 0
+
+    def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
+        return 0
+
+    @property
+    def capabilities(self) -> ModelInfo: # type: ignore
+        warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
+        return self._model_info
+
+    @property
+    def model_info(self) -> ModelInfo:
+        return self._model_info
+
+    def __del__(self):
+        # TODO: This is a hack to close the open client
+        try:
+            asyncio.get_running_loop().create_task(self._client.close())
+        except RuntimeError:
+            asyncio.run(self._client.close())
diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py
new file mode 100644
index 000000000000..492f868fc20d
--- /dev/null
+++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py
@@ -0,0 +1,38 @@
+from typing import TypedDict, Union, Optional, List, Dict, Any
+from azure.ai.inference.models import (
+    JsonSchemaFormat,
+    ChatCompletionsToolDefinition,
+    ChatCompletionsToolChoicePreset,
+    ChatCompletionsNamedToolChoice,
+)
+
+from azure.core.credentials import AzureKeyCredential
+from azure.core.credentials_async import AsyncTokenCredential
+from autogen_core.models import ModelInfo
+
+GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"
+
+
+class AzureAIClientArguments(TypedDict, total=False):
+    endpoint: str
+    credential: Union[AzureKeyCredential, AsyncTokenCredential]
+    model_info: ModelInfo
+
+
+class AzureAICreateArguments(TypedDict, total=False):
+    frequency_penalty: Optional[float]
+    presence_penalty: Optional[float]
+    temperature: Optional[float]
+    top_p: Optional[float]
+    max_tokens: Optional[int]
+    response_format: Optional[Union[str, JsonSchemaFormat]]
+    stop: Optional[List[str]]
+    tools: Optional[List[ChatCompletionsToolDefinition]]
+    tool_choice: Optional[Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]
+    seed: Optional[int]
+    model: Optional[str]
+    model_extras: Optional[Dict[str, Any]]
+
+
+class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments):
+    pass
diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py
new file mode 100644
index 000000000000..e18d5ea2280b
--- /dev/null
+++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py
@@ -0,0 +1,174 @@
+import asyncio
+from datetime import datetime
+from typing import AsyncGenerator, Any
+
+import pytest
+from azure.ai.inference.aio import (
+    ChatCompletionsClient,
+)
+
+
+from azure.ai.inference.models import (
+    ChatChoice,
+    ChatResponseMessage,
+    CompletionsUsage,
+)
+
+from azure.ai.inference.models import (
+    ChatCompletions,
+    StreamingChatCompletionsUpdate,
+    StreamingChatChoiceUpdate,
+    StreamingChatResponseMessageUpdate,
+)
+
+from azure.core.credentials import AzureKeyCredential
+
+from autogen_core import CancellationToken
+from autogen_core.models import UserMessage
+from autogen_ext.models.azure import AzureAIChatCompletionClient
+
+
+async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
+    mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]
+
+    mock_chunks = [
+        StreamingChatChoiceUpdate(
+            index=0,
+            finish_reason="stop",
+            delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
+        )
+        for chunk_content in mock_chunks_content
+    ]
+
+    for mock_chunk in mock_chunks:
+        await asyncio.sleep(0.1)
+        yield StreamingChatCompletionsUpdate(
+            id="id",
+            choices=[mock_chunk],
+            created=datetime.now(),
+            model="model",
+            usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
+        )
+
+
+async def _mock_create(
+    *args: Any, **kwargs: Any
+) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
+    stream = kwargs.get("stream", False)
+
+    if not stream:
+        await asyncio.sleep(0.1)
+        return ChatCompletions(
+            id="id",
+            created=datetime.now(),
+            model="model",
+            choices=[
+                ChatChoice(
+                    index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant")
+                )
+            ],
+            usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
+        )
+    else:
+        return _mock_create_stream(*args, **kwargs)
+
+
+@pytest.mark.asyncio
+async def test_azure_ai_chat_completion_client() -> None:
+    client = AzureAIChatCompletionClient(
+        endpoint="endpoint",
+        credential=AzureKeyCredential("api_key"),
+        model_info={
+            "family": "unknown",
+            "json_output": False,
+            "function_calling": False,
+            "vision": False,
+        },
+        model="model",
+    )
+    assert client
+
+
+@pytest.mark.asyncio
+async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None:
+    # monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
+    monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
+    client = AzureAIChatCompletionClient(
+        endpoint="endpoint",
+        credential=AzureKeyCredential("api_key"),
+        model_info={
+            "family": "unknown",
+            "json_output": False,
+            "function_calling": False,
+            "vision": False,
+        },
+    )
+    result = await client.create(messages=[UserMessage(content="Hello", source="user")])
+    assert result.content == "Hello"
+
+
+@pytest.mark.asyncio
+async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None:
+    monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
+    chunks = []
+    client = AzureAIChatCompletionClient(
+        endpoint="endpoint",
+        credential=AzureKeyCredential("api_key"),
+        model_info={
+            "family": "unknown",
+            "json_output": False,
+            "function_calling": False,
+            "vision": False,
+        },
+    )
+    async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
+        chunks.append(chunk)
+
+    assert chunks[0] == "Hello"
+    assert chunks[1] == " Another Hello"
+    assert chunks[2] == " Yet Another Hello"
+
+
+@pytest.mark.asyncio
+async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
+    monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
+    cancellation_token = CancellationToken()
+    client = AzureAIChatCompletionClient(
+        endpoint="endpoint",
+        credential=AzureKeyCredential("api_key"),
+        model_info={
+            "family": "unknown",
+            "json_output": False,
+            "function_calling": False,
+            "vision": False,
+        },
+    )
+    task = asyncio.create_task(
+        client.create(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token)
+    )
+    cancellation_token.cancel()
+    with pytest.raises(asyncio.CancelledError):
+        await task
+
+
+@pytest.mark.asyncio
+async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
+    monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
+    cancellation_token = CancellationToken()
+    client = AzureAIChatCompletionClient(
+        endpoint="endpoint",
+        credential=AzureKeyCredential("api_key"),
+        model_info={
+            "family": "unknown",
+            "json_output": False,
+            "function_calling": False,
+            "vision": False,
+        },
+    )
+    stream = client.create_stream(
+        messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
+    )
+    cancellation_token.cancel()
+    with pytest.raises(asyncio.CancelledError):
+        async for _ in stream:
+            pass