Skip to content

Commit 4dadafc

Browse files
authored
feat: ensure connection type is unique (#29)
Adds type aliases for each connection type to help with the signature namespace.
1 parent db59d09 commit 4dadafc

File tree

29 files changed

+424
-388
lines changed

29 files changed

+424
-388
lines changed

sqlspec/adapters/adbc/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from sqlspec.adapters.adbc.config import AdbcConfig
2-
from sqlspec.adapters.adbc.driver import AdbcDriver
2+
from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver
33

44
__all__ = (
55
"AdbcConfig",
6+
"AdbcConnection",
67
"AdbcDriver",
78
)

sqlspec/adapters/adbc/config.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
from dataclasses import dataclass, field
33
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
44

5-
from adbc_driver_manager.dbapi import Connection
6-
7-
from sqlspec.adapters.adbc.driver import AdbcDriver
5+
from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver
86
from sqlspec.base import NoPoolSyncConfig
97
from sqlspec.exceptions import ImproperConfigurationError
108
from sqlspec.typing import Empty, EmptyType
@@ -18,7 +16,7 @@
1816

1917

2018
@dataclass
21-
class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]):
19+
class AdbcConfig(NoPoolSyncConfig["AdbcConnection", "AdbcDriver"]):
2220
"""Configuration for ADBC connections.
2321
2422
This class provides configuration options for ADBC database connections using the
@@ -33,20 +31,16 @@ class AdbcConfig(NoPoolSyncConfig["Connection", "AdbcDriver"]):
3331
"""Additional database-specific connection parameters"""
3432
conn_kwargs: "Optional[dict[str, Any]]" = None
3533
"""Additional database-specific connection parameters"""
36-
connection_type: "type[Connection]" = field(init=False, default_factory=lambda: Connection)
34+
connection_type: "type[AdbcConnection]" = field(init=False, default_factory=lambda: AdbcConnection)
3735
"""Type of the connection object"""
3836
driver_type: "type[AdbcDriver]" = field(init=False, default_factory=lambda: AdbcDriver) # type: ignore[type-abstract,unused-ignore]
3937
"""Type of the driver object"""
40-
pool_instance: None = field(init=False, default=None)
38+
pool_instance: None = field(init=False, default=None, hash=False)
4139
"""No connection pool is used for ADBC connections"""
42-
_is_in_memory: bool = field(init=False, default=False)
43-
"""Flag indicating if the connection is for an in-memory database"""
4440

4541
def _set_adbc(self) -> str: # noqa: PLR0912
4642
"""Identify the driver type based on the URI (if provided) or preset driver name.
4743
48-
Also sets the `_is_in_memory` flag for specific in-memory URIs.
49-
5044
Raises:
5145
ImproperConfigurationError: If the driver name is not recognized or supported.
5246
@@ -143,7 +137,7 @@ def connection_config_dict(self) -> "dict[str, Any]":
143137
config["conn_kwargs"] = conn_kwargs
144138
return config
145139

146-
def _get_connect_func(self) -> "Callable[..., Connection]":
140+
def _get_connect_func(self) -> "Callable[..., AdbcConnection]":
147141
self._set_adbc()
148142
driver_path = cast("str", self.driver_name)
149143
try:
@@ -166,7 +160,7 @@ def _get_connect_func(self) -> "Callable[..., Connection]":
166160
raise ImproperConfigurationError(msg)
167161
return connect_func # type: ignore[no-any-return]
168162

169-
def create_connection(self) -> "Connection":
163+
def create_connection(self) -> "AdbcConnection":
170164
"""Create and return a new database connection using the specific driver.
171165
172166
Returns:
@@ -189,7 +183,7 @@ def create_connection(self) -> "Connection":
189183
raise ImproperConfigurationError(msg) from e
190184

191185
@contextmanager
192-
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[Connection, None, None]":
186+
def provide_connection(self, *args: "Any", **kwargs: "Any") -> "Generator[AdbcConnection, None, None]":
193187
"""Create and provide a database connection using the specific driver.
194188
195189
Yields:

sqlspec/adapters/adbc/driver.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
if TYPE_CHECKING:
1717
from sqlspec.typing import ArrowTable, ModelDTOT, StatementParameterType, T
1818

19-
__all__ = ("AdbcDriver",)
19+
__all__ = ("AdbcConnection", "AdbcDriver")
2020

2121
logger = logging.getLogger("sqlspec")
2222

@@ -33,24 +33,26 @@
3333
re.VERBOSE | re.DOTALL,
3434
)
3535

36+
AdbcConnection = Connection
37+
3638

3739
class AdbcDriver(
38-
SyncArrowBulkOperationsMixin["Connection"],
39-
SQLTranslatorMixin["Connection"],
40-
SyncDriverAdapterProtocol["Connection"],
40+
SyncArrowBulkOperationsMixin["AdbcConnection"],
41+
SQLTranslatorMixin["AdbcConnection"],
42+
SyncDriverAdapterProtocol["AdbcConnection"],
4143
):
4244
"""ADBC Sync Driver Adapter."""
4345

