From 3ff52f2a1ad121a012bd86dd17d4febc5616347f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Tue, 31 Dec 2024 02:34:08 +0200 Subject: [PATCH] Fixed config oopsie and increased test coverage --- pyproject.toml | 2 +- src/asphalt/sqlalchemy/_component.py | 9 ++++-- tests/test_component.py | 9 ++++-- tests/test_utils.py | 43 ++++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7abac1c..361fe8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,12 +84,12 @@ testpaths = ["tests"] python_version = "3.9" strict = true explicit_package_bases = true +mypy_path = ["src", "tests", "examples"] [tool.coverage.run] source = ["asphalt.sqlalchemy"] relative_files = true branch = true -mypy_path = ["src", "tests", "examples"] [tool.coverage.report] show_missing = true diff --git a/src/asphalt/sqlalchemy/_component.py b/src/asphalt/sqlalchemy/_component.py index 4813efa..b877aa4 100644 --- a/src/asphalt/sqlalchemy/_component.py +++ b/src/asphalt/sqlalchemy/_component.py @@ -91,7 +91,10 @@ def __init__( engine_args: dict[str, Any] | None = None, session_args: dict[str, Any] | None = None, commit_executor_workers: int = 50, - ready_callback: Callable[[Engine, sessionmaker[Any]], Any] | str | None = None, + ready_callback: Callable[[Engine, sessionmaker[Any]], Any] + | Callable[[AsyncEngine, async_sessionmaker[Any]], Any] + | str + | None = None, poolclass: str | type[Pool] | None = None, ): self.commit_thread_limiter = CapacityLimiter(commit_executor_workers) @@ -112,7 +115,7 @@ def __init__( elif isinstance(bind, AsyncEngine): self._engine = self._async_bind = bind else: - raise TypeError(f"Incompatible bind argument: {qualified_name(bind)}") + raise TypeError(f"incompatible bind argument: {qualified_name(bind)}") else: if isinstance(url, dict): url = URL.create(**url) @@ -131,7 +134,7 @@ def __init__( if isinstance(poolclass, str): poolclass = resolve_reference(poolclass) - pool_class = cast("type[Pool]", poolclass) + pool_class = cast(type[Pool], poolclass) if prefer_async: try: self._engine = self._async_bind = create_async_engine( diff --git a/tests/test_component.py b/tests/test_component.py index 2117f46..e3b438f 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -34,6 +34,11 @@ pytestmark = pytest.mark.anyio +def test_bad_bind_argument() -> None: + with pytest.raises(TypeError, match="incompatible bind argument: str"): + SQLAlchemyComponent(bind="bad") # type: ignore[arg-type] + + async def test_component_start_sync() -> None: """Test that the component creates all the expected (synchronous) resources.""" url = URL.create("sqlite", database=":memory:") @@ -285,9 +290,7 @@ def listener(session: Session) -> None: engine = get_resource_nowait(AsyncEngine) dbsession = get_resource_nowait(AsyncSession) await dbsession.run_sync( - lambda session: Person.metadata.create_all( - session.bind # type: ignore[arg-type] - ) + lambda session: Person.metadata.create_all(session.bind) ) dbsession.add(Person(name="Test person")) diff --git a/tests/test_utils.py b/tests/test_utils.py index d7d1a05..98edd05 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,15 +1,18 @@ from __future__ import annotations -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator from typing import Any import pytest from sqlalchemy.engine import Connection, Engine +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.sql.ddl import CreateSchema, DropSchema from sqlalchemy.sql.schema import Column, ForeignKey, MetaData, Table from sqlalchemy.sql.sqltypes import Integer -from asphalt.sqlalchemy import clear_database +from asphalt.sqlalchemy import clear_async_database, clear_database + +pytestmark = pytest.mark.anyio @pytest.fixture @@ -31,6 +34,27 @@ def connection(sync_engine: Engine) -> Generator[Connection, Any, None]: conn.execute(DropSchema("altschema")) +@pytest.fixture +async def async_connection( + async_engine: AsyncEngine, +) -> AsyncGenerator[AsyncConnection]: + async with async_engine.connect() as conn: + metadata = MetaData() + Table("table", metadata, Column("column1", Integer, primary_key=True)) + Table("table2", metadata, Column("fk_column", ForeignKey("table.column1"))) + if conn.dialect.name != "sqlite": + await conn.execute(CreateSchema("altschema")) + Table("table3", metadata, Column("fk_column", Integer), schema="altschema") + + await conn.run_sync(metadata.create_all) + + yield conn + + if conn.dialect.name != "sqlite": + await conn.run_sync(metadata.drop_all) + await conn.execute(DropSchema("altschema")) + + def test_clear_database(connection: Connection) -> None: clear_database( connection, ["altschema"] if connection.dialect.name != "sqlite" else [] @@ -43,3 +67,18 @@ def test_clear_database(connection: Connection) -> None: alt_metadata = MetaData(schema="altschema") alt_metadata.reflect(connection) assert len(alt_metadata.tables) == 0 + + +async def test_clear_async_database(async_connection: AsyncConnection) -> None: + await clear_async_database( + async_connection, + ["altschema"] if async_connection.dialect.name != "sqlite" else [], + ) + metadata = MetaData() + await async_connection.run_sync(metadata.reflect) + assert len(metadata.tables) == 0 + + if async_connection.dialect.name != "sqlite": + alt_metadata = MetaData(schema="altschema") + await async_connection.run_sync(alt_metadata.reflect) + assert len(alt_metadata.tables) == 0