Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/agentstack-server/src/agentstack_server/api/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ParsedToken(BaseModel):
context_permissions: Permissions
context_id: UUID
user_id: UUID
role_version: int
raw: dict[str, Any]


Expand All @@ -101,6 +102,7 @@ def issue_internal_jwt(
global_permissions: Permissions,
context_permissions: Permissions,
configuration: Configuration,
role_version: int,
) -> tuple[str, AwareDatetime]:
assert configuration.auth.jwt_secret_key
secret_key = configuration.auth.jwt_secret_key.get_secret_value()
Expand All @@ -119,6 +121,7 @@ def issue_internal_jwt(
"global": global_permissions.model_dump(mode="json"),
"context": context_permissions.model_dump(mode="json"),
},
"token_version": role_version,
}
return jwt.encode(header, payload, key=secret_key), expires_at

Expand All @@ -134,6 +137,7 @@ def verify_internal_jwt(token: str, configuration: Configuration) -> ParsedToken
"exp": {"essential": True},
"iss": {"essential": True, "value": "agentstack-server"},
"aud": {"essential": True, "value": "agentstack-server"},
"token_version": {"essential": True},
},
)
context_id = UUID(payload["resource"][0].replace("context:", ""))
Expand All @@ -142,6 +146,7 @@ def verify_internal_jwt(token: str, configuration: Configuration) -> ParsedToken
context_permissions=Permissions.model_validate(payload["scope"]["context"]),
context_id=context_id,
user_id=UUID(payload["sub"]),
role_version=int(payload["token_version"]),
raw=payload,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ async def authorized_user(
try:
parsed_token = verify_internal_jwt(bearer_auth.credentials, configuration=configuration)
user = await user_service.get_user(parsed_token.user_id)

token_version = parsed_token.role_version
if token_version < user.role_version:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token invalidated due to role change",
)

