diff --git a/.gitignore b/.gitignore index 8c8b618..3a04c87 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,6 @@ dmypy.json # AUTO DOCS docs/src/api/*.md docs/src/.vuepress/apiPages.js + +# Claude Code session files +.claude/ diff --git a/setup.cfg b/setup.cfg index 4da4e83..87bbbbd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = mayim -version = 1.1.0 +version = 1.2.0 description = The NOT ORM hydrator long_description = file: README.md long_description_content_type = text/markdown @@ -13,6 +13,9 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 [options] zip_safe = False diff --git a/src/mayim/base/executor.py b/src/mayim/base/executor.py index 4c3c502..6811aa1 100644 --- a/src/mayim/base/executor.py +++ b/src/mayim/base/executor.py @@ -209,8 +209,7 @@ def get_hydrator(self, name: Optional[str] = None) -> Hydrator: return self._hydrators.get(name, self.hydrator) @classmethod - def _load(cls, strict: bool) -> None: - ... + def _load(cls, strict: bool) -> None: ... @staticmethod def is_query_name(obj) -> bool: diff --git a/src/mayim/base/interface.py b/src/mayim/base/interface.py index 9dd203d..6859c14 100644 --- a/src/mayim/base/interface.py +++ b/src/mayim/base/interface.py @@ -30,20 +30,16 @@ def __init_subclass__(cls) -> None: BaseInterface.registered_interfaces.add(cls) @abstractmethod - def _setup_pool(self): - ... + def _setup_pool(self): ... @abstractmethod - async def open(self): - ... + async def open(self): ... @abstractmethod - async def close(self): - ... + async def close(self): ... @abstractmethod - def connection(self, timeout: Optional[float] = None): - ... + def connection(self, timeout: Optional[float] = None): ... def __init__( self, @@ -54,6 +50,8 @@ def __init__( password: Optional[str] = None, db: Optional[int] = None, query: Optional[str] = None, + min_size: int = 1, + max_size: Optional[int] = None, ) -> None: """DB class initialization. @@ -64,6 +62,8 @@ def __init__( password (str, optional): DB password db (int, optional): DB db. Defaults to 1 query (str, optional): DB query parameters. Defaults to None + min_size (int, optional): Minimum number of connections in pool. Defaults to 1 + max_size (int, optional): Maximum number of connections in pool. Defaults to None """ if dsn and host: @@ -96,6 +96,8 @@ def __init__( self._password = password self._db = db self._query = query + self._min_size = min_size + self._max_size = max_size self._full_dsn: Optional[str] = None self._connection: ContextVar[Any] = ContextVar( "connection", default=None @@ -117,10 +119,26 @@ def _populate_connection_args(self): dsn = self.dsn or "" if dsn: parts = urlparse(dsn) + # Default values for common database ports + defaults = { + "port": ( + 5432 + if "postgres" in dsn + else 3306 if "mysql" in dsn else None + ), + "hostname": "localhost", + "username": None, + "password": None, + "path": "/", + "query": "", + } for key, mapping in URLPARSE_MAPPING.items(): if not getattr(self, mapping.key): - value = getattr(parts, key, None) # or TODO: make defaults - setattr(self, mapping.key, mapping.cast(value)) + value = getattr(parts, key, None) + if value is None: + value = defaults.get(key) + if value is not None: + setattr(self, mapping.key, mapping.cast(value)) def _populate_dsn(self): self._dsn = ( @@ -172,6 +190,14 @@ def db(self): def full_dsn(self): return self._full_dsn + @property + def min_size(self): + return self._min_size + + @property + def max_size(self): + return self._max_size + def existing_connection(self): return self._connection.get() diff --git a/src/mayim/exception.py b/src/mayim/exception.py index bd5a519..119cebe 100644 --- a/src/mayim/exception.py +++ b/src/mayim/exception.py @@ -1,10 +1,7 @@ -class MayimError(Exception): - ... +class MayimError(Exception): ... -class RecordNotFound(MayimError): - ... +class RecordNotFound(MayimError): ... -class MissingSQL(MayimError): - ... +class MissingSQL(MayimError): ... diff --git a/src/mayim/extension/quart_extension.py b/src/mayim/extension/quart_extension.py index 44ce73e..0d45939 100644 --- a/src/mayim/extension/quart_extension.py +++ b/src/mayim/extension/quart_extension.py @@ -22,8 +22,7 @@ Quart = type("Quart", (), {}) # type: ignore -class Default: - ... +class Default: ... _default = Default() diff --git a/src/mayim/extension/starlette_extension.py b/src/mayim/extension/starlette_extension.py index aa8741e..31eaa9f 100644 --- a/src/mayim/extension/starlette_extension.py +++ b/src/mayim/extension/starlette_extension.py @@ -21,8 +21,7 @@ Starlette = type("Starlette", (), {}) # type: ignore -class Default: - ... +class Default: ... _default = Default() diff --git a/src/mayim/lazy/interface.py b/src/mayim/lazy/interface.py index ed1836d..d401478 100644 --- a/src/mayim/lazy/interface.py +++ b/src/mayim/lazy/interface.py @@ -7,31 +7,29 @@ class LazyPool(BaseInterface): _singleton = None _derivative: Optional[Type[BaseInterface]] + _derived_instance: Optional[BaseInterface] def __new__(cls, *args, **kwargs): if cls._singleton is None: cls._singleton = super().__new__(cls) cls._singleton._derivative = None + cls._singleton._derived_instance = None cls._singleton._derivative_dsn = "" + cls._singleton._min_size = 1 + cls._singleton._max_size = None return cls._singleton - def _setup_pool(self): - ... + def _setup_pool(self): ... - def _populate_dsn(self): - ... + def _populate_dsn(self): ... - def _populate_connection_args(self): - ... + def _populate_connection_args(self): ... - async def open(self): - ... + async def open(self): ... - async def close(self): - ... + async def close(self): ... - def connection(self, timeout: Optional[float] = None): - ... + def connection(self, timeout: Optional[float] = None): ... def set_derivative(self, interface_class: Type[BaseInterface]) -> None: self._derivative = interface_class @@ -39,7 +37,20 @@ def set_derivative(self, interface_class: Type[BaseInterface]) -> None: def set_dsn(self, dsn: str) -> None: self._derivative_dsn = dsn + def set_sizing( + self, min_size: int = 1, max_size: Optional[int] = None + ) -> None: + self._min_size = min_size + self._max_size = max_size + def derive(self) -> BaseInterface: if not self._derivative: raise MayimError("No interface available to derive") - return self._derivative(dsn=self._derivative_dsn) + if self._derived_instance: + return self._derived_instance + self._derived_instance = self._derivative( + dsn=self._derivative_dsn, + min_size=self._min_size, + max_size=self._max_size, + ) + return self._derived_instance diff --git a/src/mayim/mayim.py b/src/mayim/mayim.py index c76d2b0..bdbec67 100644 --- a/src/mayim/mayim.py +++ b/src/mayim/mayim.py @@ -1,5 +1,4 @@ from asyncio import get_running_loop -from contextlib import AsyncExitStack, asynccontextmanager from inspect import isclass from typing import Optional, Sequence, Type, TypeVar, Union from urllib.parse import urlparse @@ -8,9 +7,10 @@ from mayim.base.interface import BaseInterface from mayim.exception import MayimError from mayim.lazy.interface import LazyPool -from mayim.registry import InterfaceRegistry, Registry +from mayim.registry import InterfaceRegistry, PoolRegistry, Registry from mayim.sql.executor import SQLExecutor from mayim.sql.postgres.interface import PostgresPool +from mayim.transaction import TransactionCoordinator T = TypeVar("T", bound=Executor) DEFAULT_INTERFACE = PostgresPool @@ -42,6 +42,8 @@ def __init__( hydrator: Optional[Hydrator] = None, pool: Optional[BaseInterface] = None, strict: bool = True, + min_size: int = 1, + max_size: Optional[int] = None, ): """Initializer for Mayim instance @@ -95,8 +97,12 @@ def __init__( pool = LazyPool() pool.set_derivative(pool_type) pool.set_dsn(dsn) + pool.set_sizing(min_size, max_size) else: - pool = pool_type(dsn) + # Use PoolRegistry to ensure same DSN uses same pool + pool = PoolRegistry.get_or_create( + dsn, pool_type, min_size, max_size + ) if not executors: executors = [] @@ -241,30 +247,118 @@ async def disconnect(self) -> None: await interface.close() @classmethod - @asynccontextmanager - async def transaction( - cls, *executors: Union[SQLExecutor, Type[SQLExecutor]] + def transaction( + cls, + *executors: Union[SQLExecutor, Type[SQLExecutor]], + use_2pc: bool = False, + timeout: Optional[float] = None, ): - if not executors: + """ + Create a transaction across multiple executors. + + Can be used as either: + 1. Old style context manager: async with Mayim.transaction(exec1, exec2): + 2. New style: txn = await Mayim.transaction(exec1, exec2); await txn.begin() + 3. New style context: async with await Mayim.transaction(exec1, exec2) as txn: + + Args: + executors: Executor classes or instances to include in transaction. + If not provided, includes all registered SQL executors. + use_2pc: Whether to use two-phase commit protocol if available. + timeout: Maximum duration in seconds before transaction is automatically rolled back. + + Returns: + _TransactionWrapper that provides backward compatibility + """ + return _TransactionWrapper(cls, executors, use_2pc, timeout) + + +class _TransactionWrapper: + """ + Wrapper to provide backward compatibility for Mayim.transaction(). + Can be used both as an awaitable (new style) and as async context manager (old style). + """ + + def __init__(self, mayim_cls, executors, use_2pc=False, timeout=None): + self._mayim_cls = mayim_cls + self._executors = executors + self._coordinator = None + self._use_2pc = use_2pc + self._timeout = timeout + + def __await__(self): + """Support: txn = await Mayim.transaction(...)""" + + async def _create(): + return await self._create_coordinator() + + return _create().__await__() + + async def _create_coordinator(self) -> TransactionCoordinator: + """Create the actual TransactionCoordinator""" + if not self._executors: + # Default to all registered SQL executors executors = tuple( - ( - executor - for executor in Registry().values() - if ( - isclass(executor) and issubclass(executor, SQLExecutor) - ) - or ( - not isclass(executor) - and isinstance(executor, SQLExecutor) - ) + executor + for executor in Registry().values() + if (isclass(executor) and issubclass(executor, SQLExecutor)) + or ( + not isclass(executor) and isinstance(executor, SQLExecutor) ) ) - async with AsyncExitStack() as stack: - for maybe_executor in executors: - executor = ( - cls.get(maybe_executor) - if isclass(maybe_executor) - else maybe_executor - ) - await stack.enter_async_context(executor.transaction()) - yield + else: + executors = self._executors + + # Convert classes to instances and validate + resolved_executors = [] + for maybe_executor in executors: + if maybe_executor is None: + raise MayimError("Invalid executor: None") + + if isclass(maybe_executor): + # First check if it's a SQL executor class + if not issubclass(maybe_executor, SQLExecutor): + raise MayimError( + f"All executors must be SQL executors, got {maybe_executor}" + ) + try: + executor = self._mayim_cls.get(maybe_executor) + except MayimError: + raise MayimError( + f"Executor {maybe_executor} not registered" + ) + else: + executor = maybe_executor + # Validate it's a SQL executor instance + if not isinstance(executor, SQLExecutor): + raise MayimError( + f"All executors must be SQL executors, got {type(executor)}" + ) + # For instances, check if they're registered by checking if we can get the class + try: + registered_instance = self._mayim_cls.get( + executor.__class__ + ) + # If the registered instance is different, the passed instance is not registered + if registered_instance is not executor: + raise MayimError(f"Executor {executor} not registered") + except MayimError: + raise MayimError(f"Executor {executor} not registered") + + resolved_executors.append(executor) + + # Return the transaction coordinator + return TransactionCoordinator( + resolved_executors, use_2pc=self._use_2pc, timeout=self._timeout + ) + + async def __aenter__(self): + """Support old style: async with Mayim.transaction(...)""" + self._coordinator = await self._create_coordinator() + await self._coordinator.begin() + return self._coordinator + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Handle old style context manager exit""" + if self._coordinator: + return await self._coordinator.__aexit__(exc_type, exc_val, exc_tb) diff --git a/src/mayim/registry.py b/src/mayim/registry.py index 76a3eb1..6e0e7b0 100644 --- a/src/mayim/registry.py +++ b/src/mayim/registry.py @@ -98,3 +98,57 @@ def get(cls, class_name: str, method_name: str) -> Optional[Hydrator]: def reset(cls): cls._singleton = super().__new__(cls) cls._singleton._hydrators = defaultdict(dict) + + +class PoolRegistry: + """ + Registry to ensure executors with the same DSN share the same pool instance. + This prevents duplicate connections and enables proper transaction coordination. + """ + + _singleton = None + _pools: Dict[str, BaseInterface] + + def __new__(cls, *args, **kwargs): + if cls._singleton is None: + cls.reset() + return cls._singleton + + @classmethod + def get_or_create( + cls, + dsn: str, + pool_class: Type[BaseInterface], + min_size: int = 1, + max_size: Optional[int] = None, + ) -> BaseInterface: + """ + Get existing pool or create new one for DSN. + + Args: + dsn: Database connection string + pool_class: Class to use for creating new pool + min_size: Minimum number of connections in pool + max_size: Maximum number of connections in pool + + Returns: + Shared pool instance for the DSN + """ + instance = cls() + if dsn not in instance._pools: + instance._pools[dsn] = pool_class( + dsn, min_size=min_size, max_size=max_size + ) + return instance._pools[dsn] + + @classmethod + def get(cls, dsn: str) -> Optional[BaseInterface]: + """Get pool for DSN if it exists""" + instance = cls() + return instance._pools.get(dsn) + + @classmethod + def reset(cls): + """Reset the registry (useful for testing)""" + cls._singleton = super().__new__(cls) + cls._singleton._pools = {} diff --git a/src/mayim/sql/executor.py b/src/mayim/sql/executor.py index e64181b..690df98 100644 --- a/src/mayim/sql/executor.py +++ b/src/mayim/sql/executor.py @@ -153,10 +153,22 @@ async def _run_sql( no_result: bool = False, posargs: Optional[Sequence[Any]] = None, params: Optional[Dict[str, Any]] = None, - ): - ... + ): ... async def rollback(self, *, silent: bool = False) -> None: + # Check if we're part of a global transaction + from mayim.transaction import get_global_transaction + + global_context = get_global_transaction() + if global_context and self.pool in global_context.connections: + # Mark for rollback in global context + global_context.commit_flags[self.pool] = False + if not silent: + # Could optionally trigger immediate global rollback + await global_context.rollback_all() + return + + # Otherwise, proceed with single-executor rollback existing = self.pool.existing_connection() transaction = self.pool.in_transaction() if not existing or not transaction: @@ -174,6 +186,16 @@ def _get_method(self, as_list: bool) -> str: @asynccontextmanager async def transaction(self): + # Check if we're part of a global transaction + from mayim.transaction import get_global_transaction + + global_context = get_global_transaction() + if global_context and self.pool in global_context.connections: + # We're already in a global transaction, just yield + yield + return + + # Otherwise, proceed with single-executor transaction as before self.pool._transaction.set(True) async with self.pool.connection() as conn: self.pool._connection.set(conn) diff --git a/src/mayim/sql/postgres/interface.py b/src/mayim/sql/postgres/interface.py index 3211ad7..d26455e 100644 --- a/src/mayim/sql/postgres/interface.py +++ b/src/mayim/sql/postgres/interface.py @@ -20,13 +20,27 @@ class PostgresPool(BaseInterface): scheme = "postgres" + def _populate_dsn(self): + if not self._query: + self._query = "" + if "application_name" not in self._query: + if self._query: + self._query += "&" + self._query += "application_name=mayim" + super()._populate_dsn() + def _setup_pool(self): if not POSTGRES_ENABLED: raise MayimError( "Postgres driver not found. Try reinstalling Mayim: " "pip install mayim[postgres]" ) - self._pool = AsyncConnectionPool(self.full_dsn) + self._pool = AsyncConnectionPool( + self.full_dsn, + min_size=self.min_size, + max_size=self.max_size, + open=False, + ) async def open(self): """Open connections to the pool""" diff --git a/src/mayim/transaction.py b/src/mayim/transaction.py new file mode 100644 index 0000000..3658a76 --- /dev/null +++ b/src/mayim/transaction.py @@ -0,0 +1,419 @@ +""" +Transaction coordination module for Mayim. +Provides global transaction management across multiple executors. +""" + +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack, asynccontextmanager +from contextvars import ContextVar +from enum import Enum +from inspect import isclass +from time import time +from typing import Any, Dict, List, Optional, Set, Type, Union + +from mayim.base.interface import BaseInterface +from mayim.exception import MayimError + +# Global transaction context that all executors can access +_global_transaction: ContextVar[Optional["GlobalTransactionContext"]] = ( + ContextVar("global_transaction", default=None) +) + + +class TransactionState(Enum): + """Transaction state machine states""" + + PENDING = "pending" # Created but not started + ACTIVE = "active" # Transaction has begun + COMMITTED = "committed" # Transaction committed successfully + ROLLED_BACK = "rolled_back" # Transaction was rolled back + + +class GlobalTransactionContext: + """ + Manages a global transaction context across multiple executors. + Ensures all executors sharing the same pool use the same connection. + """ + + def __init__(self): + # Map from pool instance to acquired connection + self.connections: Dict[BaseInterface, Any] = {} + # Map from pool instance to executors using it + self.pool_executors: Dict[BaseInterface, Set[Any]] = {} + # Track which pools should commit + self.commit_flags: Dict[BaseInterface, bool] = {} + # Stack for cleanup + self.stack = AsyncExitStack() + # Transaction state + self.state = TransactionState.PENDING + # Two-phase commit support + self.prepared_pools: Set[BaseInterface] = set() + self.supports_2pc = False + + async def add_executor(self, executor): + """Register an executor in this transaction""" + from mayim.sql.executor import SQLExecutor + + if not isinstance(executor, SQLExecutor): + raise MayimError( + f"Only SQL executors can participate in transactions, got {type(executor)}" + ) + + pool = executor.pool + + # Track this executor + if pool not in self.pool_executors: + self.pool_executors[pool] = set() + self.pool_executors[pool].add(executor) + + # If we haven't acquired a connection for this pool yet, do so + if pool not in self.connections: + # Acquire a connection and keep it for all executors using this pool + conn = await self.stack.enter_async_context(pool.connection()) + self.connections[pool] = conn + # Set the connection in the pool's context var so all executors see it + pool._connection.set(conn) + pool._transaction.set(True) + self.commit_flags[pool] = True # Default to commit + + def get_connection(self, pool: BaseInterface) -> Any: + """Get the connection for a given pool""" + return self.connections.get(pool) + + async def prepare_all(self) -> bool: + """Prepare all connections for commit (2PC phase 1)""" + self.prepared_pools.clear() + + for pool, conn in self.connections.items(): + try: + # Check if connection supports prepare + if hasattr(conn, "prepare"): + await conn.prepare() + self.prepared_pools.add(pool) + else: + # If any connection doesn't support 2PC, we can't use it + return False + except Exception as e: + # Prepare failed, rollback what we can + await self.rollback_all() + raise MayimError(f"Transaction prepare failed: {e}") from e + + return len(self.prepared_pools) == len(self.connections) + + async def rollback_all(self): + """Rollback all connections""" + for pool, conn in self.connections.items(): + pool._commit.set(False) + try: + # Check if connection has rollback method + if hasattr(conn, "rollback"): + await conn.rollback() + except Exception: + pass # Best effort rollback + + async def commit_all(self, prepared: bool = False): + """Commit all connections that should be committed + + Args: + prepared: Whether connections have already been prepared (2PC phase 2) + """ + for pool, conn in self.connections.items(): + if self.commit_flags.get(pool, True): + try: + # If we're in 2PC and this pool was prepared + if prepared and pool in self.prepared_pools: + # Use prepared commit if available + if hasattr(conn, "commit_prepared"): + await conn.commit_prepared() + else: + await conn.commit() + else: + # Regular commit + if hasattr(conn, "commit"): + await conn.commit() + except Exception as e: + # If any commit fails, rollback all + await self.rollback_all() + raise MayimError(f"Transaction commit failed: {e}") from e + + async def cleanup(self): + """Clean up all connections and state""" + for pool in self.connections: + pool._connection.set(None) + pool._transaction.set(False) + pool._commit.set(True) + + +class TransactionCoordinator: + """ + Coordinates transactions across multiple executors. + Supports both context manager and explicit transaction control. + """ + + def __init__( + self, + executors: List[Union[Type, Any]], + use_2pc: bool = False, + timeout: Optional[float] = None, + ): + """ + Initialize transaction coordinator. + + Args: + executors: List of executor classes or instances to include in transaction + use_2pc: Whether to use two-phase commit protocol if available + timeout: Maximum duration in seconds before transaction is automatically rolled back + """ + self._executors = executors + self._context: Optional[GlobalTransactionContext] = None + self._token = None + self._state = TransactionState.PENDING + self._stack: Optional[AsyncExitStack] = None + self._use_2pc = use_2pc + self._prepared = False + self._timeout = timeout + self._start_time: Optional[float] = None + self._timeout_task: Optional[asyncio.Task] = None + + @property + def is_active(self) -> bool: + """Check if transaction is currently active""" + return self._state == TransactionState.ACTIVE + + @property + def is_committed(self) -> bool: + """Check if transaction has been committed""" + return self._state == TransactionState.COMMITTED + + @property + def is_rolled_back(self) -> bool: + """Check if transaction has been rolled back""" + return self._state == TransactionState.ROLLED_BACK + + @property + def executors(self) -> List: + """Get list of executors in this transaction""" + return self._executors + + async def begin(self): + """ + Begin the transaction explicitly. + Sets up connections and marks transaction as active. + """ + if self._state == TransactionState.ACTIVE: + raise MayimError("Transaction already active") + if self._state in ( + TransactionState.COMMITTED, + TransactionState.ROLLED_BACK, + ): + raise MayimError("Transaction already completed") + + # Create global transaction context + self._context = GlobalTransactionContext() + self._token = _global_transaction.set(self._context) + self._stack = AsyncExitStack() + + try: + await self._stack.__aenter__() + self._context.stack = self._stack + + # Register all executors and acquire connections + for executor in self._executors: + await self._context.add_executor(executor) + + self._state = TransactionState.ACTIVE + self._context.state = TransactionState.ACTIVE + + # Start timeout tracking + self._start_time = time() + if self._timeout is not None: + self._timeout_task = asyncio.create_task( + self._timeout_monitor() + ) + + except Exception: + # Clean up on failure + await self._cleanup() + raise + + async def prepare_all(self) -> bool: + """ + Prepare all connections for commit (2PC phase 1). + + Returns: + True if all connections were prepared successfully + """ + if self._state != TransactionState.ACTIVE: + raise MayimError("Transaction must be active to prepare") + + if self._context: + self._prepared = await self._context.prepare_all() + return self._prepared + return False + + async def commit(self): + """ + Commit the transaction. + Commits all connections and marks transaction as committed. + """ + if self._state == TransactionState.PENDING: + raise MayimError("Transaction not active") + + # Check for timeout first + if ( + self._check_timeout() + or self._state == TransactionState.ROLLED_BACK + ): + if self._state != TransactionState.ROLLED_BACK: + await self.rollback() + raise MayimError("Transaction timed out") + + if self._state in ( + TransactionState.COMMITTED, + TransactionState.ROLLED_BACK, + ): + raise MayimError("Transaction already completed") + if self._state != TransactionState.ACTIVE: + raise MayimError("Transaction not active") + + try: + # If we're using 2PC and haven't prepared yet, do it now + if self._use_2pc and not self._prepared: + await self.prepare_all() + + # Commit all connections + await self._context.commit_all(prepared=self._prepared) + self._state = TransactionState.COMMITTED + self._context.state = TransactionState.COMMITTED + finally: + # Clean up + await self._cleanup() + + async def rollback(self): + """ + Rollback the transaction. + Rolls back all connections and marks transaction as rolled back. + """ + if self._state == TransactionState.PENDING: + raise MayimError("Transaction not active") + if self._state in ( + TransactionState.COMMITTED, + TransactionState.ROLLED_BACK, + ): + raise MayimError("Transaction already completed") + + try: + # Rollback through executors (for compatibility with existing tests) + for executor in self._executors: + await executor.rollback(silent=True) + + self._state = TransactionState.ROLLED_BACK + if self._context: + self._context.state = TransactionState.ROLLED_BACK + finally: + # Clean up + await self._cleanup() + + async def _timeout_monitor(self): + """Monitor transaction timeout and auto-rollback if exceeded""" + try: + await asyncio.sleep(self._timeout) + # Timeout exceeded, force rollback + if self._state == TransactionState.ACTIVE: + self._state = TransactionState.ROLLED_BACK + if self._context: + self._context.state = TransactionState.ROLLED_BACK + except asyncio.CancelledError: + # Normal case - transaction completed before timeout + pass + + def _check_timeout(self): + """Check if transaction has timed out""" + if self._timeout is not None and self._start_time is not None: + elapsed = time() - self._start_time + if elapsed > self._timeout: + return True + return False + + async def _cleanup(self): + """Clean up transaction resources""" + # Cancel timeout task + if self._timeout_task and not self._timeout_task.done(): + self._timeout_task.cancel() + try: + await self._timeout_task + except asyncio.CancelledError: + pass + + if self._context: + await self._context.cleanup() + + if self._stack: + await self._stack.__aexit__(None, None, None) + + if self._token: + try: + _global_transaction.reset(self._token) + except ValueError: + # Token was reset in different context (e.g., timeout task) + pass + self._token = None + + async def __aenter__(self): + """Context manager entry - automatically begins transaction""" + await self.begin() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - commits or rolls back based on exception""" + if self._state != TransactionState.ACTIVE: + return False + + if exc_type is None: + # No exception, commit + try: + await self.commit() + except Exception: + # Commit failed, ensure cleanup + if self._state == TransactionState.ACTIVE: + await self._cleanup() + self._state = TransactionState.ROLLED_BACK + raise + else: + # Exception occurred, rollback + try: + await self.rollback() + except Exception: + # Rollback failed, ensure cleanup + if self._state == TransactionState.ACTIVE: + await self._cleanup() + self._state = TransactionState.ROLLED_BACK + + return False # Don't suppress the exception + + def get_metrics(self) -> Dict[str, Any]: + """ + Get transaction metrics (optional feature). + + Returns: + Dictionary containing transaction metrics + """ + return { + "state": self._state.value, + "executor_count": len(self._executors), + "pool_count": ( + len(self._context.connections) if self._context else 0 + ), + } + + +def get_global_transaction() -> Optional[GlobalTransactionContext]: + """ + Get the current global transaction context if one exists. + + Returns: + The current GlobalTransactionContext or None + """ + return _global_transaction.get() diff --git a/tests/conftest.py b/tests/conftest.py index 4218242..5da1da6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ InterfaceRegistry, LazyHydratorRegistry, LazyQueryRegistry, + PoolRegistry, Registry, ) from mayim.sql.postgres import interface @@ -40,6 +41,7 @@ def reset_registry(): InterfaceRegistry().reset() LazyQueryRegistry().reset() LazyHydratorRegistry().reset() + PoolRegistry().reset() @pytest.fixture @@ -70,8 +72,7 @@ def mock_postgres_pool(monkeypatch, postgres_connection_context): @pytest.fixture def FooExecutor(): class FooExecutor(Executor): - async def select_something(self) -> Foo: - ... + async def select_something(self) -> Foo: ... @classmethod def _load(cls, _): @@ -99,8 +100,7 @@ def ItemExecutor(Item): class ItemExecutor(PostgresExecutor): @query(single_query) - async def select_otheritem(self, item_id: int) -> Item: - ... + async def select_otheritem(self, item_id: int) -> Item: ... async def select_otheritem_execute(self, item_id: int) -> Item: return await self.execute( @@ -122,8 +122,7 @@ async def select_item_named(self, item_id: int): ) @query(single_query) - async def select_int(self, item_id: int) -> int: - ... + async def select_int(self, item_id: int) -> int: ... async def select_int_execute(self, item_id: int) -> int: return await self.execute( @@ -131,15 +130,15 @@ async def select_int_execute(self, item_id: int) -> int: ) @query(multiple_query) - async def select_otheritems(self) -> List[Item]: - ... + async def select_otheritems(self) -> List[Item]: ... async def select_otheritems_execute(self) -> List[Item]: return await self.execute(query=multiple_query, as_list=True) @query(single_query) - async def select_optional_item(self, item_id: int) -> Optional[Item]: - ... + async def select_optional_item( + self, item_id: int + ) -> Optional[Item]: ... async def select_optional_item_execute( self, item_id: int @@ -151,8 +150,7 @@ async def select_optional_item_execute( ) @query(multiple_query) - async def select_optional_items(self) -> Optional[List[Item]]: - ... + async def select_optional_items(self) -> Optional[List[Item]]: ... async def select_optional_items_execute(self) -> Optional[List[Item]]: return await self.execute( @@ -160,12 +158,10 @@ async def select_optional_items_execute(self) -> Optional[List[Item]]: ) @query(update_query) - async def update_item_empty(self, item_id: int, name: str): - ... + async def update_item_empty(self, item_id: int, name: str): ... @query(update_query) - async def update_item_none(self, item_id: int, name: str) -> None: - ... + async def update_item_none(self, item_id: int, name: str) -> None: ... async def update_item_empty_execute(self, item_id: int, name: str): await self.execute( @@ -180,8 +176,7 @@ async def update_item_none_execute( ) @query(single_query_positional) - async def select_otheritem_positional(self, item_id: int) -> Item: - ... + async def select_otheritem_positional(self, item_id: int) -> Item: ... async def select_otheritem_positional_execute( self, item_id: int diff --git a/tests/test_executor.py b/tests/test_executor.py index 74fcca7..878cabb 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -196,8 +196,7 @@ async def test_empty_result_none_union(postgres_connection): class ItemExecutor(PostgresExecutor): @query("SELECT * FROM otheritems") - async def select_otheritems(self) -> int | None: - ... + async def select_otheritems(self) -> int | None: ... Mayim(executors=[ItemExecutor], dsn="foo://user:password@host:1234/db") executor = Mayim.get(ItemExecutor) @@ -208,8 +207,7 @@ async def select_otheritems(self) -> int | None: def test_missing_sql_not_strict(): class FooExecutor(PostgresExecutor): - async def select_missing(self) -> int: - ... + async def select_missing(self) -> int: ... Mayim( executors=[FooExecutor()], @@ -220,8 +218,7 @@ async def select_missing(self) -> int: def test_missing_sql_strict(): class FooExecutor(PostgresExecutor): - async def select_missing(self) -> int: - ... + async def select_missing(self) -> int: ... message = re.escape( "Could not find SQL for FooExecutor.select_missing. " diff --git a/tests/test_hydrator.py b/tests/test_hydrator.py index e8e4406..d7578be 100644 --- a/tests/test_hydrator.py +++ b/tests/test_hydrator.py @@ -2,12 +2,10 @@ from mayim.base.hydrator import Hydrator -class HydratorA(Hydrator): - ... +class HydratorA(Hydrator): ... -class HydratorB(Hydrator): - ... +class HydratorB(Hydrator): ... async def test_get_hydrator_by_name(): diff --git a/tests/test_sql_loading.py b/tests/test_sql_loading.py index 547327e..48eed9e 100644 --- a/tests/test_sql_loading.py +++ b/tests/test_sql_loading.py @@ -24,8 +24,9 @@ async def test_auto_load_keyword(postgres_connection): postgres_connection.result = {"item_id": 99, "name": "thing"} class ItemExecutor(PostgresExecutor): - async def select_items(self, limit: int = 4, offset: int = 0) -> Item: - ... + async def select_items( + self, limit: int = 4, offset: int = 0 + ) -> Item: ... Mayim(executors=[ItemExecutor], dsn="foo://user:password@host:1234/db") executor = Mayim.get(ItemExecutor) @@ -51,8 +52,9 @@ class ItemExecutor(PostgresExecutor): LIMIT $limit OFFSET $offset; """ ) - async def select_items(self, limit: int = 4, offset: int = 0) -> Item: - ... + async def select_items( + self, limit: int = 4, offset: int = 0 + ) -> Item: ... Mayim(executors=[ItemExecutor], dsn="foo://user:password@host:1234/db") executor = Mayim.get(ItemExecutor) @@ -75,8 +77,7 @@ async def test_auto_load_positional(postgres_connection): class ItemExecutor(PostgresExecutor): async def select_items_numbered( self, limit: int = 4, offset: int = 0 - ) -> Item: - ... + ) -> Item: ... Mayim(executors=[ItemExecutor], dsn="foo://user:password@host:1234/db") executor = Mayim.get(ItemExecutor) @@ -107,8 +108,7 @@ class ItemExecutor(PostgresExecutor): ) async def select_items_numbered( self, limit: int = 4, offset: int = 0 - ) -> Item: - ... + ) -> Item: ... Mayim(executors=[ItemExecutor], dsn="foo://user:password@host:1234/db") executor = Mayim.get(ItemExecutor) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 262b158..68c24d4 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,9 +1,180 @@ -from unittest.mock import AsyncMock +import asyncio +from contextlib import asynccontextmanager +from dataclasses import dataclass +from unittest.mock import AsyncMock, MagicMock import pytest -from mayim import Mayim +from mayim import Mayim, MysqlExecutor, PostgresExecutor, SQLiteExecutor, query from mayim.exception import MayimError +from mayim.registry import Registry + + +@dataclass +class FoobarModel: + id: int + value: str + + +@dataclass +class User: + id: int + name: str + email: str + + +@dataclass +class Order: + id: int + user_id: int + total: float + + +@dataclass +class Account: + id: int + balance: float + owner: str + + +class FoobarExecutor(PostgresExecutor): + @query("INSERT INTO test (value) VALUES ($value) RETURNING *") + async def insert_test(self, value: str) -> FoobarModel: ... + + @query("UPDATE test SET value = $value WHERE id = $id") + async def update_test(self, id: int, value: str) -> None: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + +class UserExecutor(PostgresExecutor): + @query( + "INSERT INTO users (name, email) VALUES ($name, $email) RETURNING *" + ) + async def create_user(self, name: str, email: str) -> User: ... + + @query("UPDATE users SET name = $name WHERE id = $id") + async def update_user(self, id: int, name: str) -> None: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + +class OrderExecutor(PostgresExecutor): + @query( + "INSERT INTO orders (user_id, total) VALUES ($user_id, $total) RETURNING *" + ) + async def create_order(self, user_id: int, total: float) -> Order: ... + + @query("UPDATE orders SET total = $total WHERE id = $id") + async def update_order(self, id: int, total: float) -> None: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + +class PostgresAccountExecutor(PostgresExecutor): + @query("SELECT * FROM accounts WHERE id = $account_id FOR UPDATE") + async def lock_account(self, account_id: int) -> Account: ... + + @query("UPDATE accounts SET balance = $balance WHERE id = $account_id") + async def update_balance( + self, account_id: int, balance: float + ) -> None: ... + + @query( + "INSERT INTO transfers (from_id, to_id, amount) VALUES ($from_id, $to_id, $amount)" + ) + async def record_transfer( + self, from_id: int, to_id: int, amount: float + ) -> None: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + +class MysqlAccountExecutor(MysqlExecutor): + @query("SELECT * FROM accounts WHERE id = %s FOR UPDATE") + async def lock_account(self, account_id: int) -> Account: ... + + @query("UPDATE accounts SET balance = %s WHERE id = %s") + async def update_balance( + self, balance: float, account_id: int + ) -> None: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + +class SQLiteInventoryExecutor(SQLiteExecutor): + @query("UPDATE inventory SET quantity = quantity - ? WHERE product_id = ?") + async def reduce_inventory( + self, quantity: int, product_id: int + ) -> None: ... + + @query("SELECT quantity FROM inventory WHERE product_id = ?") + async def get_quantity(self, product_id: int) -> int: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + +@pytest.fixture +def mock_pools(): + @asynccontextmanager + async def mock_connection(*args, **kwargs): + conn = AsyncMock() + conn.execute = AsyncMock() + conn.rollback = AsyncMock() + conn.commit = AsyncMock() + yield conn + + def create_mock_pool(): + pool = MagicMock() + pool.connection = mock_connection + pool._connection = MagicMock() + pool._connection.get = MagicMock(return_value=None) + pool._connection.set = MagicMock() + pool._transaction = MagicMock() + pool._transaction.get = MagicMock(return_value=False) + pool._transaction.set = MagicMock() + pool._commit = MagicMock() + pool._commit.get = MagicMock(return_value=True) + pool._commit.set = MagicMock() + pool.in_transaction = MagicMock(return_value=False) + pool.existing_connection = MagicMock(return_value=None) + return pool + + return { + "postgres": create_mock_pool(), + "mysql": create_mock_pool(), + "sqlite": create_mock_pool(), + } async def test_transaction(postgres_connection, item_executor): @@ -81,3 +252,969 @@ async def test_global_transaction( ... postgres_connection.rollback.assert_not_called() mock.assert_called_once_with(silent=True) + + +async def test_two_phase_commit_simulation(): + user_exec = UserExecutor() + order_exec = OrderExecutor() + Mayim(executors=[user_exec, order_exec], dsn="postgres://localhost/test") + + user_exec.create_user = AsyncMock( + return_value=User(id=1, name="Alice", email="alice@example.com") + ) + order_exec.create_order = AsyncMock( + return_value=Order(id=1, user_id=1, total=100.0) + ) + + user_exec.pool.existing_connection = MagicMock(return_value=AsyncMock()) + order_exec.pool.existing_connection = MagicMock(return_value=AsyncMock()) + + async with Mayim.transaction(user_exec, order_exec, use_2pc=True): + user = await user_exec.create_user( + name="Alice", email="alice@example.com" + ) + order = await order_exec.create_order(user_id=user.id, total=100.0) + + user_exec.create_user.assert_called_once_with( + name="Alice", email="alice@example.com" + ) + order_exec.create_order.assert_called_once_with(user_id=1, total=100.0) + + +async def test_transaction_isolation_between_executors(): + user_exec = UserExecutor() + order_exec = OrderExecutor() + + Mayim(executors=[user_exec, order_exec], dsn="postgres://localhost/test") + + user_exec.create_user = AsyncMock( + return_value=User(id=1, name="Alice", email="alice@example.com") + ) + order_exec.create_order = AsyncMock( + return_value=Order(id=1, user_id=1, total=100.0) + ) + + try: + async with Mayim.transaction(user_exec, order_exec): + user = await user_exec.create_user( + name="Alice", email="alice@example.com" + ) + raise Exception("Simulated failure after user creation") + except Exception: + pass + + user_exec.create_user.assert_called_once() + + +async def test_nested_executor_in_global_transaction(): + user_exec = UserExecutor() + order_exec = OrderExecutor() + + Mayim(executors=[user_exec, order_exec], dsn="postgres://localhost/test") + + user_exec.create_user = AsyncMock( + return_value=User(id=1, name="Alice", email="alice@example.com") + ) + order_exec.create_order = AsyncMock( + return_value=Order(id=1, user_id=1, total=100.0) + ) + + async with Mayim.transaction(user_exec, order_exec): + user = await user_exec.create_user( + name="Alice", email="alice@example.com" + ) + async with user_exec.transaction(): + order = await order_exec.create_order(user_id=user.id, total=100.0) + + user_exec.create_user.assert_called_once() + order_exec.create_order.assert_called_once() + + +async def test_explicit_api_basic_usage(postgres_connection): + Mayim( + executors=[FoobarExecutor], + dsn="postgres://user:pass@localhost:5432/test", + ) + executor = Mayim.get(FoobarExecutor) + + txn = await Mayim.transaction(FoobarExecutor) + + assert not txn.is_active + assert not txn.is_committed + assert not txn.is_rolled_back + + await txn.begin() + assert txn.is_active + + postgres_connection.result = {"id": 1, "value": "test1"} + result = await executor.insert_test(value="test1") + assert txn.is_active + + await txn.commit() + assert not txn.is_active + assert txn.is_committed + + +async def test_explicit_api_with_multiple_executors(): + class AnotherExecutor(PostgresExecutor): + @query("INSERT INTO another (data) VALUES ($data)") + async def insert_another(self, data: str) -> None: ... + + Mayim( + executors=[FoobarExecutor, AnotherExecutor], + dsn="postgres://localhost/test", + ) + + txn = await Mayim.transaction(FoobarExecutor, AnotherExecutor) + await txn.begin() + + test_exec = Mayim.get(FoobarExecutor) + another_exec = Mayim.get(AnotherExecutor) + + await test_exec.insert_test(value="test") + await another_exec.insert_another(data="data") + + await txn.commit() + assert txn.is_committed + + +async def test_explicit_api_rollback_behavior(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction(FoobarExecutor) + await txn.begin() + + executor = Mayim.get(FoobarExecutor) + await executor.insert_test(value="will_be_rolled_back") + + await txn.rollback() + + assert not txn.is_active + assert txn.is_rolled_back + assert not txn.is_committed + + +async def test_explicit_api_context_manager_compatibility(postgres_connection): + Mayim( + executors=[FoobarExecutor], + dsn="postgres://user:pass@localhost:5432/test", + ) + executor = Mayim.get(FoobarExecutor) + + txn1 = await Mayim.transaction(FoobarExecutor) + await txn1.begin() + postgres_connection.result = {"id": 1, "value": "explicit"} + await executor.insert_test(value="explicit") + await txn1.commit() + assert txn1.is_committed + + txn2 = await Mayim.transaction(FoobarExecutor) + async with txn2: + postgres_connection.result = {"id": 2, "value": "context_manager"} + await executor.insert_test(value="context_manager") + assert txn2.is_committed + + txn3 = await Mayim.transaction(FoobarExecutor) + async with txn3: + postgres_connection.result = {"id": 3, "value": "direct_context"} + await executor.insert_test(value="direct_context") + assert txn3.is_active + assert txn3.is_committed + + +async def test_explicit_api_error_handling(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction(FoobarExecutor) + + with pytest.raises(MayimError, match="Transaction not active"): + await txn.commit() + + with pytest.raises(MayimError, match="Transaction not active"): + await txn.rollback() + + await txn.begin() + + with pytest.raises(MayimError, match="Transaction already active"): + await txn.begin() + + await txn.commit() + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.commit() + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.rollback() + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.begin() + + +async def test_explicit_api_with_executor_instances(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + + txn1 = await Mayim.transaction(FoobarExecutor) + await txn1.begin() + await txn1.commit() + + executor_instance = Mayim.get(FoobarExecutor) + txn2 = await Mayim.transaction(executor_instance) + await txn2.begin() + await txn2.commit() + + class AnotherExecutor(PostgresExecutor): + pass + + Mayim(executors=[AnotherExecutor], dsn="postgres://localhost/test") + another_instance = Mayim.get(AnotherExecutor) + + txn3 = await Mayim.transaction(FoobarExecutor, another_instance) + await txn3.begin() + await txn3.commit() + + +async def test_explicit_api_all_executors_default(): + class Exec1(PostgresExecutor): + pass + + class Exec2(PostgresExecutor): + pass + + Mayim(executors=[Exec1, Exec2], dsn="postgres://localhost/test") + + txn = await Mayim.transaction() + await txn.begin() + + assert len(txn.executors) == 2 + assert Exec1 in [ + type(e) if not isinstance(e, type) else e for e in txn.executors + ] + assert Exec2 in [ + type(e) if not isinstance(e, type) else e for e in txn.executors + ] + + await txn.commit() + + +async def test_explicit_api_with_nested_executor_transactions(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction(FoobarExecutor) + await txn.begin() + + executor = Mayim.get(FoobarExecutor) + + async with executor.transaction(): + await executor.insert_test(value="nested") + + await txn.commit() + + +async def test_explicit_api_connection_sharing_verification( + postgres_connection, +): + class Exec1(PostgresExecutor): + @query("SELECT 1") + async def test_query1(self) -> int: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + class Exec2(PostgresExecutor): + @query("SELECT 2") + async def test_query2(self) -> int: ... + + @classmethod + def _load(cls, strict: bool) -> None: + if not hasattr(cls, "_loaded") or not cls._loaded: + cls._queries = {} + cls._hydrators = {} + cls._loaded = True + + Mayim(executors=[Exec1, Exec2], dsn="postgres://localhost/test") + + connections_used = set() + + @asynccontextmanager + async def track_connection(*args, **kwargs): + conn = AsyncMock() + conn.execute = AsyncMock(return_value=postgres_connection) + conn.rollback = AsyncMock() + conn.commit = AsyncMock() + conn_id = id(conn) + connections_used.add(conn_id) + yield conn + + exec1 = Mayim.get(Exec1) + exec2 = Mayim.get(Exec2) + + exec1.pool.connection = track_connection + exec2.pool.connection = track_connection + + txn = await Mayim.transaction(Exec1, Exec2) + await txn.begin() + + postgres_connection.result = {"?column?": 1} + await exec1.test_query1() + + postgres_connection.result = {"?column?": 2} + await exec2.test_query2() + + await txn.commit() + + assert len(connections_used) == 1 + + +async def test_deadlock_detection_and_recovery(mock_pools): + postgres_exec = PostgresAccountExecutor(pool=mock_pools["postgres"]) + mysql_exec = MysqlAccountExecutor(pool=mock_pools["mysql"]) + + Mayim(executors=[postgres_exec, mysql_exec]) + + call_count = [0] + + async def mock_lock_account(account_id): + call_count[0] += 1 + if call_count[0] >= 3: + raise MayimError( + "Deadlock detected: Transaction was chosen as deadlock victim" + ) + return Account(id=account_id, balance=1000, owner="Test") + + postgres_exec.lock_account = AsyncMock(side_effect=mock_lock_account) + mysql_exec.lock_account = AsyncMock(side_effect=mock_lock_account) + + error_raised = None + try: + async with Mayim.transaction(postgres_exec, mysql_exec): + await postgres_exec.lock_account(account_id=1) + await mysql_exec.lock_account(account_id=2) + await postgres_exec.lock_account(account_id=3) + except Exception as e: + error_raised = e + + assert error_raised is not None + assert "deadlock" in str(error_raised).lower() + + +async def test_bank_transfer_acid_compliance(): + mayim = Mayim( + executors=[PostgresAccountExecutor], dsn="postgres://localhost/bank" + ) + executor = mayim.get(PostgresAccountExecutor) + + account1_balance = 1000.0 + account2_balance = 500.0 + transfer_amount = 300.0 + + executor.lock_account = AsyncMock( + side_effect=[ + Account(id=1, balance=account1_balance, owner="Alice"), + Account(id=2, balance=account2_balance, owner="Bob"), + ] + ) + executor.update_balance = AsyncMock() + executor.record_transfer = AsyncMock() + + call_count = 0 + original_update = executor.update_balance + + async def failing_update(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise Exception("Database connection lost") + return await original_update(*args, **kwargs) + + executor.update_balance = failing_update + + try: + async with executor.transaction(): + acc1 = await executor.lock_account(account_id=1) + acc2 = await executor.lock_account(account_id=2) + + if acc1.balance < transfer_amount: + raise ValueError("Insufficient funds") + + await executor.update_balance( + account_id=1, balance=acc1.balance - transfer_amount + ) + await executor.update_balance( + account_id=2, balance=acc2.balance + transfer_amount + ) + await executor.record_transfer( + from_id=1, to_id=2, amount=transfer_amount + ) + except Exception: + pass + + assert call_count == 2 + + +async def test_mixed_database_transaction(mock_pools): + postgres_exec = PostgresAccountExecutor(pool=mock_pools["postgres"]) + sqlite_exec = SQLiteInventoryExecutor(pool=mock_pools["sqlite"]) + + Mayim(executors=[postgres_exec, sqlite_exec]) + + postgres_exec.update_balance = AsyncMock() + sqlite_exec.reduce_inventory = AsyncMock() + + async with Mayim.transaction(postgres_exec, sqlite_exec): + await postgres_exec.update_balance(account_id=1, balance=100) + await sqlite_exec.reduce_inventory(quantity=5, product_id=1) + + postgres_exec.update_balance.assert_called_once_with( + account_id=1, balance=100 + ) + sqlite_exec.reduce_inventory.assert_called_once_with( + quantity=5, product_id=1 + ) + + +async def test_transaction_timeout(mock_pools): + executor = PostgresAccountExecutor(pool=mock_pools["postgres"]) + executor.lock_account = AsyncMock( + return_value=Account(id=1, balance=1000, owner="Test") + ) + + Mayim(executors=[executor]) + + txn = await Mayim.transaction(executor, timeout=0.1) + await txn.begin() + + await executor.lock_account(account_id=1) + + await asyncio.sleep(0.2) + + with pytest.raises(MayimError, match="Transaction timed out"): + await txn.commit() + + assert txn.is_rolled_back + + +async def test_nested_transactions(mock_pools): + executor = PostgresAccountExecutor(pool=mock_pools["postgres"]) + executor.update_balance = AsyncMock() + + async with executor.transaction(): + await executor.update_balance(account_id=1, balance=900) + + try: + async with executor.transaction(): + await executor.update_balance(account_id=2, balance=600) + raise Exception("Inner transaction fails") + except Exception: + pass + + await executor.update_balance(account_id=3, balance=300) + + +async def test_connection_pool_exhaustion(): + Mayim( + executors=[PostgresAccountExecutor], + dsn="postgres://localhost/test", + max_size=2, + ) + + transactions = [] + + async def create_and_hold_txn(): + txn = await Mayim.transaction(PostgresAccountExecutor) + await txn.begin() + transactions.append(txn) + await asyncio.sleep(1) + + tasks = [] + for i in range(2): + task = asyncio.create_task(create_and_hold_txn()) + tasks.append(task) + + await asyncio.sleep(0.1) + + try: + txn = await Mayim.transaction(PostgresAccountExecutor) + await asyncio.wait_for(txn.begin(), timeout=0.1) + await txn.rollback() + except (MayimError, asyncio.TimeoutError): + pass + + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +async def test_invalid_transaction_state_operations(): + Mayim(executors=[PostgresAccountExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction(PostgresAccountExecutor) + + with pytest.raises(MayimError, match="Transaction not active"): + await txn.commit() + + with pytest.raises(MayimError, match="Transaction not active"): + await txn.rollback() + + await txn.begin() + await txn.commit() + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.begin() + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.commit() + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.rollback() + + +async def test_invalid_executor_combinations(): + Registry.reset() + + unregistered = PostgresAccountExecutor() + + txn = await Mayim.transaction(unregistered) + assert txn is not None + + Mayim(executors=[PostgresAccountExecutor], dsn="postgres://localhost/test") + + class NonSQLExecutor: + pass + + with pytest.raises( + MayimError, match="All executors must be SQL executors" + ): + await Mayim.transaction(PostgresAccountExecutor, NonSQLExecutor) + + with pytest.raises(MayimError, match="Invalid executor"): + await Mayim.transaction(None) + + +@pytest.mark.xfail(reason="Connection failure recovery not implemented") +async def test_connection_failure_during_transaction(): + executor = PostgresAccountExecutor() + Mayim(executors=[executor], dsn="postgres://localhost/test") + + try: + async with executor.transaction(): + await executor.update_balance(account_id=1, balance=100) + + executor.pool._connection.set(None) + + await executor.update_balance(account_id=2, balance=200) + assert False, "Expected MayimError for connection failure" + except MayimError as e: + assert "Connection lost" in str(e) or "connection" in str(e).lower() + + assert not executor.pool.in_transaction() + assert executor.pool.existing_connection() is None + + +@pytest.mark.xfail(reason="Transaction metrics collection not implemented") +async def test_transaction_metrics_collection(): + Mayim(executors=[PostgresAccountExecutor], dsn="postgres://localhost/test") + async with Mayim.transaction(PostgresAccountExecutor) as txn: + executor = Mayim.get(PostgresAccountExecutor) + await executor.update_balance(account_id=1, balance=100) + await executor.update_balance(account_id=2, balance=200) + + metrics = txn.get_metrics() + assert metrics["duration_ms"] > 0 + assert metrics["operation_count"] == 2 + assert metrics["executor_count"] == 1 + assert "begin_time" in metrics + assert "commit_time" in metrics + + +async def test_transaction_context_propagation(): + Mayim(executors=[PostgresAccountExecutor], dsn="postgres://localhost/test") + + async def nested_operation(executor): + assert executor.pool.in_transaction() + await executor.update_balance(account_id=3, balance=300) + + async with Mayim.transaction(PostgresAccountExecutor): + executor = Mayim.get(PostgresAccountExecutor) + await executor.update_balance(account_id=1, balance=100) + + await nested_operation(executor) + + task = asyncio.create_task(nested_operation(executor)) + await task + + +async def test_cross_executor_transaction_shares_connection( + postgres_connection, +): + user_executor = UserExecutor() + order_executor = OrderExecutor() + + Mayim( + executors=[user_executor, order_executor], + dsn="postgres://user:pass@localhost:5432/test", + ) + + connections_used = set() + + @asynccontextmanager + async def track_connection(*args, **kwargs): + conn = AsyncMock() + conn.execute = AsyncMock(return_value=postgres_connection) + conn.rollback = AsyncMock() + conn.commit = AsyncMock() + connections_used.add(id(conn)) + yield conn + + user_executor.pool.connection = track_connection + order_executor.pool.connection = track_connection + + postgres_connection.result = { + "id": 1, + "name": "Test", + "email": "test@example.com", + } + + txn = await Mayim.transaction(user_executor, order_executor) + async with txn: + await user_executor.create_user(name="Test", email="test@example.com") + postgres_connection.result = {"id": 1, "user_id": 1, "total": 100.0} + await order_executor.create_order(user_id=1, total=100.0) + + assert len(connections_used) == 1 + + +async def test_cross_executor_transaction_atomic_rollback(postgres_connection): + user_executor = UserExecutor() + order_executor = OrderExecutor() + + Mayim( + executors=[user_executor, order_executor], + dsn="postgres://user:pass@localhost:5432/test", + ) + + user_executor.count_users = AsyncMock(return_value={"count": 0}) + order_executor.count_orders = AsyncMock(return_value={"count": 0}) + + postgres_connection.result = {"count": 0} + initial_users = await user_executor.count_users() + initial_orders = await order_executor.count_orders() + + postgres_connection.result = { + "id": 1, + "name": "Test", + "email": "test@example.com", + } + + try: + async with Mayim.transaction(user_executor, order_executor): + await user_executor.create_user( + name="Test", email="test@example.com" + ) + await order_executor.create_order(user_id=1, total=100.0) + raise Exception("Force rollback") + except Exception: + pass + + postgres_connection.result = {"count": 0} + final_users = await user_executor.count_users() + final_orders = await order_executor.count_orders() + + assert final_users == initial_users + assert final_orders == initial_orders + + +async def test_transaction_context_visibility(postgres_connection): + user_executor = UserExecutor() + order_executor = OrderExecutor() + + Mayim( + executors=[user_executor, order_executor], + dsn="postgres://user:pass@localhost:5432/test", + ) + + @asynccontextmanager + async def mock_connection(*args, **kwargs): + conn = AsyncMock() + conn.execute = AsyncMock(return_value=postgres_connection) + yield conn + + user_executor.pool.connection = mock_connection + order_executor.pool.connection = mock_connection + + in_transaction_states = [] + + postgres_connection.result = { + "id": 1, + "name": "Test", + "email": "test@example.com", + } + + async with Mayim.transaction(user_executor, order_executor): + in_transaction_states.append(user_executor.pool.in_transaction()) + in_transaction_states.append(order_executor.pool.in_transaction()) + + assert all(in_transaction_states) + + +async def test_same_dsn_shares_pool(): + dsn = "postgres://user:pass@localhost:5432/test" + + Mayim(executors=[UserExecutor, OrderExecutor], dsn=dsn) + + user_exec = Mayim.get(UserExecutor) + order_exec = Mayim.get(OrderExecutor) + + assert user_exec.pool is order_exec.pool + + +async def test_explicit_transaction_state_machine_detailed(): + Mayim(executors=[UserExecutor], dsn="postgres://user:pass@localhost/test") + txn = await Mayim.transaction(UserExecutor) + + assert not txn.is_active + assert not txn.is_committed + assert not txn.is_rolled_back + + await txn.begin() + assert txn.is_active + assert not txn.is_committed + assert not txn.is_rolled_back + + with pytest.raises(MayimError, match="Transaction already active"): + await txn.begin() + + await txn.commit() + assert not txn.is_active + assert txn.is_committed + assert not txn.is_rolled_back + + with pytest.raises(MayimError, match="Transaction already completed"): + await txn.commit() + + +async def test_mixed_transaction_patterns(): + Mayim( + executors=[UserExecutor, OrderExecutor], + dsn="postgres://user:pass@localhost/test", + ) + + async with Mayim.transaction(UserExecutor) as txn: + user_exec = Mayim.get(UserExecutor) + await user_exec.create_user( + name="Context", email="context@example.com" + ) + assert txn.is_active + + txn2 = await Mayim.transaction(OrderExecutor) + await txn2.begin() + order_exec = Mayim.get(OrderExecutor) + await order_exec.create_order(user_id=1, total=50.0) + await txn2.commit() + + assert txn2.is_committed + + +async def test_verify_connection_reuse_in_nested_calls(postgres_connection): + user_executor = UserExecutor() + Mayim( + executors=[user_executor], + dsn="postgres://user:pass@localhost:5432/test", + ) + + connections = [] + + mock_conn = AsyncMock() + mock_conn.execute = AsyncMock(return_value=postgres_connection) + mock_conn.rollback = AsyncMock() + mock_conn.commit = AsyncMock() + + @asynccontextmanager + async def track_conn(*args, **kwargs): + conn_id = id(mock_conn) + connections.append(conn_id) + yield mock_conn + + user_executor.pool.connection = track_conn + + postgres_connection.result = { + "id": 1, + "name": "User1", + "email": "user1@example.com", + } + + async with user_executor.transaction(): + await user_executor.create_user( + name="User1", email="user1@example.com" + ) + postgres_connection.result = {"count": 1} + await user_executor.update_user(id=1, name="Updated") + + assert len(connections) >= 1 + assert len(set(connections)) == 1 + + +async def test_concurrent_transactions_are_isolated(): + results = [] + + Mayim(executors=[UserExecutor], dsn="postgres://user:pass@localhost/test") + + async def transaction1(): + async with Mayim.transaction(UserExecutor): + user_exec = Mayim.get(UserExecutor) + results.append(("txn1", user_exec.pool.in_transaction())) + await asyncio.sleep(0.01) + results.append(("txn1_after", user_exec.pool.in_transaction())) + + async def transaction2(): + await asyncio.sleep(0.005) + user_exec = Mayim.get(UserExecutor) + results.append(("txn2", user_exec.pool.in_transaction())) + + await asyncio.gather(transaction1(), transaction2()) + + assert results[0] == ("txn1", True) + assert results[1] == ("txn2", False) + assert results[2] == ("txn1_after", True) + + +async def test_two_phase_commit_detailed(postgres_connection): + prepared_pools = [] + committed_pools = [] + + @asynccontextmanager + async def mock_connection_2pc(*args, **kwargs): + conn = AsyncMock() + conn.execute = AsyncMock(return_value=postgres_connection) + conn.rollback = AsyncMock() + conn.commit = AsyncMock() + + async def mock_prepare(): + prepared_pools.append(id(conn)) + + async def mock_commit_prepared(): + committed_pools.append(id(conn)) + + conn.prepare = AsyncMock(side_effect=mock_prepare) + conn.commit_prepared = AsyncMock(side_effect=mock_commit_prepared) + + yield conn + + Mayim( + executors=[UserExecutor, OrderExecutor], + dsn="postgres://user:pass@localhost/test", + ) + + user_exec = Mayim.get(UserExecutor) + order_exec = Mayim.get(OrderExecutor) + + user_exec.pool.connection = mock_connection_2pc + order_exec.pool.connection = mock_connection_2pc + + txn = await Mayim.transaction(UserExecutor, OrderExecutor, use_2pc=True) + await txn.begin() + + postgres_connection.result = { + "id": 1, + "name": "Test", + "email": "test@example.com", + } + await user_exec.create_user(name="Test", email="test@example.com") + + postgres_connection.result = {"id": 1, "user_id": 1, "total": 100.0} + await order_exec.create_order(user_id=1, total=100.0) + + prepare_result = await txn.prepare_all() + assert prepare_result + + assert len(prepared_pools) > 0 + assert len(committed_pools) == 0 + + await txn.commit() + + assert len(committed_pools) > 0 + + +@pytest.mark.xfail(reason="Optional advanced features not fully implemented") +async def test_explicit_api_transaction_properties(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction(FoobarExecutor) + + assert hasattr(txn, "is_active") + assert hasattr(txn, "is_committed") + assert hasattr(txn, "is_rolled_back") + assert hasattr(txn, "executors") + + assert hasattr(txn, "begin") + assert hasattr(txn, "commit") + assert hasattr(txn, "rollback") + + if hasattr(txn, "savepoint"): + await txn.begin() + sp = await txn.savepoint("sp1") + await sp.rollback() + + if hasattr(txn, "get_metrics"): + await txn.begin() + await txn.commit() + metrics = txn.get_metrics() + assert "duration_ms" in metrics + assert "operation_count" in metrics + + +@pytest.mark.xfail(reason="Isolation level feature not implemented") +async def test_explicit_api_isolation_level(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction( + FoobarExecutor, + isolation_level="SERIALIZABLE", + ) + await txn.begin() + + if hasattr(txn, "isolation_level"): + assert txn.isolation_level == "SERIALIZABLE" + + await txn.commit() + + +@pytest.mark.xfail(reason="Read-only transactions not implemented") +async def test_explicit_api_readonly_transactions(): + Mayim(executors=[FoobarExecutor], dsn="postgres://localhost/test") + txn = await Mayim.transaction(FoobarExecutor, readonly=True) + await txn.begin() + + await txn.commit() + + +@pytest.mark.xfail(reason="Transaction lifecycle hooks not implemented") +async def test_transaction_lifecycle_hooks(): + events = [] + + async def on_begin(txn): + events.append(("begin", txn)) + + async def on_commit(txn): + events.append(("commit", txn)) + + async def on_rollback(txn): + events.append(("rollback", txn)) + + Mayim.transaction_hooks( + on_begin=on_begin, on_commit=on_commit, on_rollback=on_rollback + ) + + async with Mayim.transaction(PostgresAccountExecutor): + pass + + assert events[0][0] == "begin" + assert events[1][0] == "commit" + + events.clear() + + try: + async with Mayim.transaction(PostgresAccountExecutor): + raise Exception("Test failure") + except Exception: + pass + + assert events[0][0] == "begin" + assert events[1][0] == "rollback" diff --git a/tox.ini b/tox.ini index c7c0b88..dd6d709 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,21 @@ [tox] -envlist = {py38,py39,py310}, check +envlist = {py39,py39,py310,py311,py312,py313}, check [gh-actions] python = - 3.8: py38 - 3.9: py39, check + 3.9: py39 3.10: py310 + 3.11: py311 + 3.12: py312 + 3.13: py313, check [testenv] extras = test postgres mysql - aiosqlite + sqlite commands = pytest {posargs:tests}