Skip to content

Commit

Permalink
- fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
seanharrison committed Sep 23, 2024
1 parent b28b494 commit 6716c8d
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 18 deletions.
2 changes: 1 addition & 1 deletion sqly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dialect import Dialect
from .query import Q
from .sql import SQL, ASQL
from .sql import ASQL, SQL
4 changes: 2 additions & 2 deletions sqly/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def migrate(migration_key, database_url=None, dialect=None, dryrun=False):
sys.exit(1)

# force psycopg instead of asyncpg for migrations
if dialect == 'asyncpg':
dialect = 'psycopg'
if dialect == "asyncpg":
dialect = "psycopg"

dialect = Dialect(dialect)
adaptor = dialect.adaptor()
Expand Down
2 changes: 1 addition & 1 deletion sqly/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def must_async(self) -> bool:

@property
def can_async(self) -> bool:
return self in [self.ASYNCPG, self.PSYCOPG]
return self in [self.ASYNCPG, self.PSYCOPG]
6 changes: 4 additions & 2 deletions sqly/lib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import asyncio
import inspect


def walk(iterator):
"""
Walk a nested iterator and yield items in a single stream.
Expand All @@ -24,11 +25,12 @@ def run(f):
f = asyncio.run(f)
return f


def gen(g):
async def unasync(g):
return [item async for item in g]

if inspect.isasyncgen(g):
return asyncio.run(unasync(g))
else:
return list(g)

2 changes: 1 addition & 1 deletion sqly/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import networkx as nx
import yaml

from . import queries, lib
from . import lib, queries
from .dialect import Dialect
from .query import Q
from .sql import ASQL, SQL
Expand Down
5 changes: 3 additions & 2 deletions sqly/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def UPDATE(relation: str, fields: Iterable, filters: Iterable[str]) -> str:
return " ".join(query)


def UPSERT(relation: str, fields: Iterable[str], key: Iterable[str], returning=False) -> str:
def UPSERT(
relation: str, fields: Iterable[str], key: Iterable[str], returning=False
) -> str:
query = [
INSERT(relation, fields, returning=False),
f"ON CONFLICT ({Q.fields(key)})",
Expand All @@ -117,7 +119,6 @@ def UPSERT(relation: str, fields: Iterable[str], key: Iterable[str], returning=F
return " ".join(query)



def DELETE(relation: str, filters: Iterable[str]) -> str:
"""
Build a DELETE query with the following form:
Expand Down
2 changes: 1 addition & 1 deletion sqly/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from typing import Any, Iterator, Mapping, Optional

from . import queries
from .dialect import Dialect, ParamFormat
from .lib import walk
from .query import Q
from . import queries


@dataclass
Expand Down
1 change: 0 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
("sqlite", f"file://{PATH}/test.db"),
("psycopg", POSTGRESQL_URL),
# ("asyncpg", POSTGRESQL_URL),

# (
# "mysql",
# dedent(
Expand Down
28 changes: 21 additions & 7 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pytest

from sqly.dialect import Dialect
from sqly import lib
from sqly.dialect import Dialect
from sqly.sql import SQL
from tests import fixtures

Expand Down Expand Up @@ -35,23 +35,33 @@ def test_execute_query_ok(dialect_name, database_url):
connection = lib.run(adaptor.connect(database_url))

# execute a query that should have visible results
cursor = lib.run(connection.execute("CREATE TABLE widgets (id int, sku varchar)"))
cursor = lib.run(
connection.execute("CREATE TABLE widgets (id int, sku varchar)")
)

print(f"{connection=}")
widget = {"id": 1, "sku": "COG-01"}
# - the following table exists (and using the cursor to execute is fine)
lib.run(sql.execute(cursor, "INSERT INTO widgets (id, sku) VALUES (:id, :sku)", widget))
lib.run(
sql.execute(
cursor, "INSERT INTO widgets (id, sku) VALUES (:id, :sku)", widget
)
)

print(f"{connection=}")
# - the row is in the table
row = lib.run(sql.select_one(cursor, "SELECT * from widgets WHERE id=:id", widget))
row = lib.run(
sql.select_one(cursor, "SELECT * from widgets WHERE id=:id", widget)
)
assert row == widget

# after we rollback, the table doesn't exist (NOTE: This might not work on all
# databases, because not all have transactional DDL. )
lib.run(connection.rollback())
with pytest.raises(Exception):
row = lib.run(sql.select_one(connection, "SELECT * from widgets WHERE id=:id", widget))
row = lib.run(
sql.select_one(connection, "SELECT * from widgets WHERE id=:id", widget)
)
# If the DDL wasn't transactional, the row still doesn't exist - is None
assert row

Expand Down Expand Up @@ -146,14 +156,18 @@ def test_cursor_as_connection(dialect_name, database_url):
# connection = adaptor.connect(**conn_info)
# else:
connection = lib.run(adaptor.connect(database_url))
cursor = lib.run(sql.execute(connection, "CREATE TABLE WIDGETS (id int, sku varchar)"))
cursor = lib.run(
sql.execute(connection, "CREATE TABLE WIDGETS (id int, sku varchar)")
)
lib.run(connection.commit())
with pytest.raises(Exception, match="foo"):
lib.run(sql.execute(cursor, "INSERT INTO foo VALUES (1, 2)"))
lib.run(connection.rollback())

widget = {"id": 1, "sku": "COG-01"}
cursor2 = lib.run(sql.execute(cursor, "INSERT INTO widgets VALUES (:id, :sku)", widget))
cursor2 = lib.run(
sql.execute(cursor, "INSERT INTO widgets VALUES (:id, :sku)", widget)
)
assert cursor2 == cursor
record = lib.run(sql.select_one(cursor, "SELECT * FROM widgets"))
assert record == widget
Expand Down

0 comments on commit 6716c8d

Please sign in to comment.