Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions apps/agentstack-sdk-py/examples/tool_call_approval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent

from agentstack_sdk.a2a.extensions.tools.call import (
ToolCallExtensionParams,
ToolCallExtensionServer,
ToolCallExtensionSpec,
ToolCallRequest,
from agentstack_sdk.a2a.extensions.interactions.approval import (
ApprovalExtensionParams,
ApprovalExtensionServer,
ApprovalExtensionSpec,
ToolCallApprovalRequest,
)
from agentstack_sdk.a2a.extensions.tools.exceptions import ToolCallRejectionError
from agentstack_sdk.server import Server
from agentstack_sdk.server.context import RunContext

Expand All @@ -25,7 +24,7 @@
async def tool_call_approval_agent(
message: Message,
context: RunContext,
mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())],
mcp_tool_call: Annotated[ApprovalExtensionServer, ApprovalExtensionSpec(params=ApprovalExtensionParams())],
):
async with (
streamablehttp_client(url="https://hf.co/mcp") as (read, write, _),
Expand All @@ -41,18 +40,18 @@ async def tool_call_approval_agent(
raise RuntimeError("Could not find whoami_tool on the server")

arguments = {}
try:
await mcp_tool_call.request_tool_call_approval(
ToolCallRequest.from_mcp_tool(whoami_tool, arguments, server=session_init_result.serverInfo),
context=context,
)
response = await mcp_tool_call.request_approval(
ToolCallApprovalRequest.from_mcp_tool(whoami_tool, arguments, server=session_init_result.serverInfo),
context=context,
)
if response.approved:
result = await session.call_tool("hf_whoami", arguments)
content = result.content[0]
if isinstance(content, TextContent):
yield content.text
else:
yield "Tool call succeeded"
except ToolCallRejectionError:
else:
yield "Tool call has been rejected by the client"


Expand Down
18 changes: 9 additions & 9 deletions apps/agentstack-sdk-py/examples/tool_call_approval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@
import httpx

import agentstack_sdk.a2a.extensions
from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse
from agentstack_sdk.a2a.extensions.interactions.approval import ApprovalResponse


async def run(base_url: str = "http://127.0.0.1:10000"):
async with httpx.AsyncClient(timeout=30) as httpx_client:
card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card()
tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card)
approval_spec = agentstack_sdk.a2a.extensions.ApprovalExtensionSpec.from_agent_card(card)

if not tool_call_spec:
raise ValueError(f"Agent at {base_url} does not support MCP Tool Call extension")
if not approval_spec:
raise ValueError(f"Agent at {base_url} does not support approval extension")

tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec)
approval_extension_client = agentstack_sdk.a2a.extensions.ApprovalExtensionClient(approval_spec)

message = a2a.types.Message(
message_id=str(uuid.uuid4()),
role=a2a.types.Role.user,
parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))],
metadata=tool_call_extension_client.metadata(),
metadata=approval_extension_client.metadata(),
)

