diff --git a/.gitignore b/.gitignore index c2d352a..3aa0787 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,6 @@ cython_debug/ # static files generated from Django application using `collectstatic` media static -.vscode/ \ No newline at end of file +.vscode/ +.idea/ +.envrc diff --git a/app/database.py b/app/database.py index 73fc456..c5ac940 100644 --- a/app/database.py +++ b/app/database.py @@ -1,13 +1,82 @@ +from pathlib import Path + +from flask import _app_ctx_stack, current_app, has_app_context from sqlalchemy import create_engine +from sqlalchemy.engine.url import URL from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import scoped_session, sessionmaker +from werkzeug.local import LocalProxy + +# Constants +DB_EXTENSION_KEY = "sa_session" +SQLALCHEMY_DATABASE_URL = URL("sqlite", database=str(Path.cwd().joinpath("test.db"))) +# SQLALCHEMY_DATABASE_URL = URL( +# "postgresql", +# username="user", +# password="password", +# host="postgresd", +# database="db", +# ) -SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" -# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" +# I see people adding this here and not sure if it's needed since `NullPool` is the default for SqLite. +# Does need to be added when `StaticPool` is used. engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# The mixin could have passed the mixin here as well. +# Base = declarative_base(cls=DictMixIn) Base = declarative_base() + + +def init_db(app): + """ + + Create ``scoped_session`` only if we are initializing within flask because sessions + will not close properly outside of the context. + We can add optional ``query_property`` to the ``Base``. + + Parameters + ---------- + app : flask.Flask + + """ + + # Import models locally just to ensure they get added to the registry on app load + from app import models + + db_session = scoped_session(SessionLocal, scopefunc=_app_ctx_stack.__ident_func__) + Base.query = db_session.query_property() + Base.metadata.create_all(bind=engine) + + if DB_EXTENSION_KEY not in app.extensions: + app.extensions[DB_EXTENSION_KEY] = db_session + app.db = db_session + + @app.teardown_appcontext + def remove_db_session(exception=None): + """Terminates all connections, transactions or stale, in session and checks them back into pool""" + db_session.remove() + + +# A little too advanced maybe? skip? +def _get_db(): + """ + + Returns + ------- + sqlalchemy.orm.Session + session obj stored in global Flask App. + """ + if has_app_context(): + assert ( + DB_EXTENSION_KEY in current_app.extensions + ), "`db_session` might not have been registered with the current app" + return current_app.extensions[DB_EXTENSION_KEY] + raise RuntimeError("No application context found.") + + +db = LocalProxy(_get_db) diff --git a/app/main.py b/app/main.py index 10c3401..8307295 100644 --- a/app/main.py +++ b/app/main.py @@ -1,28 +1,24 @@ -from flask import Flask, _app_ctx_stack, jsonify, url_for +from flask import Flask, jsonify, url_for from flask_cors import CORS -from sqlalchemy.orm import scoped_session -from . import models -from .database import SessionLocal, engine - -models.Base.metadata.create_all(bind=engine) +from app import models +from app.database import db, init_db app = Flask(__name__) CORS(app) -app.session = scoped_session(SessionLocal, scopefunc=_app_ctx_stack.__ident_func__) +init_db(app) @app.route("/") def main(): - return f"See the data at {url_for('show_records')}" + href = url_for("show_records") + return f'See the data at {href}' @app.route("/records/") def show_records(): - records = app.session.query(models.Record).all() - return jsonify([record.to_dict() for record in records]) - -@app.teardown_appcontext -def remove_session(*args, **kwargs): - app.session.remove() + # records = app.db.query(models.Record).all() # use db attribute + # records = models.Record.query.all() # use query_property + records = db.query(models.Record).all() # or use proxy + return jsonify([record.to_dict() for record in records]) diff --git a/app/models.py b/app/models.py index 003b395..890a804 100644 --- a/app/models.py +++ b/app/models.py @@ -1,7 +1,9 @@ +import datetime + from sqlalchemy import Column, Integer, String from sqlalchemy.types import Date + from .database import Base -import datetime class DictMixIn: @@ -16,7 +18,7 @@ def to_dict(self): } -class Record(Base, DictMixIn): +class Record(DictMixIn, Base): __tablename__ = "Records" id = Column(Integer, primary_key=True, index=True)