Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions apps/agentstack-sdk-py/src/agentstack_sdk/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from a2a.types import AgentInterface, TransportProtocol
from fastapi import APIRouter, Depends, FastAPI
from fastapi.applications import AppType
from starlette.authentication import AuthenticationBackend
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.types import Lifespan

from agentstack_sdk.server.agent import Agent, Executor
Expand All @@ -39,7 +37,6 @@ def create_app(
dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues]
override_interfaces: bool = True,
task_timeout: timedelta = timedelta(minutes=10),
auth_backend: AuthenticationBackend | None = None,
**kwargs,
) -> FastAPI:
queue_manager = queue_manager or InMemoryQueueManager()
Expand Down Expand Up @@ -78,10 +75,6 @@ def create_app(
**kwargs,
)

if auth_backend:
rest_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
jsonrpc_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)

rest_app.mount("/jsonrpc", jsonrpc_app)
rest_app.include_router(APIRouter(lifespan=lifespan))
return rest_app
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import urljoin

from a2a.auth.user import User
from a2a.types import AgentCard, HTTPAuthSecurityScheme, SecurityScheme
from async_lru import alru_cache
from authlib.jose import JsonWebKey, JWTClaims, KeySet, jwt
from authlib.jose.errors import JoseError
Expand All @@ -15,15 +16,14 @@
from pydantic import Secret
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
BaseUser,
)
from starlette.requests import HTTPConnection
from typing_extensions import override

from agentstack_sdk.platform import use_platform_client
from agentstack_sdk.types import JsonValue
from agentstack_sdk.types import JsonValue, SdkAuthenticationBackend

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,7 +70,7 @@ async def discover_jwks() -> KeySet:
raise RuntimeError(f"JWKS discovery failed for url {url}") from e


class PlatformAuthBackend(AuthenticationBackend):
class PlatformAuthBackend(SdkAuthenticationBackend):
def __init__(self, public_url: str | None = None, skip_audience_validation: bool | None = None) -> None:
self.skip_audience_validation: bool = (
skip_audience_validation
Expand Down Expand Up @@ -129,3 +129,16 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas
except Exception as e:
logger.error(f"Authentication error: {e}")
raise AuthenticationError(f"Authentication failed: {e}") from e

@override
def update_card_security_schemes(self, agent_card: AgentCard) -> None:
agent_card.security_schemes = {
"platform_context_token": SecurityScheme(
HTTPAuthSecurityScheme(
scheme="bearer",
bearer_format="JWT",
description="Platform context token, issued by the AgentStack server using POST /api/v1/context/{context_id}/token.",
)
),
}
agent_card.security = [{"platform_context_token": []}]
7 changes: 5 additions & 2 deletions apps/agentstack-sdk-py/src/agentstack_sdk/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from fastapi.responses import PlainTextResponse
from httpx import HTTPError, HTTPStatusError
from pydantic import AnyUrl
from starlette.authentication import AuthenticationBackend, AuthenticationError
from starlette.authentication import AuthenticationError
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import HTTPConnection
from starlette.types import Lifespan
Expand All @@ -39,6 +39,7 @@
from agentstack_sdk.server.store.memory_context_store import InMemoryContextStore
from agentstack_sdk.server.telemetry import configure_telemetry as configure_telemetry_func
from agentstack_sdk.server.utils import cancel_task
from agentstack_sdk.types import SdkAuthenticationBackend
from agentstack_sdk.util.logging import configure_logger as configure_logger_func
from agentstack_sdk.util.logging import logger

Expand Down Expand Up @@ -131,7 +132,7 @@ async def serve(
factory: bool = False,
h11_max_incomplete_event_size: int | None = None,
self_registration_client_factory: Callable[[], PlatformClient] | None = None,
auth_backend: AuthenticationBackend | None = None,
auth_backend: SdkAuthenticationBackend | None = None,
) -> None:
if self.server:
raise RuntimeError("The server is already running")
Expand Down Expand Up @@ -201,9 +202,11 @@ async def _lifespan_fn(app: FastAPI) -> AsyncGenerator[None, None]:
push_sender=push_sender,
task_timeout=task_timeout,
request_context_builder=request_context_builder,
auth_backend=auth_backend,
)

if auth_backend:
auth_backend.update_card_security_schemes(self._agent.card)

def on_error(connection: HTTPConnection, error: AuthenticationError) -> PlainTextResponse:
return PlainTextResponse("Unauthorized", status_code=401)
Expand Down
9 changes: 9 additions & 0 deletions apps/agentstack-sdk-py/src/agentstack_sdk/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import abc
from typing import TYPE_CHECKING, TypeAlias

from a2a.types import AgentCard
from starlette.authentication import AuthenticationBackend

if TYPE_CHECKING:
JsonValue: TypeAlias = list["JsonValue"] | dict[str, "JsonValue"] | str | bool | int | float | None
JsonDict: TypeAlias = dict[str, JsonValue]
Expand All @@ -13,3 +17,8 @@