client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create(
Expand All @@ -45,13 +45,13 @@ async def run(base_url: str = "http://127.0.0.1:10000"):
if not task.status.message:
raise RuntimeError("Missing message")

approval_request = tool_call_extension_client.parse_request(message=task.status.message)
approval_request = approval_extension_client.parse_request(message=task.status.message)

print("Agent has requested a tool call")
print(approval_request)
choice = input("Approve (Y/n): ")
response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject")
message = tool_call_extension_client.create_response_message(task_id=task.id, response=response)
response = ApprovalResponse(decision="approve" if choice.lower() == "y" else "reject")
message = approval_extension_client.create_response_message(task_id=task.id, response=response)
else:
break

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from .auth import *
from .interactions import *
from .services import *
from .tools import *
from .ui import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

from .approval import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import uuid
from types import NoneType
from typing import TYPE_CHECKING, Annotated, Any, Literal

import a2a.types
from mcp import Implementation, Tool
from pydantic import BaseModel, Discriminator, Field, TypeAdapter

from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
from agentstack_sdk.a2a.types import AgentMessage, InputRequired

if TYPE_CHECKING:
from agentstack_sdk.server.context import RunContext


class ApprovalRejectionError(RuntimeError):
pass


class GenericApprovalRequest(BaseModel):
action: Literal["generic"] = "generic"

title: str | None = Field(None, description="A human-readable title for the action being approved.")
description: str | None = Field(None, description="A human-readable description of the action being approved.")


class ToolCallServer(BaseModel):
name: str = Field(description="The programmatic name of the server.")
title: str | None = Field(description="A human-readable title for the server.")
version: str = Field(description="The version of the server.")


class ToolCallApprovalRequest(BaseModel):
action: Literal["tool-call"] = "tool-call"

title: str | None = Field(None, description="A human-readable title for the tool call being approved.")
description: str | None = Field(None, description="A human-readable description of the tool call being approved.")
name: str = Field(description="The programmatic name of the tool.")
input: dict[str, Any] | None = Field(description="The input for the tool.")
server: ToolCallServer | None = Field(None, description="The server executing the tool.")

@staticmethod
def from_mcp_tool(
tool: Tool, input: dict[str, Any] | None, server: Implementation | None = None
) -> ToolCallApprovalRequest:
return ToolCallApprovalRequest(
name=tool.name,
title=tool.annotations.title if tool.annotations else None,
description=tool.description,
input=input,
server=ToolCallServer(name=server.name, title=server.title, version=server.version) if server else None,
)


ApprovalRequest = Annotated[GenericApprovalRequest | ToolCallApprovalRequest, Discriminator("action")]


class ApprovalResponse(BaseModel):
decision: Literal["approve", "reject"]

@property
def approved(self) -> bool:
return self.decision == "approve"

def raise_on_rejection(self) -> None:
if self.decision == "reject":
raise ApprovalRejectionError("Approval request has been rejected")


class ApprovalExtensionParams(BaseModel):
pass


class ApprovalExtensionSpec(BaseExtensionSpec[ApprovalExtensionParams]):
URI: str = "https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1"


class ApprovalExtensionMetadata(BaseModel):
pass


class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]):
def create_request_message(self, *, request: ApprovalRequest):
return AgentMessage(text="Approval requested", metadata={self.spec.URI: request.model_dump(mode="json")})

def parse_response(self, *, message: a2a.types.Message):
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
raise ValueError("Approval response data is missing")
return ApprovalResponse.model_validate(data)

async def request_approval(
self,
request: ApprovalRequest,
*,
context: RunContext,
) -> ApprovalResponse:
message = self.create_request_message(request=request)
message = await context.yield_async(InputRequired(message=message))
if not message:
raise RuntimeError("Yield did not return a message")
return self.parse_response(message=message)


class ApprovalExtensionClient(BaseExtensionClient[ApprovalExtensionSpec, NoneType]):
def create_response_message(self, *, response: ApprovalResponse, task_id: str | None):
return a2a.types.Message(
message_id=str(uuid.uuid4()),
role=a2a.types.Role.user,
parts=[],
task_id=task_id,
metadata={self.spec.URI: response.model_dump(mode="json")},
)

def parse_request(self, *, message: a2a.types.Message):
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
raise ValueError("Approval request data is missing")
return TypeAdapter(ApprovalRequest).validate_python(data)

def metadata(self) -> dict[str, Any]:
return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json")}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any, Literal

import a2a.types
from deprecated import deprecated
from mcp import Tool
from mcp.types import Implementation
from pydantic import BaseModel, Field
Expand All @@ -26,6 +27,7 @@ class ToolCallServer(BaseModel):
version: str = Field(description="The version of the server.")


@deprecated(reason="Use ToolCallApprovalRequest instead")
class ToolCallRequest(BaseModel):
name: str = Field(description="The programmatic name of the tool.")
title: str | None = Field(None, description="A human-readable title for the tool.")
Expand Down Expand Up @@ -64,6 +66,7 @@ class ToolCallExtensionMetadata(BaseModel):
pass


@deprecated(reason="Use ApprovalExtensionServer instead")
class ToolCallExtensionServer(BaseExtensionServer[ToolCallExtensionSpec, ToolCallExtensionMetadata]):
def create_request_message(self, *, request: ToolCallRequest):
return AgentMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import type { TaskStatusUpdateEvent } from '@a2a-js/sdk';

