diff --git a/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py b/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py index accebf869..a3c470c17 100644 --- a/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py +++ b/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py @@ -8,13 +8,12 @@ from mcp.client.streamable_http import streamablehttp_client from mcp.types import TextContent -from agentstack_sdk.a2a.extensions.tools.call import ( - ToolCallExtensionParams, - ToolCallExtensionServer, - ToolCallExtensionSpec, - ToolCallRequest, +from agentstack_sdk.a2a.extensions.interactions.approval import ( + ApprovalExtensionParams, + ApprovalExtensionServer, + ApprovalExtensionSpec, + ToolCallApprovalRequest, ) -from agentstack_sdk.a2a.extensions.tools.exceptions import ToolCallRejectionError from agentstack_sdk.server import Server from agentstack_sdk.server.context import RunContext @@ -25,7 +24,7 @@ async def tool_call_approval_agent( message: Message, context: RunContext, - mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())], + mcp_tool_call: Annotated[ApprovalExtensionServer, ApprovalExtensionSpec(params=ApprovalExtensionParams())], ): async with ( streamablehttp_client(url="https://hf.co/mcp") as (read, write, _), @@ -41,18 +40,18 @@ async def tool_call_approval_agent( raise RuntimeError("Could not find whoami_tool on the server") arguments = {} - try: - await mcp_tool_call.request_tool_call_approval( - ToolCallRequest.from_mcp_tool(whoami_tool, arguments, server=session_init_result.serverInfo), - context=context, - ) + response = await mcp_tool_call.request_approval( + ToolCallApprovalRequest.from_mcp_tool(whoami_tool, arguments, server=session_init_result.serverInfo), + context=context, + ) + if response.approved: result = await session.call_tool("hf_whoami", arguments) content = result.content[0] if isinstance(content, TextContent): yield content.text else: yield "Tool call succeeded" - except ToolCallRejectionError: + else: yield "Tool call has been rejected by the client" diff --git a/apps/agentstack-sdk-py/examples/tool_call_approval_client.py b/apps/agentstack-sdk-py/examples/tool_call_approval_client.py index 64d6f9926..a6a9ce697 100644 --- a/apps/agentstack-sdk-py/examples/tool_call_approval_client.py +++ b/apps/agentstack-sdk-py/examples/tool_call_approval_client.py @@ -9,24 +9,24 @@ import httpx import agentstack_sdk.a2a.extensions -from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse +from agentstack_sdk.a2a.extensions.interactions.approval import ApprovalResponse async def run(base_url: str = "http://127.0.0.1:10000"): async with httpx.AsyncClient(timeout=30) as httpx_client: card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card() - tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card) + approval_spec = agentstack_sdk.a2a.extensions.ApprovalExtensionSpec.from_agent_card(card) - if not tool_call_spec: - raise ValueError(f"Agent at {base_url} does not support MCP Tool Call extension") + if not approval_spec: + raise ValueError(f"Agent at {base_url} does not support approval extension") - tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec) + approval_extension_client = agentstack_sdk.a2a.extensions.ApprovalExtensionClient(approval_spec) message = a2a.types.Message( message_id=str(uuid.uuid4()), role=a2a.types.Role.user, parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))], - metadata=tool_call_extension_client.metadata(), + metadata=approval_extension_client.metadata(), ) client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create( @@ -45,13 +45,13 @@ async def run(base_url: str = "http://127.0.0.1:10000"): if not task.status.message: raise RuntimeError("Missing message") - approval_request = tool_call_extension_client.parse_request(message=task.status.message) + approval_request = approval_extension_client.parse_request(message=task.status.message) print("Agent has requested a tool call") print(approval_request) choice = input("Approve (Y/n): ") - response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject") - message = tool_call_extension_client.create_response_message(task_id=task.id, response=response) + response = ApprovalResponse(decision="approve" if choice.lower() == "y" else "reject") + message = approval_extension_client.create_response_message(task_id=task.id, response=response) else: break diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py index 6fcc7e3d5..f17926356 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from .auth import * +from .interactions import * from .services import * from .tools import * from .ui import * diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/interactions/__init__.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/interactions/__init__.py new file mode 100644 index 000000000..505dd19a0 --- /dev/null +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/interactions/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from .approval import * diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/interactions/approval.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/interactions/approval.py new file mode 100644 index 000000000..0856d466c --- /dev/null +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/interactions/approval.py @@ -0,0 +1,125 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import uuid +from types import NoneType +from typing import TYPE_CHECKING, Annotated, Any, Literal + +import a2a.types +from mcp import Implementation, Tool +from pydantic import BaseModel, Discriminator, Field, TypeAdapter + +from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from agentstack_sdk.a2a.types import AgentMessage, InputRequired + +if TYPE_CHECKING: + from agentstack_sdk.server.context import RunContext + + +class ApprovalRejectionError(RuntimeError): + pass + + +class GenericApprovalRequest(BaseModel): + action: Literal["generic"] = "generic" + + title: str | None = Field(None, description="A human-readable title for the action being approved.") + description: str | None = Field(None, description="A human-readable description of the action being approved.") + + +class ToolCallServer(BaseModel): + name: str = Field(description="The programmatic name of the server.") + title: str | None = Field(description="A human-readable title for the server.") + version: str = Field(description="The version of the server.") + + +class ToolCallApprovalRequest(BaseModel): + action: Literal["tool-call"] = "tool-call" + + title: str | None = Field(None, description="A human-readable title for the tool call being approved.") + description: str | None = Field(None, description="A human-readable description of the tool call being approved.") + name: str = Field(description="The programmatic name of the tool.") + input: dict[str, Any] | None = Field(description="The input for the tool.") + server: ToolCallServer | None = Field(None, description="The server executing the tool.") + + @staticmethod + def from_mcp_tool( + tool: Tool, input: dict[str, Any] | None, server: Implementation | None = None + ) -> ToolCallApprovalRequest: + return ToolCallApprovalRequest( + name=tool.name, + title=tool.annotations.title if tool.annotations else None, + description=tool.description, + input=input, + server=ToolCallServer(name=server.name, title=server.title, version=server.version) if server else None, + ) + + +ApprovalRequest = Annotated[GenericApprovalRequest | ToolCallApprovalRequest, Discriminator("action")] + + +class ApprovalResponse(BaseModel): + decision: Literal["approve", "reject"] + + @property + def approved(self) -> bool: + return self.decision == "approve" + + def raise_on_rejection(self) -> None: + if self.decision == "reject": + raise ApprovalRejectionError("Approval request has been rejected") + + +class ApprovalExtensionParams(BaseModel): + pass + + +class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]): + URI: str = "https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1" + + +class ApprovalExtensionMetadata(BaseModel): + pass + + +class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]): + def create_request_message(self, *, request: ApprovalRequest): + return AgentMessage(text="Approval requested", metadata={self.spec.URI: request.model_dump(mode="json")}) + + def parse_response(self, *, message: a2a.types.Message): + if not message.metadata or not (data := message.metadata.get(self.spec.URI)): + raise ValueError("Approval response data is missing") + return ApprovalResponse.model_validate(data) + + async def request_approval( + self, + request: ApprovalRequest, + *, + context: RunContext, + ) -> ApprovalResponse: + message = self.create_request_message(request=request) + message = await context.yield_async(InputRequired(message=message)) + if not message: + raise RuntimeError("Yield did not return a message") + return self.parse_response(message=message) + + +class ApprovalExtensionClient(BaseExtensionClient[ApprovalExtensionSpec, NoneType]): + def create_response_message(self, *, response: ApprovalResponse, task_id: str | None): + return a2a.types.Message( + message_id=str(uuid.uuid4()), + role=a2a.types.Role.user, + parts=[], + task_id=task_id, + metadata={self.spec.URI: response.model_dump(mode="json")}, + ) + + def parse_request(self, *, message: a2a.types.Message): + if not message.metadata or not (data := message.metadata.get(self.spec.URI)): + raise ValueError("Approval request data is missing") + return TypeAdapter(ApprovalRequest).validate_python(data) + + def metadata(self) -> dict[str, Any]: + return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json")} diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py index 6a0f8ea85..c9283b9f4 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal import a2a.types +from deprecated import deprecated from mcp import Tool from mcp.types import Implementation from pydantic import BaseModel, Field @@ -26,6 +27,7 @@ class ToolCallServer(BaseModel): version: str = Field(description="The version of the server.") +@deprecated(reason="Use ToolCallApprovalRequest instead") class ToolCallRequest(BaseModel): name: str = Field(description="The programmatic name of the tool.") title: str | None = Field(None, description="A human-readable title for the tool.") @@ -64,6 +66,7 @@ class ToolCallExtensionMetadata(BaseModel): pass +@deprecated(reason="Use ApprovalExtensionServer instead") class ToolCallExtensionServer(BaseExtensionServer[ToolCallExtensionSpec, ToolCallExtensionMetadata]): def create_request_message(self, *, request: ToolCallRequest): return AgentMessage( diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/handle-task-status-update.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/handle-task-status-update.ts index 3b4b438f6..74afa0126 100644 --- a/apps/agentstack-sdk-ts/src/client/a2a/extensions/handle-task-status-update.ts +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/handle-task-status-update.ts @@ -6,6 +6,8 @@ import type { TaskStatusUpdateEvent } from '@a2a-js/sdk'; import type { FormRender } from './common/form'; +import type { ApprovalRequest } from './interactions/approval'; +import { approvalExtension } from './interactions/approval'; import type { SecretDemands } from './services/secrets'; import { secretsMessageExtension } from './services/secrets'; import { FormRequestExtension } from './ui/form-request'; @@ -15,11 +17,13 @@ import { extractUiExtensionData } from './utils'; const secretsMessageExtensionExtractor = extractUiExtensionData(secretsMessageExtension); const oauthRequestExtensionExtractor = extractUiExtensionData(oauthRequestExtension); const FormRequestExtensionExtractor = extractUiExtensionData(FormRequestExtension); +const approvalExtensionExtractor = extractUiExtensionData(approvalExtension); export enum TaskStatusUpdateType { SecretRequired = 'secret-required', FormRequired = 'form-required', OAuthRequired = 'oauth-required', + ApprovalRequired = 'approval-required', } export interface SecretRequiredResult { @@ -37,7 +41,16 @@ export interface OAuthRequiredResult { url: string; } -export type TaskStatusUpdateResult = SecretRequiredResult | FormRequiredResult | OAuthRequiredResult; +export interface ApprovalRequiredResult { + type: TaskStatusUpdateType.ApprovalRequired; + request: ApprovalRequest; +} + +export type TaskStatusUpdateResult = + | SecretRequiredResult + | FormRequiredResult + | OAuthRequiredResult + | ApprovalRequiredResult; export const handleTaskStatusUpdate = (event: TaskStatusUpdateEvent): TaskStatusUpdateResult[] => { const results: TaskStatusUpdateResult[] = []; @@ -61,6 +74,7 @@ export const handleTaskStatusUpdate = (event: TaskStatusUpdateEvent): TaskStatus } } else if (event.status.state === 'input-required') { const formRequired = FormRequestExtensionExtractor(event.status.message?.metadata); + const approvalRequired = approvalExtensionExtractor(event.status.message?.metadata); if (formRequired) { results.push({ @@ -68,6 +82,13 @@ export const handleTaskStatusUpdate = (event: TaskStatusUpdateEvent): TaskStatus form: formRequired, }); } + + if (approvalRequired) { + results.push({ + type: TaskStatusUpdateType.ApprovalRequired, + request: approvalRequired, + }); + } } return results; diff --git a/apps/agentstack-sdk-ts/src/client/a2a/extensions/interactions/approval.ts b/apps/agentstack-sdk-ts/src/client/a2a/extensions/interactions/approval.ts new file mode 100644 index 000000000..097073cb9 --- /dev/null +++ b/apps/agentstack-sdk-ts/src/client/a2a/extensions/interactions/approval.ts @@ -0,0 +1,50 @@ +/** + * Copyright 2025 © BeeAI a Series of LF Projects, LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import z from 'zod'; + +import type { A2AUiExtension } from '../types'; + +const URI = 'https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1'; + +export const genericApprovalRequestSchema = z.object({ + action: z.literal('generic'), + title: z.string().nullish().describe('A human-readable title for the action being approved.'), + description: z.string().nullish().describe('A human-readable description of the action being approved.'), +}); +export type GenericApprovalRequest = z.infer; + +export const toolCallApprovalRequestSchema = z.object({ + action: z.literal('tool-call'), + title: z.string().nullish().describe('A human-readable title for the tool call being approved.'), + description: z.string().nullish().describe('A human-readable description of the tool call being approved.'), + name: z.string().describe('The programmatic name of the tool.'), + input: z.object().nullish().describe('The input for the tool.'), + server: z + .object({ + name: z.string().describe('The programmatic name of the server.'), + title: z.string().nullish().describe('A human-readable title for the server.'), + version: z.string().describe('The version of the server.'), + }) + .nullish() + .describe('The server executing the tool.'), +}); +export type ToolCallApprovalRequest = z.infer; + +export const approvalRequestSchema = z.discriminatedUnion('action', [ + genericApprovalRequestSchema, + toolCallApprovalRequestSchema, +]); +export type ApprovalRequest = z.infer; + +export const approvalResultSchema = z.object({ + decision: z.enum(['approve', 'reject']), +}); +export type ApprovalResult = z.infer; + +export const approvalExtension: A2AUiExtension = { + getMessageMetadataSchema: () => z.object({ [URI]: approvalRequestSchema }).partial(), + getUri: () => URI, +}; diff --git a/docs/development/agent-integration/tool-calls.mdx b/docs/development/agent-integration/tool-calls.mdx index 07c0d1c40..e35332403 100644 --- a/docs/development/agent-integration/tool-calls.mdx +++ b/docs/development/agent-integration/tool-calls.mdx @@ -5,18 +5,18 @@ description: "Have tool calls approved by the user before execution" Many agent frameworks support the ability to request user approval before executing certain actions. This is especially useful when an agent is calling external tools that may have significant effects or costs associated with their usage. -The Tool Call extension provides a mechanism for implementing this functionality over A2A connection. +The Approval extension provides a mechanism for implementing this functionality over A2A connection. ## Usage - - Inject the `ToolCallExtension` into your agent function using the `Annotated` + + Inject the `ApprovalExtension` into your agent function using the `Annotated` type hint. - Use `request_tool_call_approval()` method to request tool call approval from the A2A client side. + Use `request_approval()` method to request tool call approval from the A2A client side. @@ -32,11 +32,11 @@ from a2a.types import ( ) from agentstack_sdk.server import Server from agentstack_sdk.server.context import RunContext -from agentstack_sdk.a2a.extensions.tools.call import ( - ToolCallExtensionParams, - ToolCallExtensionServer, - ToolCallExtensionSpec, - ToolCallRequest, +from agentstack_sdk.a2a.extensions.interactions.approval import ( + ApprovalExtensionParams, + ApprovalExtensionServer, + ApprovalExtensionSpec, + ToolCallApprovalRequest, ) from agentstack_sdk.a2a.extensions.tools.exceptions import ToolCallRejectionError from beeai_framework.agents.requirement import RequirementAgent @@ -53,18 +53,16 @@ server = Server() async def tool_call_agent( input: Message, context: RunContext, - mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())], + approval_ext: Annotated[ApprovalExtensionServer, ApprovalExtensionSpec(params=ApprovalExtensionParams())], ): async def handler(tool: Tool, input: dict[str, Any]) -> bool: - try: - await mcp_tool_call.request_tool_call_approval( - # using MCP Tool data model as intermediary to simplify conversion - ToolCallRequest.from_mcp_tool(_tool_factory(tool), input=input), # type: ignore - context=context, - ) - return True - except ToolCallRejectionError: - return False + + response = await approval_ext.request_approval( + # using MCP Tool data model as intermediary to simplify conversion + ToolCallApprovalRequest.from_mcp_tool(_tool_factory(tool), input=input), # type: ignore + context=context, + ) + return response.approved think_tool = ThinkTool() agent = RequirementAgent(