diff --git a/.gitignore b/.gitignore index e78a91c..9ffe834 100644 --- a/.gitignore +++ b/.gitignore @@ -139,6 +139,7 @@ venv.bak/ # Rope project settings .ropeproject +.vscode # mkdocs documentation /site diff --git a/src/labs/api.py b/src/labs/api.py index e6ab57c..3466790 100644 --- a/src/labs/api.py +++ b/src/labs/api.py @@ -10,10 +10,10 @@ API endpoints are built and served using the FastAPI micro-framework. """ +from contextlib import asynccontextmanager from . import __title__, __version__ from fastapi import FastAPI, Request, status, WebSocket -from fastapi.responses import JSONResponse from fastapi.routing import APIRoute from .settings import settings @@ -39,12 +39,25 @@ def generate_operation_id(route: APIRoute) -> str: return route.name +@asynccontextmanager +async def lifespan(_fastapi: FastAPI): + if broker.is_worker_process: + # TaskIQ configurartion so we can share FastAPI dependencies in tasks + await broker.startup() + + yield + + if broker.is_worker_process: + # On shutdown, we need to shutdown the broker + await broker.shutdown() + + """A FastAPI application that serves handlers """ app = FastAPI( title=__title__, version=__version__, - description=settings.api_router.__doc__, + description=str(settings.api_router.__doc__), docs_url=settings.api_router.path_docs, root_path=settings.api_router.path_root, terms_of_service=settings.api_router.terms_of_service, @@ -52,6 +65,7 @@ def generate_operation_id(route: APIRoute) -> str: license_info=settings.api_router.license_info, openapi_tags=settings.api_router.open_api_tags, generate_unique_id_function=generate_operation_id, + lifespan=lifespan ) @@ -67,23 +81,7 @@ async def websocket_endpoint(websocket: WebSocket): app.include_router(router_root) -# TaskIQ configurartion so we can share FastAPI dependencies in tasks -@app.on_event("startup") -async def app_startup(): - if not broker.is_worker_process: - await broker.startup() - -# On shutdown, we need to shutdown the broker - - -@app.on_event("shutdown") -async def app_shutdown(): - if not broker.is_worker_process: - await broker.shutdown() - # Default handler - - @app.get( "/", status_code=status.HTTP_200_OK, @@ -93,5 +91,5 @@ async def root(request: Request) -> RootResponse: """ return RootResponse( message="Welcome to the {} API".format(__name__), - root_path=request.scope.get("root_path") + root_path=str(request.scope.get("root_path")) ) diff --git a/src/labs/db.py b/src/labs/db.py index 6ac6607..c6d5312 100644 --- a/src/labs/db.py +++ b/src/labs/db.py @@ -6,10 +6,11 @@ """ +from typing import AsyncGenerator from sqlalchemy.ext.asyncio import create_async_engine,\ - AsyncSession + AsyncSession, async_sessionmaker, AsyncAttrs from sqlalchemy.orm import DeclarativeBase,\ - configure_mappers, sessionmaker + configure_mappers from .settings import settings @@ -24,19 +25,16 @@ configure_mappers() # Get an async session from the engine +AsyncSessionFactory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) -async def get_async_session() -> AsyncSession: - async_session = sessionmaker( - engine, class_=AsyncSession, expire_on_commit=False - ) - async with async_session() as session: +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSessionFactory() as session: yield session -# Used by the ORM layer to describe models - -class Base(DeclarativeBase): +# Used by the ORM layer to describe models +class Base(DeclarativeBase, AsyncAttrs): """ SQLAlchemy 2.0 style declarative base class https://bit.ly/3WE3Srg diff --git a/src/labs/email.py b/src/labs/email.py index dd6f773..7b8b29f 100644 --- a/src/labs/email.py +++ b/src/labs/email.py @@ -14,11 +14,22 @@ Redmail docs are located at https://red-mail.readthedocs.io/ """ import os -from redmail import EmailSender +from redmail.email.sender import EmailSender from .settings import settings -sender = EmailSender( +# Custom factory to be able to set the sender globally +class EmailSenderFactory(EmailSender): + @property + def sender(self): + return self.sender + + @sender.setter + def sender(self, sender: str): + self.sender = sender + + +sender = EmailSenderFactory( host=settings.smtp.host, port=settings.smtp.port, username=settings.smtp.user.get_secret_value(),