import type { FormRender } from './common/form';
import type { ApprovalRequest } from './interactions/approval';
import { approvalExtension } from './interactions/approval';
import type { SecretDemands } from './services/secrets';
import { secretsMessageExtension } from './services/secrets';
import { FormRequestExtension } from './ui/form-request';
Expand All @@ -15,11 +17,13 @@ import { extractUiExtensionData } from './utils';
const secretsMessageExtensionExtractor = extractUiExtensionData(secretsMessageExtension);
const oauthRequestExtensionExtractor = extractUiExtensionData(oauthRequestExtension);
const FormRequestExtensionExtractor = extractUiExtensionData(FormRequestExtension);
const approvalExtensionExtractor = extractUiExtensionData(approvalExtension);

export enum TaskStatusUpdateType {
SecretRequired = 'secret-required',
FormRequired = 'form-required',
OAuthRequired = 'oauth-required',
ApprovalRequired = 'approval-required',
}

export interface SecretRequiredResult {
Expand All @@ -37,7 +41,16 @@ export interface OAuthRequiredResult {
url: string;
}

export type TaskStatusUpdateResult = SecretRequiredResult | FormRequiredResult | OAuthRequiredResult;
export interface ApprovalRequiredResult {
type: TaskStatusUpdateType.ApprovalRequired;
request: ApprovalRequest;
}

export type TaskStatusUpdateResult =
| SecretRequiredResult
| FormRequiredResult
| OAuthRequiredResult
| ApprovalRequiredResult;

export const handleTaskStatusUpdate = (event: TaskStatusUpdateEvent): TaskStatusUpdateResult[] => {
const results: TaskStatusUpdateResult[] = [];
Expand All @@ -61,13 +74,21 @@ export const handleTaskStatusUpdate = (event: TaskStatusUpdateEvent): TaskStatus
}
} else if (event.status.state === 'input-required') {
const formRequired = FormRequestExtensionExtractor(event.status.message?.metadata);
const approvalRequired = approvalExtensionExtractor(event.status.message?.metadata);

if (formRequired) {
results.push({
type: TaskStatusUpdateType.FormRequired,
form: formRequired,
});
}

if (approvalRequired) {
results.push({
type: TaskStatusUpdateType.ApprovalRequired,
request: approvalRequired,
});
}
}

return results;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* Copyright 2025 © BeeAI a Series of LF Projects, LLC
* SPDX-License-Identifier: Apache-2.0
*/

import z from 'zod';

import type { A2AUiExtension } from '../types';

const URI = 'https://a2a-extensions.agentstack.beeai.dev/interactions/approval/v1';

export const genericApprovalRequestSchema = z.object({
action: z.literal('generic'),
title: z.string().nullish().describe('A human-readable title for the action being approved.'),
description: z.string().nullish().describe('A human-readable description of the action being approved.'),
});
export type GenericApprovalRequest = z.infer<typeof genericApprovalRequestSchema>;

export const toolCallApprovalRequestSchema = z.object({
action: z.literal('tool-call'),
title: z.string().nullish().describe('A human-readable title for the tool call being approved.'),
description: z.string().nullish().describe('A human-readable description of the tool call being approved.'),
name: z.string().describe('The programmatic name of the tool.'),
input: z.object().nullish().describe('The input for the tool.'),
server: z
.object({
name: z.string().describe('The programmatic name of the server.'),
title: z.string().nullish().describe('A human-readable title for the server.'),
version: z.string().describe('The version of the server.'),
})
.nullish()
.describe('The server executing the tool.'),
});
export type ToolCallApprovalRequest = z.infer<typeof toolCallApprovalRequestSchema>;

export const approvalRequestSchema = z.discriminatedUnion('action', [
genericApprovalRequestSchema,
toolCallApprovalRequestSchema,
]);
export type ApprovalRequest = z.infer<typeof approvalRequestSchema>;

export const approvalResultSchema = z.object({
decision: z.enum(['approve', 'reject']),
});
export type ApprovalResult = z.infer<typeof approvalResultSchema>;

export const approvalExtension: A2AUiExtension<typeof URI, ApprovalRequest> = {
getMessageMetadataSchema: () => z.object({ [URI]: approvalRequestSchema }).partial(),
getUri: () => URI,
};
Loading