diff --git a/invenio_db/ext.py b/invenio_db/ext.py index 88f723e..981b908 100644 --- a/invenio_db/ext.py +++ b/invenio_db/ext.py @@ -13,7 +13,9 @@ import logging import os import random +import re import time +import warnings from importlib.metadata import PackageNotFoundError from importlib.metadata import version as package_version from importlib.resources import files @@ -146,13 +148,32 @@ def init_db(self, app, entry_point_group="invenio_db.models", **kwargs): app.config.setdefault("SQLALCHEMY_ECHO", False) # Needed for before/after_flush/commit/rollback events app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", True) - app.config.setdefault( - "SQLALCHEMY_ENGINE_OPTIONS", - # Ensure the database is using the UTC timezone for interpreting timestamps (Postgres only). - # This overrides any default setting (e.g. in postgresql.conf). Invenio expects the DB to receive - # and provide UTC timestamps in all cases, so it's important that this doesn't get changed. - {"connect_args": {"options": "-c timezone=UTC"}}, - ) + + # Check if the DB is PostgreSQL. We don't include the `://` since the driver name + # usually follows the `postgres` (e.g. `postgres+psycopg2`), and we don't know 100% + # what the driver will be. + is_postgres = app.config.get("SQLALCHEMY_DATABASE_URI").startswith("postgres") + if is_postgres: + current_engine_options = app.config.get("SQLALCHEMY_ENGINE_OPTIONS") + options_override = "-c timezone=UTC" + + if current_engine_options is None: + app.config.setdefault( + "SQLALCHEMY_ENGINE_OPTIONS", + {"connect_args": {"options": options_override}}, + ) + else: + options_value = current_engine_options.get("connect_args", {}).get( + "options", "" + ) + + if not re.search(rf"{re.escape(options_override)}( |$)", options_value): + warnings.warn( + "It looks like you are manually setting `SQLALCHEMY_ENGINE_OPTIONS` without specifying a UTC timezone value for PostgreSQL. " + "To avoid unexpected behaviour, InvenioDB won't add an override to these options to set the time zone to UTC. " + "Please note that PostgreSQL databases used with Invenio must be in UTC. If your database or connection is configured with a non-UTC " + "timezone, please change this before continuing to avoid unexpected behaviour." + ) # Initialize Flask-SQLAlchemy extension. database = kwargs.get("db", db) diff --git a/invenio_db/shared.py b/invenio_db/shared.py index 3101012..90e59ba 100644 --- a/invenio_db/shared.py +++ b/invenio_db/shared.py @@ -13,10 +13,7 @@ from flask_sqlalchemy import SQLAlchemy as FlaskSQLAlchemy from sqlalchemy import Column, MetaData, event, util -from sqlalchemy.engine import Engine -from sqlalchemy.sql import text from sqlalchemy.types import DateTime, TypeDecorator -from werkzeug.local import LocalProxy NAMING_CONVENTION = util.immutabledict( { @@ -129,79 +126,6 @@ def __getattr__(self, name): return super().__getattr__(name) - def apply_driver_hacks(self, app, sa_url, options): - """Call before engine creation.""" - # Don't forget to apply hacks defined on parent object. - super(SQLAlchemy, self).apply_driver_hacks(app, sa_url, options) - - if sa_url.drivername == "sqlite": - connect_args = options.setdefault("connect_args", {}) - - if "isolation_level" not in connect_args: - # disable pysqlite's emitting of the BEGIN statement entirely. - # also stops it from emitting COMMIT before any DDL. - connect_args["isolation_level"] = None - - if not event.contains(Engine, "connect", do_sqlite_connect): - event.listen(Engine, "connect", do_sqlite_connect) - if not event.contains(Engine, "begin", do_sqlite_begin): - event.listen(Engine, "begin", do_sqlite_begin) - - from sqlite3 import register_adapter - - def adapt_proxy(proxy): - """Get current object and try to adapt it again.""" - return proxy._get_current_object() - - register_adapter(LocalProxy, adapt_proxy) - - elif sa_url.drivername == "postgresql+psycopg2": # pragma: no cover - from psycopg2.extensions import adapt, register_adapter - - def adapt_proxy(proxy): - """Get current object and try to adapt it again.""" - return adapt(proxy._get_current_object()) - - register_adapter(LocalProxy, adapt_proxy) - - elif sa_url.drivername == "mysql+pymysql": # pragma: no cover - from pymysql import converters - - def escape_local_proxy(val, mapping): - """Get current object and try to adapt it again.""" - return converters.escape_item( - val._get_current_object(), - self.engine.dialect.encoding, - mapping=mapping, - ) - - converters.conversions[LocalProxy] = escape_local_proxy - converters.encoders[LocalProxy] = escape_local_proxy - - return sa_url, options - - -def do_sqlite_connect(dbapi_connection, connection_record): - """Ensure SQLite checks foreign key constraints. - - For further details see "Foreign key support" sections on - https://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support - """ - # Enable foreign key constraint checking - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - - -def do_sqlite_begin(dbapi_connection): - """Ensure SQLite transaction are started properly. - - For further details see "Foreign key support" sections on - https://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#pysqlite-serializable # noqa - """ - # emit our own BEGIN - dbapi_connection.execute(text("BEGIN")) - db = SQLAlchemy(metadata=metadata) """Shared database instance using Flask-SQLAlchemy extension.