JsonValue = TypeAliasType("JsonValue", "Union[dict[str, JsonValue], list[JsonValue], str, int, float, bool, None]") # noqa: UP007
JsonDict = TypeAliasType("JsonDict", "dict[str, JsonValue]")


class SdkAuthenticationBackend(AuthenticationBackend, abc.ABC):
@abc.abstractmethod
def update_card_security_schemes(self, agent_card: AgentCard) -> None: ...
42 changes: 14 additions & 28 deletions apps/agentstack-server/src/agentstack_server/api/routes/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,14 @@
RequiresPermissions,
authorized_user,
)
from agentstack_server.configuration import Configuration
from agentstack_server.domain.models.permissions import AuthorizedUser
from agentstack_server.service_layer.services.a2a import A2AServerResponse

router = fastapi.APIRouter()


def create_proxy_agent_card(
agent_card: AgentCard, *, provider_id: UUID, request: Request, configuration: Configuration
) -> AgentCard:
def create_proxy_agent_card(agent_card: AgentCard, *, provider_id: UUID, request: Request) -> AgentCard:
proxy_base = str(request.url_for(a2a_proxy_jsonrpc_transport.__name__, provider_id=provider_id))

proxy_security = []
proxy_security_schemes = {}
if not configuration.auth.disable_auth:
# Note that we're purposefully not using oAuth but a more generic http scheme.
# This is because we don't want to declare the auth metadata but prefer discovery through related RFCs
# The http scheme also covers internal jwt tokens
proxy_security.append({"bearer": []})
proxy_security_schemes["bearer"] = SecurityScheme(HTTPAuthSecurityScheme(scheme="bearer"))
if configuration.auth.basic.enabled:
proxy_security.append({"basic": []})
proxy_security_schemes["basic"] = SecurityScheme(HTTPAuthSecurityScheme(scheme="basic"))
Comment on lines -35 to -45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this truly gone? I thought user token is also accepted, at least it has a2a_proxy scope right?


return agent_card.model_copy(
update={
"preferred_transport": TransportProtocol.jsonrpc,
Expand All @@ -52,8 +36,16 @@ def create_proxy_agent_card(
AgentInterface(transport=TransportProtocol.http_json, url=urljoin(proxy_base, "http")),
AgentInterface(transport=TransportProtocol.jsonrpc, url=proxy_base),
],
"security": proxy_security,
"security_schemes": proxy_security_schemes,
"security": {"platform_context_token": []},
"security_schemes": {
"platform_context_token": SecurityScheme(
HTTPAuthSecurityScheme(
scheme="bearer",
bearer_format="JWT",
description="Platform context token, issued by the AgentStack server using POST /api/v1/context/{context_id}/token.",
)
),
},
}
)

Expand All @@ -80,9 +72,7 @@ async def get_agent_card(
user = RequiresPermissions(a2a_proxy={provider_id})(user) # try a2a proxy permissions

provider = await provider_service.get_provider(provider_id=provider_id)
return create_proxy_agent_card(
provider.agent_card, provider_id=provider.id, request=request, configuration=configuration
)
return create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)


@router.post("/{provider_id}")
Expand All @@ -98,9 +88,7 @@ async def a2a_proxy_jsonrpc_transport(
user = RequiresPermissions(a2a_proxy={provider_id})(user)

provider = await provider_service.get_provider(provider_id=provider_id)
agent_card = create_proxy_agent_card(
provider.agent_card, provider_id=provider.id, request=request, configuration=configuration
)
agent_card = create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)

handler = await a2a_proxy.get_request_handler(provider=provider, user=user.user)
app = A2AFastAPIApplication(agent_card=agent_card, http_handler=handler)
Expand All @@ -122,9 +110,7 @@ async def a2a_proxy_http_transport(
):
user = RequiresPermissions(a2a_proxy={provider_id})(user)
provider = await provider_service.get_provider(provider_id=provider_id)
agent_card = create_proxy_agent_card(
provider.agent_card, provider_id=provider.id, request=request, configuration=configuration
)
agent_card = create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)

handler = await a2a_proxy.get_request_handler(provider=provider, user=user.user)
adapter = RESTAdapter(agent_card=agent_card, http_handler=handler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ async def list_providers(
for provider in await provider_service.list_providers(user=user.user, user_owned=user_owned, origin=origin):
new_provider = provider.model_copy(
update={
"agent_card": create_proxy_agent_card(
provider.agent_card, provider_id=provider.id, request=request, configuration=configuration
)
"agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)
}
)
providers.append(EntityModel(new_provider))
Expand All @@ -106,9 +104,7 @@ async def get_provider(
return EntityModel(
provider.model_copy(
update={
"agent_card": create_proxy_agent_card(
provider.agent_card, provider_id=provider.id, request=request, configuration=configuration
)
"agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)
}
)
)
Expand All @@ -130,9 +126,7 @@ async def get_provider_by_location(
return EntityModel(
provider.model_copy(
update={
"agent_card": create_proxy_agent_card(
provider.agent_card, provider_id=provider.id, request=request, configuration=configuration
)
"agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)
}
)
)
Expand Down
Loading