diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/server/app.py b/apps/agentstack-sdk-py/src/agentstack_sdk/server/app.py index 070f69641..15fbfe32b 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/server/app.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/server/app.py @@ -18,8 +18,6 @@ from a2a.types import AgentInterface, TransportProtocol from fastapi import APIRouter, Depends, FastAPI from fastapi.applications import AppType -from starlette.authentication import AuthenticationBackend -from starlette.middleware.authentication import AuthenticationMiddleware from starlette.types import Lifespan from agentstack_sdk.server.agent import Agent, Executor @@ -39,7 +37,6 @@ def create_app( dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues] override_interfaces: bool = True, task_timeout: timedelta = timedelta(minutes=10), - auth_backend: AuthenticationBackend | None = None, **kwargs, ) -> FastAPI: queue_manager = queue_manager or InMemoryQueueManager() @@ -78,10 +75,6 @@ def create_app( **kwargs, ) - if auth_backend: - rest_app.add_middleware(AuthenticationMiddleware, backend=auth_backend) - jsonrpc_app.add_middleware(AuthenticationMiddleware, backend=auth_backend) - rest_app.mount("/jsonrpc", jsonrpc_app) rest_app.include_router(APIRouter(lifespan=lifespan)) return rest_app diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/server/middleware/platform_auth_backend.py b/apps/agentstack-sdk-py/src/agentstack_sdk/server/middleware/platform_auth_backend.py index aa316f110..2eeff3d6e 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/server/middleware/platform_auth_backend.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/server/middleware/platform_auth_backend.py @@ -7,6 +7,7 @@ from urllib.parse import urljoin from a2a.auth.user import User +from a2a.types import AgentCard, HTTPAuthSecurityScheme, SecurityScheme from async_lru import alru_cache from authlib.jose import JsonWebKey, JWTClaims, KeySet, jwt from authlib.jose.errors import JoseError @@ -15,7 +16,6 @@ from pydantic import Secret from starlette.authentication import ( AuthCredentials, - AuthenticationBackend, AuthenticationError, BaseUser, ) @@ -23,7 +23,7 @@ from typing_extensions import override from agentstack_sdk.platform import use_platform_client -from agentstack_sdk.types import JsonValue +from agentstack_sdk.types import JsonValue, SdkAuthenticationBackend logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ async def discover_jwks() -> KeySet: raise RuntimeError(f"JWKS discovery failed for url {url}") from e -class PlatformAuthBackend(AuthenticationBackend): +class PlatformAuthBackend(SdkAuthenticationBackend): def __init__(self, public_url: str | None = None, skip_audience_validation: bool | None = None) -> None: self.skip_audience_validation: bool = ( skip_audience_validation @@ -129,3 +129,16 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas except Exception as e: logger.error(f"Authentication error: {e}") raise AuthenticationError(f"Authentication failed: {e}") from e + + @override + def update_card_security_schemes(self, agent_card: AgentCard) -> None: + agent_card.security_schemes = { + "platform_context_token": SecurityScheme( + HTTPAuthSecurityScheme( + scheme="bearer", + bearer_format="JWT", + description="Platform context token, issued by the AgentStack server using POST /api/v1/context/{context_id}/token.", + ) + ), + } + agent_card.security = [{"platform_context_token": []}] diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/server/server.py b/apps/agentstack-sdk-py/src/agentstack_sdk/server/server.py index af6c4b84f..84493c332 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/server/server.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/server/server.py @@ -24,7 +24,7 @@ from fastapi.responses import PlainTextResponse from httpx import HTTPError, HTTPStatusError from pydantic import AnyUrl -from starlette.authentication import AuthenticationBackend, AuthenticationError +from starlette.authentication import AuthenticationError from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import HTTPConnection from starlette.types import Lifespan @@ -39,6 +39,7 @@ from agentstack_sdk.server.store.memory_context_store import InMemoryContextStore from agentstack_sdk.server.telemetry import configure_telemetry as configure_telemetry_func from agentstack_sdk.server.utils import cancel_task +from agentstack_sdk.types import SdkAuthenticationBackend from agentstack_sdk.util.logging import configure_logger as configure_logger_func from agentstack_sdk.util.logging import logger @@ -131,7 +132,7 @@ async def serve( factory: bool = False, h11_max_incomplete_event_size: int | None = None, self_registration_client_factory: Callable[[], PlatformClient] | None = None, - auth_backend: AuthenticationBackend | None = None, + auth_backend: SdkAuthenticationBackend | None = None, ) -> None: if self.server: raise RuntimeError("The server is already running") @@ -201,9 +202,11 @@ async def _lifespan_fn(app: FastAPI) -> AsyncGenerator[None, None]: push_sender=push_sender, task_timeout=task_timeout, request_context_builder=request_context_builder, + auth_backend=auth_backend, ) if auth_backend: + auth_backend.update_card_security_schemes(self._agent.card) def on_error(connection: HTTPConnection, error: AuthenticationError) -> PlainTextResponse: return PlainTextResponse("Unauthorized", status_code=401) diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/types.py b/apps/agentstack-sdk-py/src/agentstack_sdk/types.py index 1c44b41b3..71cc218e1 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/types.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/types.py @@ -1,8 +1,12 @@ # Copyright 2025 © BeeAI a Series of LF Projects, LLC # SPDX-License-Identifier: Apache-2.0 +import abc from typing import TYPE_CHECKING, TypeAlias +from a2a.types import AgentCard +from starlette.authentication import AuthenticationBackend + if TYPE_CHECKING: JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None JsonDict: TypeAlias = dict[str, JsonValue] @@ -13,3 +17,8 @@ JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007 JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]") + + +class SdkAuthenticationBackend(AuthenticationBackend, abc.ABC): + @abc.abstractmethod + def update_card_security_schemes(self, agent_card: AgentCard) -> None: ... diff --git a/apps/agentstack-server/src/agentstack_server/api/routes/a2a.py b/apps/agentstack-server/src/agentstack_server/api/routes/a2a.py index 0c7ecd3d0..8940fd0f3 100644 --- a/apps/agentstack-server/src/agentstack_server/api/routes/a2a.py +++ b/apps/agentstack-server/src/agentstack_server/api/routes/a2a.py @@ -20,30 +20,14 @@ RequiresPermissions, authorized_user, ) -from agentstack_server.configuration import Configuration from agentstack_server.domain.models.permissions import AuthorizedUser from agentstack_server.service_layer.services.a2a import A2AServerResponse router = fastapi.APIRouter() -def create_proxy_agent_card( - agent_card: AgentCard, *, provider_id: UUID, request: Request, configuration: Configuration -) -> AgentCard: +def create_proxy_agent_card(agent_card: AgentCard, *, provider_id: UUID, request: Request) -> AgentCard: proxy_base = str(request.url_for(a2a_proxy_jsonrpc_transport.__name__, provider_id=provider_id)) - - proxy_security = [] - proxy_security_schemes = {} - if not configuration.auth.disable_auth: - # Note that we're purposefully not using oAuth but a more generic http scheme. - # This is because we don't want to declare the auth metadata but prefer discovery through related RFCs - # The http scheme also covers internal jwt tokens - proxy_security.append({"bearer": []}) - proxy_security_schemes["bearer"] = SecurityScheme(HTTPAuthSecurityScheme(scheme="bearer")) - if configuration.auth.basic.enabled: - proxy_security.append({"basic": []}) - proxy_security_schemes["basic"] = SecurityScheme(HTTPAuthSecurityScheme(scheme="basic")) - return agent_card.model_copy( update={ "preferred_transport": TransportProtocol.jsonrpc, @@ -52,8 +36,16 @@ def create_proxy_agent_card( AgentInterface(transport=TransportProtocol.http_json, url=urljoin(proxy_base, "http")), AgentInterface(transport=TransportProtocol.jsonrpc, url=proxy_base), ], - "security": proxy_security, - "security_schemes": proxy_security_schemes, + "security": {"platform_context_token": []}, + "security_schemes": { + "platform_context_token": SecurityScheme( + HTTPAuthSecurityScheme( + scheme="bearer", + bearer_format="JWT", + description="Platform context token, issued by the AgentStack server using POST /api/v1/context/{context_id}/token.", + ) + ), + }, } ) @@ -80,9 +72,7 @@ async def get_agent_card( user = RequiresPermissions(a2a_proxy={provider_id})(user) # try a2a proxy permissions provider = await provider_service.get_provider(provider_id=provider_id) - return create_proxy_agent_card( - provider.agent_card, provider_id=provider.id, request=request, configuration=configuration - ) + return create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) @router.post("/{provider_id}") @@ -98,9 +88,7 @@ async def a2a_proxy_jsonrpc_transport( user = RequiresPermissions(a2a_proxy={provider_id})(user) provider = await provider_service.get_provider(provider_id=provider_id) - agent_card = create_proxy_agent_card( - provider.agent_card, provider_id=provider.id, request=request, configuration=configuration - ) + agent_card = create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) handler = await a2a_proxy.get_request_handler(provider=provider, user=user.user) app = A2AFastAPIApplication(agent_card=agent_card, http_handler=handler) @@ -122,9 +110,7 @@ async def a2a_proxy_http_transport( ): user = RequiresPermissions(a2a_proxy={provider_id})(user) provider = await provider_service.get_provider(provider_id=provider_id) - agent_card = create_proxy_agent_card( - provider.agent_card, provider_id=provider.id, request=request, configuration=configuration - ) + agent_card = create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) handler = await a2a_proxy.get_request_handler(provider=provider, user=user.user) adapter = RESTAdapter(agent_card=agent_card, http_handler=handler) diff --git a/apps/agentstack-server/src/agentstack_server/api/routes/providers.py b/apps/agentstack-server/src/agentstack_server/api/routes/providers.py index 260ae62a3..2c4fd24f6 100644 --- a/apps/agentstack-server/src/agentstack_server/api/routes/providers.py +++ b/apps/agentstack-server/src/agentstack_server/api/routes/providers.py @@ -84,9 +84,7 @@ async def list_providers( for provider in await provider_service.list_providers(user=user.user, user_owned=user_owned, origin=origin): new_provider = provider.model_copy( update={ - "agent_card": create_proxy_agent_card( - provider.agent_card, provider_id=provider.id, request=request, configuration=configuration - ) + "agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) } ) providers.append(EntityModel(new_provider)) @@ -106,9 +104,7 @@ async def get_provider( return EntityModel( provider.model_copy( update={ - "agent_card": create_proxy_agent_card( - provider.agent_card, provider_id=provider.id, request=request, configuration=configuration - ) + "agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) } ) ) @@ -130,9 +126,7 @@ async def get_provider_by_location( return EntityModel( provider.model_copy( update={ - "agent_card": create_proxy_agent_card( - provider.agent_card, provider_id=provider.id, request=request, configuration=configuration - ) + "agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) } ) )