44-
connection: Connection
46+
connection: AdbcConnection
4547
__supports_arrow__: ClassVar[bool] = True
4648

47-
def __init__(self, connection: "Connection") -> None:
49+
def __init__(self, connection: "AdbcConnection") -> None:
4850
"""Initialize the ADBC driver adapter."""
4951
self.connection = connection
5052
self.dialect = self._get_dialect(connection)
5153

5254
@staticmethod
53-
def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911
55+
def _get_dialect(connection: "AdbcConnection") -> str: # noqa: PLR0911
5456
"""Get the database dialect based on the driver name.
5557
5658
Args:
@@ -75,11 +77,11 @@ def _get_dialect(connection: "Connection") -> str: # noqa: PLR0911
7577
return "postgres" # default to postgresql dialect
7678

7779
@staticmethod
78-
def _cursor(connection: "Connection", *args: Any, **kwargs: Any) -> "Cursor":
80+
def _cursor(connection: "AdbcConnection", *args: Any, **kwargs: Any) -> "Cursor":
7981
return connection.cursor(*args, **kwargs)
8082

8183
@contextmanager
82-
def _with_cursor(self, connection: "Connection") -> Generator["Cursor", None, None]:
84+
def _with_cursor(self, connection: "AdbcConnection") -> Generator["Cursor", None, None]:
8385
cursor = self._cursor(connection)
8486
try:
8587
yield cursor
@@ -172,7 +174,7 @@ def select(
172174
parameters: "Optional[StatementParameterType]" = None,
173175
/,
174176
*,
175-
connection: "Optional[Connection]" = None,
177+
connection: "Optional[AdbcConnection]" = None,
176178
schema_type: None = None,
177179
**kwargs: Any,
178180
) -> "Sequence[dict[str, Any]]": ...
@@ -183,7 +185,7 @@ def select(
183185
parameters: "Optional[StatementParameterType]" = None,
184186
/,
185187
*,
186-
connection: "Optional[Connection]" = None,
188+
connection: "Optional[AdbcConnection]" = None,
187189
schema_type: "type[ModelDTOT]",
188190
**kwargs: Any,
189191
) -> "Sequence[ModelDTOT]": ...
@@ -193,7 +195,7 @@ def select(
193195
parameters: Optional["StatementParameterType"] = None,
194196
/,
195197
*,
196-
connection: Optional["Connection"] = None,
198+
connection: Optional["AdbcConnection"] = None,
197199
schema_type: "Optional[type[ModelDTOT]]" = None,
198200
**kwargs: Any,
199201
) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]":
@@ -223,7 +225,7 @@ def select_one(
223225
parameters: "Optional[StatementParameterType]" = None,
224226
/,
225227
*,
226-
connection: "Optional[Connection]" = None,
228+
connection: "Optional[AdbcConnection]" = None,
227229
schema_type: None = None,
228230
**kwargs: Any,
229231
) -> "dict[str, Any]": ...
@@ -234,7 +236,7 @@ def select_one(
234236
parameters: "Optional[StatementParameterType]" = None,
235237
/,
236238
*,
237-
connection: "Optional[Connection]" = None,
239+
connection: "Optional[AdbcConnection]" = None,
238240
schema_type: "type[ModelDTOT]",
239241
**kwargs: Any,
240242
) -> "ModelDTOT": ...
@@ -244,7 +246,7 @@ def select_one(
244246
parameters: Optional["StatementParameterType"] = None,
245247
/,
246248
*,
247-
connection: Optional["Connection"] = None,
249+
connection: Optional["AdbcConnection"] = None,
248250
schema_type: "Optional[type[ModelDTOT]]" = None,
249251
**kwargs: Any,
250252
) -> "Union[ModelDTOT, dict[str, Any]]":
@@ -271,7 +273,7 @@ def select_one_or_none(
271273
parameters: "Optional[StatementParameterType]" = None,
272274
/,
273275
*,
274-
connection: "Optional[Connection]" = None,
276+
connection: "Optional[AdbcConnection]" = None,
275277
schema_type: None = None,
276278
**kwargs: Any,
277279
) -> "Optional[dict[str, Any]]": ...
@@ -282,7 +284,7 @@ def select_one_or_none(
282284
parameters: "Optional[StatementParameterType]" = None,
283285
/,
284286
*,
285-
connection: "Optional[Connection]" = None,
287+
connection: "Optional[AdbcConnection]" = None,
286288
schema_type: "type[ModelDTOT]",
287289
**kwargs: Any,
288290
) -> "Optional[ModelDTOT]": ...
@@ -292,7 +294,7 @@ def select_one_or_none(
292294
parameters: Optional["StatementParameterType"] = None,
293295
/,
294296
*,
295-
connection: Optional["Connection"] = None,
297+
connection: Optional["AdbcConnection"] = None,
296298
schema_type: "Optional[type[ModelDTOT]]" = None,
297299
**kwargs: Any,
298300
) -> "Optional[Union[ModelDTOT, dict[str, Any]]]":
@@ -320,7 +322,7 @@ def select_value(
320322
parameters: "Optional[StatementParameterType]" = None,
321323
/,
322324
*,
323-
connection: "Optional[Connection]" = None,
325+
connection: "Optional[AdbcConnection]" = None,
324326
schema_type: None = None,
325327
**kwargs: Any,
326328
) -> "Any": ...
@@ -331,7 +333,7 @@ def select_value(
331333
parameters: "Optional[StatementParameterType]" = None,
332334
/,
333335
*,
334-
connection: "Optional[Connection]" = None,
336+
connection: "Optional[AdbcConnection]" = None,
335337
schema_type: "type[T]",
336338
**kwargs: Any,
337339
) -> "T": ...
@@ -341,7 +343,7 @@ def select_value(
341343
parameters: Optional["StatementParameterType"] = None,
342344
/,
343345
*,
344-
connection: Optional["Connection"] = None,
346+
connection: Optional["AdbcConnection"] = None,
345347
schema_type: "Optional[type[T]]" = None,
346348
**kwargs: Any,
347349
) -> "Union[T, Any]":
@@ -367,7 +369,7 @@ def select_value_or_none(
367369
parameters: "Optional[StatementParameterType]" = None,
368370
/,
369371
*,
370-
connection: "Optional[Connection]" = None,
372+
connection: "Optional[AdbcConnection]" = None,
371373
schema_type: None = None,
372374
**kwargs: Any,
373375
) -> "Optional[Any]": ...
@@ -378,7 +380,7 @@ def select_value_or_none(
378380
parameters: "Optional[StatementParameterType]" = None,
379381
/,
380382
*,
381-
connection: "Optional[Connection]" = None,
383+
connection: "Optional[AdbcConnection]" = None,
382384
schema_type: "type[T]",
383385
**kwargs: Any,
384386
) -> "Optional[T]": ...
@@ -388,7 +390,7 @@ def select_value_or_none(
388390
parameters: Optional["StatementParameterType"] = None,
389391
/,
390392
*,
391-
connection: Optional["Connection"] = None,
393+
connection: Optional["AdbcConnection"] = None,
392394
schema_type: "Optional[type[T]]" = None,
393395
**kwargs: Any,
394396
) -> "Optional[Union[T, Any]]":
@@ -414,7 +416,7 @@ def insert_update_delete(
414416
parameters: Optional["StatementParameterType"] = None,
415417
/,
416418
*,
417-
connection: Optional["Connection"] = None,
419+
connection: Optional["AdbcConnection"] = None,
418420
**kwargs: Any,
419421
) -> int:
420422
"""Insert, update, or delete data from the database.
@@ -436,7 +438,7 @@ def insert_update_delete_returning(
436438
parameters: "Optional[StatementParameterType]" = None,
437439
/,
438440
*,
439-
connection: "Optional[Connection]" = None,
441+
connection: "Optional[AdbcConnection]" = None,
440442
schema_type: None = None,
441443
**kwargs: Any,
442444
) -> "dict[str, Any]": ...
@@ -447,7 +449,7 @@ def insert_update_delete_returning(
447449
parameters: "Optional[StatementParameterType]" = None,
448450
/,
449451
*,
450-
connection: "Optional[Connection]" = None,
452+
connection: "Optional[AdbcConnection]" = None,
451453
schema_type: "type[ModelDTOT]",
452454
**kwargs: Any,
453455
) -> "ModelDTOT": ...
@@ -457,7 +459,7 @@ def insert_update_delete_returning(
457459
parameters: Optional["StatementParameterType"] = None,
458460
/,
459461
*,
460-
connection: Optional["Connection"] = None,
462+
connection: Optional["AdbcConnection"] = None,
461463
schema_type: "Optional[type[ModelDTOT]]" = None,
462464
**kwargs: Any,
463465
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
@@ -490,7 +492,7 @@ def execute_script(
490492
parameters: Optional["StatementParameterType"] = None,
491493
/,
492494
*,
493-
connection: Optional["Connection"] = None,
495+
connection: Optional["AdbcConnection"] = None,
494496
**kwargs: Any,
495497
) -> str:
496498
"""Execute a script.
@@ -513,7 +515,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType]
513515
parameters: "Optional[StatementParameterType]" = None,
514516
/,
515517
*,
516-
connection: "Optional[Connection]" = None,
518+
connection: "Optional[AdbcConnection]" = None,
517519
**kwargs: Any,
518520
) -> "ArrowTable":
519521
"""Execute a SQL query and return results as an Apache Arrow Table.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
2-
from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver
2+
from sqlspec.adapters.aiosqlite.driver import AiosqliteConnection, AiosqliteDriver
33

44
__all__ = (
55
"AiosqliteConfig",
6+
"AiosqliteConnection",
67
"AiosqliteDriver",
78
)

0 commit comments

Comments
 (0)