Skip to content

Commit

Permalink
Fixed config oopsie and increased test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Dec 31, 2024
1 parent cb88971 commit 3ff52f2
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ testpaths = ["tests"]
python_version = "3.9"
strict = true
explicit_package_bases = true
mypy_path = ["src", "tests", "examples"]

[tool.coverage.run]
source = ["asphalt.sqlalchemy"]
relative_files = true
branch = true
mypy_path = ["src", "tests", "examples"]

[tool.coverage.report]
show_missing = true
Expand Down
9 changes: 6 additions & 3 deletions src/asphalt/sqlalchemy/_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def __init__(
engine_args: dict[str, Any] | None = None,
session_args: dict[str, Any] | None = None,
commit_executor_workers: int = 50,
ready_callback: Callable[[Engine, sessionmaker[Any]], Any] | str | None = None,
ready_callback: Callable[[Engine, sessionmaker[Any]], Any]
| Callable[[AsyncEngine, async_sessionmaker[Any]], Any]
| str
| None = None,
poolclass: str | type[Pool] | None = None,
):
self.commit_thread_limiter = CapacityLimiter(commit_executor_workers)
Expand All @@ -112,7 +115,7 @@ def __init__(
elif isinstance(bind, AsyncEngine):
self._engine = self._async_bind = bind
else:
raise TypeError(f"Incompatible bind argument: {qualified_name(bind)}")
raise TypeError(f"incompatible bind argument: {qualified_name(bind)}")
else:
if isinstance(url, dict):
url = URL.create(**url)
Expand All @@ -131,7 +134,7 @@ def __init__(
if isinstance(poolclass, str):
poolclass = resolve_reference(poolclass)

pool_class = cast("type[Pool]", poolclass)
pool_class = cast(type[Pool], poolclass)
if prefer_async:
try:
self._engine = self._async_bind = create_async_engine(
Expand Down
9 changes: 6 additions & 3 deletions tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
pytestmark = pytest.mark.anyio


def test_bad_bind_argument() -> None:
with pytest.raises(TypeError, match="incompatible bind argument: str"):
SQLAlchemyComponent(bind="bad") # type: ignore[arg-type]


async def test_component_start_sync() -> None:
"""Test that the component creates all the expected (synchronous) resources."""
url = URL.create("sqlite", database=":memory:")
Expand Down Expand Up @@ -285,9 +290,7 @@ def listener(session: Session) -> None:
engine = get_resource_nowait(AsyncEngine)
dbsession = get_resource_nowait(AsyncSession)
await dbsession.run_sync(
lambda session: Person.metadata.create_all(
session.bind # type: ignore[arg-type]
)
lambda session: Person.metadata.create_all(session.bind)
)
dbsession.add(Person(name="Test person"))

Expand Down
43 changes: 41 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
from typing import Any

import pytest
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
from sqlalchemy.sql.ddl import CreateSchema, DropSchema
from sqlalchemy.sql.schema import Column, ForeignKey, MetaData, Table
from sqlalchemy.sql.sqltypes import Integer

from asphalt.sqlalchemy import clear_database
from asphalt.sqlalchemy import clear_async_database, clear_database

pytestmark = pytest.mark.anyio


@pytest.fixture
Expand All @@ -31,6 +34,27 @@ def connection(sync_engine: Engine) -> Generator[Connection, Any, None]:
conn.execute(DropSchema("altschema"))


@pytest.fixture
async def async_connection(
async_engine: AsyncEngine,
) -> AsyncGenerator[AsyncConnection]:
async with async_engine.connect() as conn:
metadata = MetaData()
Table("table", metadata, Column("column1", Integer, primary_key=True))
Table("table2", metadata, Column("fk_column", ForeignKey("table.column1")))
if conn.dialect.name != "sqlite":
await conn.execute(CreateSchema("altschema"))
Table("table3", metadata, Column("fk_column", Integer), schema="altschema")

await conn.run_sync(metadata.create_all)

yield conn

if conn.dialect.name != "sqlite":
await conn.run_sync(metadata.drop_all)
await conn.execute(DropSchema("altschema"))


def test_clear_database(connection: Connection) -> None:
clear_database(
connection, ["altschema"] if connection.dialect.name != "sqlite" else []
Expand All @@ -43,3 +67,18 @@ def test_clear_database(connection: Connection) -> None:
alt_metadata = MetaData(schema="altschema")
alt_metadata.reflect(connection)
assert len(alt_metadata.tables) == 0


async def test_clear_async_database(async_connection: AsyncConnection) -> None:
await clear_async_database(
async_connection,
["altschema"] if async_connection.dialect.name != "sqlite" else [],
)
metadata = MetaData()
await async_connection.run_sync(metadata.reflect)
assert len(metadata.tables) == 0

if async_connection.dialect.name != "sqlite":
alt_metadata = MetaData(schema="altschema")
await async_connection.run_sync(alt_metadata.reflect)
assert len(alt_metadata.tables) == 0

0 comments on commit 3ff52f2

Please sign in to comment.