Skip to content

Commit

Permalink
Initial Casbin integration
Browse files Browse the repository at this point in the history
  • Loading branch information
robertbinning committed Oct 15, 2024
1 parent d6e5ae5 commit 98f6c50
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 3 deletions.
28 changes: 28 additions & 0 deletions backend/api/AbilitiesView.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
from backend.managers.AbilitiesManager import AbilitiesManager
from backend.pagination import parse_pagination_params
import logging
from backend.managers.AuthManager import AuthManager
from connexion.exceptions import Unauthorized
logger = logging.getLogger(__name__)

class AbilitiesView:
def __init__(self):
self.am = AbilitiesManager()
self.auth_manager = AuthManager()

def check_permission(self, user_id, action):
if not self.auth_manager.check_permission(user_id, 'abilities', action):
raise Unauthorized('Permission denied')

def error_immutable(self):
return JSONResponse(status_code=400, content={"message": "Invalid Request: Abilities must be installed and are immutable; their metadata.json files cannot be edited via the API."})
Expand All @@ -21,12 +28,16 @@ async def delete(self, id: str):
return self.error_immutable()

def get(self, id=None):
user_id = self.get_current_user_id() # Implement this method to get the current user's ID
self.check_permission(user_id, 'read')
ability = self.am.get_ability(id)
if ability:
return JSONResponse(status_code=200, content=ability)
return JSONResponse(status_code=404, content={"message": "Ability not found"})

async def search(self, filter: str = None, range: str = None, sort: str = None):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'read')
result = parse_pagination_params(filter, range, sort)
if isinstance(result, JSONResponse):
return result
Expand All @@ -49,6 +60,8 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
return JSONResponse(abilities, status_code=200, headers=headers)

async def install(self, id: str, version: str = None):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'write')
try:
if self.am.install_ability(id, version):
return JSONResponse(status_code=200, content={"message": "Ability installed"})
Expand All @@ -58,6 +71,8 @@ async def install(self, id: str, version: str = None):
return JSONResponse(status_code=400, content={"message": str(e)})

async def upgrade(self, id: str, version: str = None):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'write')
try:
if self.am.upgrade_ability(id, version):
return JSONResponse(status_code=200, content={"message": "Ability upgraded"})
Expand All @@ -67,6 +82,8 @@ async def upgrade(self, id: str, version: str = None):
return JSONResponse(status_code=400, content={"message": str(e)})

async def uninstall(self, id: str):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'write')
try:
if self.am.uninstall_ability(id):
return JSONResponse(status_code=200, content={"message": "Ability uninstalled"})
Expand All @@ -76,6 +93,8 @@ async def uninstall(self, id: str):
return JSONResponse(status_code=400, content={"message": str(e)})

async def install_dependency(self, id: str, dependency_id: str):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'write')
try:
await self.am.install_dependency(id, dependency_id)
return JSONResponse(status_code=202, content={"message": "Dependency install started"})
Expand All @@ -86,6 +105,8 @@ async def install_dependency(self, id: str, dependency_id: str):
return JSONResponse(status_code=500, content={"error": str(e)})

async def start(self, id: str):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'write')
try:
result = self.am.start_ability(id)
if "error" in result:
Expand All @@ -96,6 +117,8 @@ async def start(self, id: str):
return JSONResponse(status_code=500, content={"error": str(e)})

async def stop(self, id: str):
user_id = self.get_current_user_id()
self.check_permission(user_id, 'write')
try:
result = self.am.stop_ability(id)
if "error" in result:
Expand All @@ -104,3 +127,8 @@ async def stop(self, id: str):
return JSONResponse(status_code=200, content={"message": "Ability stopped"})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})

def get_current_user_id(self):
# Implement this method to get the current user's ID from the JWT token
# You may need to modify your JWT handling to include the user ID
pass
26 changes: 24 additions & 2 deletions backend/managers/AuthManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from webauthn.helpers.cose import COSEAlgorithmIdentifier
from connexion.exceptions import Unauthorized
from common.utils import get_env_key
import casbin
import os
from pathlib import Path

# set up logging
from common.log import get_logger
Expand Down Expand Up @@ -81,6 +84,21 @@ def __init__(self):
with self._lock:
if not hasattr(self, '_initialized'):
self._initialized = True
self._load_rbac_model()

def _load_rbac_model(self):
model_path = Path(__file__).parent.parent / 'rbac_model.conf'
policy_path = Path(__file__).parent.parent / 'rbac_policy.csv'
self.enforcer = casbin.Enforcer(str(model_path), str(policy_path))

def check_permission(self, sub, obj, act):
return self.enforcer.enforce(sub, obj, act)

def add_role_for_user(self, user, role):
self.enforcer.add_grouping_policy(user, role)

def get_roles_for_user(self, user):
return self.enforcer.get_roles_for_user(user)

async def webauthn_register_options(self, email_id: str):
async with db_session_context() as session:
Expand Down Expand Up @@ -146,6 +164,9 @@ async def webauthn_register(self, challenge: str, email_id: str, user_id: str, r
await session.commit()
await session.refresh(new_user)
user = new_user

# Add default role for new user
self.add_role_for_user(str(user.id), 'user')

base64url_cred_id = base64.urlsafe_b64encode(res.credential_id).decode("utf-8").rstrip("=")
base64url_public_key = base64.urlsafe_b64encode(res.credential_public_key).decode("utf-8").rstrip("=")
Expand All @@ -158,7 +179,8 @@ async def webauthn_register(self, challenge: str, email_id: str, user_id: str, r
payload = {
"sub": user.id,
"iat": datetime.now(timezone.utc),
"exp": datetime.now(timezone.utc) + timedelta(days=1)
"exp": datetime.now(timezone.utc) + timedelta(days=1),
"roles": self.get_roles_for_user(str(user.id)) # Include roles in the token
}

token = generate_jwt(payload)
Expand Down Expand Up @@ -252,4 +274,4 @@ async def delete_session(self, token: str):
async with db_session_context() as session:
stmt = delete(Session).where(Session.token == token)
await session.execute(stmt)
await session.commit()
await session.commit()
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ structlog
webauthn
greenlet
pyjwt
casbin
13 changes: 12 additions & 1 deletion frontend/src/authProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ import { AuthProvider } from "react-admin";
import { login, register, logout } from "./apis/auth";
import { jwtDecode } from "jwt-decode";

interface CustomJwtPayload {
roles: string[];
exp?: number;
}

export const authProvider: AuthProvider = {
// called when the user attempts to log in
login: async ({ email, isRegistering }) => {
Expand Down Expand Up @@ -51,5 +56,11 @@ export const authProvider: AuthProvider = {
return Promise.reject()
},
// called when the user navigates to a new location, to check for permissions / roles
getPermissions: () => Promise.resolve(),
getPermissions: () => {
const token = localStorage.getItem("token");
if (!token) return Promise.reject();

const decodedToken = jwtDecode<CustomJwtPayload>(token);
return Promise.resolve(decodedToken.roles);
},
};

0 comments on commit 98f6c50

Please sign in to comment.