Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,6 @@ cython_debug/
# static files generated from Django application using `collectstatic`
media
static
.vscode/
.vscode/
.idea/
.envrc
75 changes: 72 additions & 3 deletions app/database.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,82 @@
from pathlib import Path

from flask import _app_ctx_stack, current_app, has_app_context
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session, sessionmaker
from werkzeug.local import LocalProxy

# Constants
DB_EXTENSION_KEY = "sa_session"
SQLALCHEMY_DATABASE_URL = URL("sqlite", database=str(Path.cwd().joinpath("test.db")))
# SQLALCHEMY_DATABASE_URL = URL(
# "postgresql",
# username="user",
# password="password",
# host="postgresd",
# database="db",
# )

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"

# I see people adding this here and not sure if it's needed since `NullPool` is the default for SqLite.
# Does need to be added when `StaticPool` is used.
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# The mixin could have passed the mixin here as well.
# Base = declarative_base(cls=DictMixIn)
Base = declarative_base()


def init_db(app):
"""

Create ``scoped_session`` only if we are initializing within flask because sessions
will not close properly outside of the context.
We can add optional ``query_property`` to the ``Base``.

Parameters
----------
app : flask.Flask

"""

# Import models locally just to ensure they get added to the registry on app load
from app import models

db_session = scoped_session(SessionLocal, scopefunc=_app_ctx_stack.__ident_func__)
Base.query = db_session.query_property()
Base.metadata.create_all(bind=engine)

if DB_EXTENSION_KEY not in app.extensions:
app.extensions[DB_EXTENSION_KEY] = db_session
app.db = db_session

@app.teardown_appcontext
def remove_db_session(exception=None):
"""Terminates all connections, transactions or stale, in session and checks them back into pool"""
db_session.remove()


# A little too advanced maybe? skip?
def _get_db():
"""

Returns
-------
sqlalchemy.orm.Session
session obj stored in global Flask App.
"""
if has_app_context():
assert (
DB_EXTENSION_KEY in current_app.extensions
), "`db_session` might not have been registered with the current app"
return current_app.extensions[DB_EXTENSION_KEY]
raise RuntimeError("No application context found.")


db = LocalProxy(_get_db)
24 changes: 10 additions & 14 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
from flask import Flask, _app_ctx_stack, jsonify, url_for
from flask import Flask, jsonify, url_for
from flask_cors import CORS
from sqlalchemy.orm import scoped_session

from . import models
from .database import SessionLocal, engine

models.Base.metadata.create_all(bind=engine)
from app import models
from app.database import db, init_db

app = Flask(__name__)
CORS(app)
app.session = scoped_session(SessionLocal, scopefunc=_app_ctx_stack.__ident_func__)
init_db(app)


@app.route("/")
def main():
return f"See the data at {url_for('show_records')}"
href = url_for("show_records")
return f'See the data at <a href="{href}">{href}</a>'


@app.route("/records/")
def show_records():
records = app.session.query(models.Record).all()
return jsonify([record.to_dict() for record in records])


@app.teardown_appcontext
def remove_session(*args, **kwargs):
app.session.remove()
# records = app.db.query(models.Record).all() # use db attribute
# records = models.Record.query.all() # use query_property
records = db.query(models.Record).all() # or use proxy
return jsonify([record.to_dict() for record in records])
6 changes: 4 additions & 2 deletions app/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime

from sqlalchemy import Column, Integer, String
from sqlalchemy.types import Date

from .database import Base
import datetime


class DictMixIn:
Expand All @@ -16,7 +18,7 @@ def to_dict(self):
}


class Record(Base, DictMixIn):
class Record(DictMixIn, Base):
__tablename__ = "Records"

id = Column(Integer, primary_key=True, index=True)
Expand Down