token = AuthorizedUser(
user=user,
global_permissions=parsed_token.global_permissions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def generate_context_token(
global_permissions=global_grant,
context_permissions=context_grant,
configuration=configuration,
role_version=user.user.role_version,
)
return ContextTokenResponse(token=token, expires_at=expires_at)

Expand Down
49 changes: 49 additions & 0 deletions apps/agentstack-server/src/agentstack_server/api/routes/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel

from agentstack_server.api.dependencies import UserServiceDependency, authorized_user
from agentstack_server.domain.models.permissions import AuthorizedUser
from agentstack_server.domain.models.user import UserRole

logger = logging.getLogger(__name__)

router = APIRouter(tags=["users"])


class ChangeRoleRequest(BaseModel):
new_role: UserRole


class ChangeRoleResponse(BaseModel):
user_id: UUID
new_role: UserRole
role_version: int


@router.put("/users/{user_id}/role", response_model=ChangeRoleResponse)
async def change_user_role(
user_id: UUID,
request: ChangeRoleRequest,
user: Annotated[AuthorizedUser, Depends(authorized_user)],
user_service: UserServiceDependency,
) -> ChangeRoleResponse:
if not user.user.role == UserRole.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")

if user_id == user.user.id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change own role")

updated_user = await user_service.change_role(user_id=user_id, new_role=request.new_role)

return ChangeRoleResponse(
user_id=updated_user.id,
new_role=updated_user.role,
role_version=updated_user.role_version,
)
2 changes: 2 additions & 0 deletions apps/agentstack-server/src/agentstack_server/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from agentstack_server.api.routes.providers import router as provider_router
from agentstack_server.api.routes.user import router as user_router
from agentstack_server.api.routes.user_feedback import router as user_feedback_router
from agentstack_server.api.routes.users import router as users_router
from agentstack_server.api.routes.variables import router as variables_router
from agentstack_server.api.routes.vector_stores import router as vector_stores_router
from agentstack_server.api.utils import format_openai_error
Expand Down Expand Up @@ -118,6 +119,7 @@ async def custom_http_exception_handler(request: Request, exc: Exception):
def mount_routes(app: FastAPI):
server_router = APIRouter()
server_router.include_router(user_router, prefix="/user")
server_router.include_router(users_router, prefix="/users")
server_router.include_router(a2a_router, prefix="/a2a")
server_router.include_router(mcp_router, prefix="/mcp")
server_router.include_router(provider_router, prefix="/providers", tags=["providers"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ class User(BaseModel):
id: UUID = Field(default_factory=uuid4)
role: UserRole = UserRole.USER
email: EmailStr
role_version: int = 1
role_updated_at: AwareDatetime | None = None
created_at: AwareDatetime = Field(default_factory=utc_now)
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ async def create(self, *, user: User) -> None: ...
async def get(self, *, user_id: UUID) -> User: ...
async def get_by_email(self, *, email: str) -> User: ...
async def delete(self, *, user_id: UUID) -> int: ...
async def update(self, *, user: User) -> None: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

"""add role versioning to users

Revision ID: 4jowyo7q9m66
Revises: ef8769062e65
Create Date: 2025-12-18 14:00:00.000000

"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

revision: str = "4jowyo7q9m66"
down_revision: str | None = "ef8769062e65"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade schema."""
op.add_column("users", sa.Column("role_version", sa.Integer(), nullable=False, server_default="1"))
op.add_column("users", sa.Column("role_updated_at", sa.DateTime(timezone=True), nullable=True))


def downgrade() -> None:
"""Downgrade schema."""
op.drop_column("users", "role_updated_at")
op.drop_column("users", "role_version")
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from kink import inject
from sqlalchemy import UUID as SQL_UUID
from sqlalchemy import Column, DateTime, Row, String, Table
from sqlalchemy import Column, DateTime, Integer, Row, String, Table
from sqlalchemy.ext.asyncio import AsyncConnection

from agentstack_server.domain.models.user import User, UserRole
Expand All @@ -21,6 +21,8 @@
Column("email", String(256), nullable=False, unique=True),
Column("created_at", DateTime(timezone=True), nullable=False),
Column("role", sql_enum(UserRole), nullable=False),
Column("role_version", Integer, nullable=False, server_default="1"),
Column("role_updated_at", DateTime(timezone=True), nullable=True),
)


Expand All @@ -31,10 +33,7 @@ def __init__(self, connection: AsyncConnection):

async def create(self, *, user: User) -> None:
query = users_table.insert().values(
id=user.id,
email=user.email,
created_at=user.created_at,
role=user.role,
id=user.id, email=user.email, created_at=user.created_at, role=user.role, role_version=user.role_version
)
await self.connection.execute(query)

Expand All @@ -45,6 +44,8 @@ def _to_user(self, row: Row):
"email": row.email,
"created_at": row.created_at,
"role": row.role,
"role_version": row.role_version,
"role_updated_at": row.role_updated_at,
}
)

Expand Down Expand Up @@ -73,3 +74,17 @@ async def list(self):
query = users_table.select()
async for row in await self.connection.stream(query):
yield self._to_user(row)

async def update(self, *, user: User) -> None:
query = (
users_table.update()
.where(users_table.c.id == user.id)
.values(
role=user.role,
role_version=user.role_version,
role_updated_at=user.role_updated_at,
)
)
result = await self.connection.execute(query)
if not result.rowcount:
raise EntityNotFoundError(entity="user", id=user.id)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from agentstack_server.domain.repositories.env import EnvStoreEntity
from agentstack_server.exceptions import UsageLimitExceededError
from agentstack_server.service_layer.unit_of_work import IUnitOfWorkFactory
from agentstack_server.utils.utils import utc_now

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,3 +59,18 @@ async def list_user_env(self, *, user: User) -> dict[str, str]:
async with self._uow() as uow:
env = await uow.env.get_all(parent_entity=EnvStoreEntity.USER, parent_entity_ids=[user.id])
return env[user.id]

async def change_role(self, user_id: UUID, new_role: UserRole) -> User:
async with self._uow() as uow:
user = await uow.users.get(user_id=user_id)

if user.role == new_role:
raise ValueError("User already has this role")

user.role = new_role
user.role_version += 1
user.role_updated_at = utc_now()

await uow.users.update(user=user)
await uow.commit()
return user
Loading