Skip to content

Commit 5fe5d31

Browse files
authoredDec 22, 2023
Return common data structure in DBApi derived classes
The ADR for Airflow' s DB API specifies it needs to return a named tuple SerializableRow or a list of them.
1 parent 33ee0b9 commit 5fe5d31

File tree

22 files changed

+191
-108
lines changed

22 files changed

+191
-108
lines changed
 

‎.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,8 @@ repos:
385385
files: ^dev/breeze/src/airflow_breeze/utils/docker_command_utils\.py$|^scripts/ci/docker_compose/local\.yml$
386386
pass_filenames: false
387387
additional_dependencies: ['rich>=12.4.4']
388-
- id: check-common-sql-dependency-make-serializable
389-
name: Check dependency of SQL Providers with '_make_serializable'
388+
- id: check-sql-dependency-common-data-structure
389+
name: Check dependency of SQL Providers with common data structure
390390
entry: ./scripts/ci/pre_commit/pre_commit_check_common_sql_dependency.py
391391
language: python
392392
files: ^airflow/providers/.*/hooks/.*\.py$

‎STATIC_CODE_CHECKS.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ require Breeze Docker image to be built locally.
170170
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
171171
| check-cncf-k8s-only-for-executors | Check cncf.kubernetes imports used for executors only | |
172172
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
173-
| check-common-sql-dependency-make-serializable | Check dependency of SQL Providers with '_make_serializable' | |
174-
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
175173
| check-core-deprecation-classes | Verify usage of Airflow deprecation classes in core | |
176174
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
177175
| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | |
@@ -240,6 +238,8 @@ require Breeze Docker image to be built locally.
240238
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
241239
| check-setup-order | Check order of dependencies in setup.cfg and setup.py | |
242240
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
241+
| check-sql-dependency-common-data-structure | Check dependency of SQL Providers with common data structure | |
242+
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
243243
| check-start-date-not-used-in-defaults | start_date not to be defined in default_args in example_dags | |
244244
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
245245
| check-system-tests-present | Check if system tests have required segments of code | |

‎airflow/providers/common/sql/hooks/sql.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import contextlib
20+
import warnings
1921
from contextlib import closing
2022
from datetime import datetime
2123
from typing import (
@@ -24,6 +26,7 @@
2426
Callable,
2527
Generator,
2628
Iterable,
29+
List,
2730
Mapping,
2831
Protocol,
2932
Sequence,
@@ -36,7 +39,7 @@
3639
import sqlparse
3740
from sqlalchemy import create_engine
3841

39-
from airflow.exceptions import AirflowException
42+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
4043
from airflow.hooks.base import BaseHook
4144

4245
if TYPE_CHECKING:
@@ -122,10 +125,10 @@ class DbApiHook(BaseHook):
122125
"""
123126
Abstract base class for sql hooks.
124127
125-
When subclassing, maintainers can override the `_make_serializable` method:
128+
When subclassing, maintainers can override the `_make_common_data_structure` method:
126129
This method transforms the result of the handler method (typically `cursor.fetchall()`) into
127-
JSON-serializable objects. Most of the time, the underlying SQL library already returns tuples from
128-
its cursor, and the `_make_serializable` method can be ignored.
130+
objects common across all Hooks derived from this class (tuples). Most of the time, the underlying SQL
131+
library already returns tuples from its cursor, and the `_make_common_data_structure` method can be ignored.
129132
130133
:param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that
131134
if you change the schema parameter value in the constructor of the derived Hook, such change
@@ -308,7 +311,7 @@ def run(
308311
handler: Callable[[Any], T] = ...,
309312
split_statements: bool = ...,
310313
return_last: bool = ...,
311-
) -> T | list[T]:
314+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
312315
...
313316

314317
def run(
@@ -319,7 +322,7 @@ def run(
319322
handler: Callable[[Any], T] | None = None,
320323
split_statements: bool = False,
321324
return_last: bool = True,
322-
) -> T | list[T] | None:
325+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
323326
"""Run a command or a list of commands.
324327
325328
Pass a list of SQL statements to the sql parameter to get them to
@@ -395,7 +398,7 @@ def run(
395398
self._run_command(cur, sql_statement, parameters)
396399

397400
if handler is not None:
398-
result = self._make_serializable(handler(cur))
401+
result = self._make_common_data_structure(handler(cur))
399402
if return_single_query_results(sql, return_last, split_statements):
400403
_last_result = result
401404
_last_description = cur.description
@@ -415,19 +418,31 @@ def run(
415418
else:
416419
return results
417420

418-
@staticmethod
419-
def _make_serializable(result: Any) -> Any:
420-
"""Ensure the data returned from an SQL command is JSON-serializable.
421+
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple | list[tuple]:
422+
"""Ensure the data returned from an SQL command is a standard tuple or list[tuple].
421423
422424
This method is intended to be overridden by subclasses of the `DbApiHook`. Its purpose is to
423-
transform the result of an SQL command (typically returned by cursor methods) into a
424-
JSON-serializable format.
425+
transform the result of an SQL command (typically returned by cursor methods) into a common
426+
data structure (a tuple or list[tuple]) across all DBApiHook derived Hooks, as defined in the
427+
ADR-0002 of the sql provider.
428+
429+
If this method is not overridden, the result data is returned as-is. If the output of the cursor
430+
is already a common data structure, this method should be ignored.
431+
"""
432+
# Back-compatibility call for providers implementing old ´_make_serializable' method.
433+
with contextlib.suppress(AttributeError):
434+
result = self._make_serializable(result=result) # type: ignore[attr-defined]
435+
warnings.warn(
436+
"The `_make_serializable` method is deprecated and support will be removed in a future "
437+
f"version of the common.sql provider. Please update the {self.__class__.__name__}'s provider "
438+
"to a version based on common.sql >= 1.9.1.",
439+
AirflowProviderDeprecationWarning,
440+
stacklevel=2,
441+
)
425442

426-
If this method is not overridden, the result data is returned as-is.
427-
If the output of the cursor is already JSON-serializable, this method
428-
should be ignored.
429-
"""
430-
return result
443+
if isinstance(result, Sequence):
444+
return cast(List[tuple], result)
445+
return cast(tuple, result)
431446

432447
def _run_command(self, cur, sql_statement, parameters):
433448
"""Run a statement using an already open cursor."""

‎airflow/providers/common/sql/provider.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ description: |
2424
suspended: false
2525
source-date-epoch: 1701983370
2626
versions:
27+
- 1.9.1
2728
- 1.9.0
2829
- 1.8.1
2930
- 1.8.0

‎airflow/providers/databricks/hooks/databricks_sql.py

+52-13
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,32 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import warnings
20+
from collections import namedtuple
1921
from contextlib import closing
2022
from copy import copy
21-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
23+
from typing import (
24+
TYPE_CHECKING,
25+
Any,
26+
Callable,
27+
Iterable,
28+
List,
29+
Mapping,
30+
Sequence,
31+
TypeVar,
32+
cast,
33+
overload,
34+
)
2235

2336
from databricks import sql # type: ignore[attr-defined]
24-
from databricks.sql.types import Row
2537

26-
from airflow.exceptions import AirflowException
38+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
2739
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
2840
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook
2941

3042
if TYPE_CHECKING:
3143
from databricks.sql.client import Connection
44+
from databricks.sql.types import Row
3245

3346
LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
3447

@@ -52,6 +65,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
5265
on every request
5366
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
5467
:param schema: An optional initial schema to use. Requires DBR version 9.0+
68+
:param return_tuple: Return a ``namedtuple`` object instead of a ``databricks.sql.Row`` object. Default
69+
to False. In a future release of the provider, this will become True by default. This parameter
70+
ensures backward-compatibility during the transition phase to common tuple objects for all hooks based
71+
on DbApiHook. This flag will also be removed in a future release.
5572
:param kwargs: Additional parameters internal to Databricks SQL Connector parameters
5673
"""
5774

@@ -68,6 +85,7 @@ def __init__(
6885
catalog: str | None = None,
6986
schema: str | None = None,
7087
caller: str = "DatabricksSqlHook",
88+
return_tuple: bool = False,
7189
**kwargs,
7290
) -> None:
7391
super().__init__(databricks_conn_id, caller=caller)
@@ -80,8 +98,18 @@ def __init__(
8098
self.http_headers = http_headers
8199
self.catalog = catalog
82100
self.schema = schema
101+
self.return_tuple = return_tuple
83102
self.additional_params = kwargs
84103

104+
if not self.return_tuple:
105+
warnings.warn(
106+
"""Returning a raw `databricks.sql.Row` object is deprecated. A namedtuple will be
107+
returned instead in a future release of the databricks provider. Set `return_tuple=True` to
108+
enable this behavior.""",
109+
AirflowProviderDeprecationWarning,
110+
stacklevel=2,
111+
)
112+
85113
def _get_extra_config(self) -> dict[str, Any | None]:
86114
extra_params = copy(self.databricks_conn.extra_dejson)
87115
for arg in ["http_path", "session_configuration", *self.extra_parameters]:
@@ -167,7 +195,7 @@ def run(
167195
handler: Callable[[Any], T] = ...,
168196
split_statements: bool = ...,
169197
return_last: bool = ...,
170-
) -> T | list[T]:
198+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
171199
...
172200

173201
def run(
@@ -178,7 +206,7 @@ def run(
178206
handler: Callable[[Any], T] | None = None,
179207
split_statements: bool = True,
180208
return_last: bool = True,
181-
) -> T | list[T] | None:
209+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
182210
"""
183211
Run a command or a list of commands.
184212
@@ -223,7 +251,12 @@ def run(
223251
with closing(conn.cursor()) as cur:
224252
self._run_command(cur, sql_statement, parameters)
225253
if handler is not None:
226-
result = self._make_serializable(handler(cur))
254+
raw_result = handler(cur)
255+
if self.return_tuple:
256+
result = self._make_common_data_structure(raw_result)
257+
else:
258+
# Returning raw result is deprecated, and do not comply with current common.sql interface
259+
result = raw_result # type: ignore[assignment]
227260
if return_single_query_results(sql, return_last, split_statements):
228261
results = [result]
229262
self.descriptions = [cur.description]
@@ -241,14 +274,20 @@ def run(
241274
else:
242275
return results
243276

244-
@staticmethod
245-
def _make_serializable(result):
246-
"""Transform the databricks Row objects into JSON-serializable lists."""
277+
def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
278+
"""Transform the databricks Row objects into namedtuple."""
279+
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
280+
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
247281
if isinstance(result, list):
248-
return [list(row) for row in result]
249-
elif isinstance(result, Row):
250-
return list(result)
251-
return result
282+
rows: list[Row] = result
283+
rows_fields = rows[0].__fields__
284+
rows_object = namedtuple("Row", rows_fields) # type: ignore[misc]
285+
return cast(List[tuple], [rows_object(*row) for row in rows])
286+
else:
287+
row: Row = result
288+
row_fields = row.__fields__
289+
row_object = namedtuple("Row", row_fields) # type: ignore[misc]
290+
return cast(tuple, row_object(*row))
252291

253292
def bulk_dump(self, table, tmp_file):
254293
raise NotImplementedError()

‎airflow/providers/databricks/operators/databricks_sql.py

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def get_db_hook(self) -> DatabricksSqlHook:
113113
"catalog": self.catalog,
114114
"schema": self.schema,
115115
"caller": "DatabricksSqlOperator",
116+
"return_tuple": True,
116117
**self.client_parameters,
117118
**self.hook_params,
118119
}

‎airflow/providers/databricks/provider.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ versions:
5959

6060
dependencies:
6161
- apache-airflow>=2.6.0
62-
- apache-airflow-providers-common-sql>=1.8.1
62+
- apache-airflow-providers-common-sql>=1.9.1
6363
- requests>=2.27,<3
6464
# The connector 2.9.0 released on Aug 10, 2023 has a bug that it does not properly declare urllib3 and
6565
# it needs to be excluded. See https://github.com/databricks/databricks-sql-python/issues/190

‎airflow/providers/exasol/hooks/exasol.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def run(
183183
handler: Callable[[Any], T] = ...,
184184
split_statements: bool = ...,
185185
return_last: bool = ...,
186-
) -> T | list[T]:
186+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
187187
...
188188

189189
def run(
@@ -194,7 +194,7 @@ def run(
194194
handler: Callable[[Any], T] | None = None,
195195
split_statements: bool = False,
196196
return_last: bool = True,
197-
) -> T | list[T] | None:
197+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
198198
"""Run a command or a list of commands.
199199
200200
Pass a list of SQL statements to the SQL parameter to get them to
@@ -232,7 +232,7 @@ def run(
232232
with closing(conn.execute(sql_statement, parameters)) as exa_statement:
233233
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
234234
if handler is not None:
235-
result = handler(exa_statement)
235+
result = self._make_common_data_structure(handler(exa_statement))
236236
if return_single_query_results(sql, return_last, split_statements):
237237
_last_result = result
238238
_last_columns = self.get_description(exa_statement)

‎airflow/providers/exasol/provider.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ versions:
5252

5353
dependencies:
5454
- apache-airflow>=2.6.0
55-
- apache-airflow-providers-common-sql>=1.3.1
55+
- apache-airflow-providers-common-sql>=1.9.1
5656
- pyexasol>=0.5.1
5757
- pandas>=0.17.1
5858

‎airflow/providers/odbc/hooks/odbc.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
"""This module contains ODBC hook."""
1818
from __future__ import annotations
1919

20-
from typing import Any, NamedTuple
20+
from typing import Any, List, NamedTuple, Sequence, cast
2121
from urllib.parse import quote_plus
2222

23-
import pyodbc
23+
from pyodbc import Connection, Row, connect
2424

2525
from airflow.providers.common.sql.hooks.sql import DbApiHook
2626
from airflow.utils.helpers import merge_dicts
@@ -195,9 +195,9 @@ def connect_kwargs(self) -> dict:
195195

196196
return merged_connect_kwargs
197197

198-
def get_conn(self) -> pyodbc.Connection:
198+
def get_conn(self) -> Connection:
199199
"""Returns a pyodbc connection object."""
200-
conn = pyodbc.connect(self.odbc_connection_string, **self.connect_kwargs)
200+
conn = connect(self.odbc_connection_string, **self.connect_kwargs)
201201
return conn
202202

203203
@property
@@ -228,17 +228,15 @@ def get_sqlalchemy_connection(
228228
cnx = engine.connect(**(connect_kwargs or {}))
229229
return cnx
230230

231-
@staticmethod
232-
def _make_serializable(result: list[pyodbc.Row] | pyodbc.Row | None) -> list[NamedTuple] | None:
233-
"""Transform the pyodbc.Row objects returned from an SQL command into JSON-serializable NamedTuple."""
231+
def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple] | tuple:
232+
"""Transform the pyodbc.Row objects returned from an SQL command into typed NamedTuples."""
234233
# Below ignored lines respect NamedTuple docstring, but mypy do not support dynamically
235-
# instantiated Namedtuple, and will never do: https://github.com/python/mypy/issues/848
236-
columns: list[tuple[str, type]] | None = None
237-
if isinstance(result, list):
238-
columns = [col[:2] for col in result[0].cursor_description]
239-
row_object = NamedTuple("Row", columns) # type: ignore[misc]
240-
return [row_object(*row) for row in result]
241-
elif isinstance(result, pyodbc.Row):
242-
columns = [col[:2] for col in result.cursor_description]
243-
return NamedTuple("Row", columns)(*result) # type: ignore[misc, operator]
244-
return result
234+
# instantiated typed Namedtuple, and will never do: https://github.com/python/mypy/issues/848
235+
field_names: list[tuple[str, type]] | None = None
236+
if isinstance(result, Sequence):
237+
field_names = [col[:2] for col in result[0].cursor_description]
238+
row_object = NamedTuple("Row", field_names) # type: ignore[misc]
239+
return cast(List[tuple], [row_object(*row) for row in result])
240+
else:
241+
field_names = [col[:2] for col in result.cursor_description]
242+
return cast(tuple, NamedTuple("Row", field_names)(*result)) # type: ignore[misc, operator]

‎airflow/providers/odbc/provider.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ versions:
4545

4646
dependencies:
4747
- apache-airflow>=2.6.0
48-
- apache-airflow-providers-common-sql>=1.8.1
48+
- apache-airflow-providers-common-sql>=1.9.1
4949
- pyodbc
5050

5151
integrations:

‎airflow/providers/oracle/hooks/oracle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def callproc(
372372
identifier: str,
373373
autocommit: bool = False,
374374
parameters: list | dict | None = None,
375-
) -> list | dict | None:
375+
) -> list | dict | tuple | None:
376376
"""
377377
Call the stored procedure identified by the provided string.
378378

‎airflow/providers/snowflake/hooks/snowflake.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def run(
323323
split_statements: bool = ...,
324324
return_last: bool = ...,
325325
return_dictionaries: bool = ...,
326-
) -> T | list[T]:
326+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
327327
...
328328

329329
def run(
@@ -335,7 +335,7 @@ def run(
335335
split_statements: bool = True,
336336
return_last: bool = True,
337337
return_dictionaries: bool = False,
338-
) -> T | list[T] | None:
338+
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
339339
"""Runs a command or a list of commands.
340340
341341
Pass a list of SQL statements to the SQL parameter to get them to
@@ -388,7 +388,7 @@ def run(
388388
self._run_command(cur, sql_statement, parameters)
389389

390390
if handler is not None:
391-
result = handler(cur)
391+
result = self._make_common_data_structure(handler(cur))
392392
if return_single_query_results(sql, return_last, split_statements):
393393
_last_result = result
394394
_last_description = cur.description

‎airflow/providers/snowflake/provider.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ versions:
6767

6868
dependencies:
6969
- apache-airflow>=2.6.0
70-
- apache-airflow-providers-common-sql>=1.3.1
70+
- apache-airflow-providers-common-sql>=1.9.1
7171
- snowflake-connector-python>=2.7.8
7272
- snowflake-sqlalchemy>=1.1.0
7373

‎dev/breeze/src/airflow_breeze/pre_commit_ids.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
"check-builtin-literals",
3939
"check-changelog-has-no-duplicates",
4040
"check-cncf-k8s-only-for-executors",
41-
"check-common-sql-dependency-make-serializable",
4241
"check-core-deprecation-classes",
4342
"check-daysago-import-from-utils",
4443
"check-decorated-operator-implements-custom-name",
@@ -73,6 +72,7 @@
7372
"check-revision-heads-map",
7473
"check-safe-filter-usage-in-html",
7574
"check-setup-order",
75+
"check-sql-dependency-common-data-structure",
7676
"check-start-date-not-used-in-defaults",
7777
"check-system-tests-present",
7878
"check-system-tests-tocs",

‎generated/provider_dependencies.json

+4-4
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@
291291
"databricks": {
292292
"deps": [
293293
"aiohttp>=3.6.3, <4",
294-
"apache-airflow-providers-common-sql>=1.8.1",
294+
"apache-airflow-providers-common-sql>=1.9.1",
295295
"apache-airflow>=2.6.0",
296296
"databricks-sql-connector>=2.0.0, <3.0.0, !=2.9.0",
297297
"requests>=2.27,<3"
@@ -364,7 +364,7 @@
364364
},
365365
"exasol": {
366366
"deps": [
367-
"apache-airflow-providers-common-sql>=1.3.1",
367+
"apache-airflow-providers-common-sql>=1.9.1",
368368
"apache-airflow>=2.6.0",
369369
"pandas>=0.17.1",
370370
"pyexasol>=0.5.1"
@@ -653,7 +653,7 @@
653653
},
654654
"odbc": {
655655
"deps": [
656-
"apache-airflow-providers-common-sql>=1.8.1",
656+
"apache-airflow-providers-common-sql>=1.9.1",
657657
"apache-airflow>=2.6.0",
658658
"pyodbc"
659659
],
@@ -864,7 +864,7 @@
864864
},
865865
"snowflake": {
866866
"deps": [
867-
"apache-airflow-providers-common-sql>=1.3.1",
867+
"apache-airflow-providers-common-sql>=1.9.1",
868868
"apache-airflow>=2.6.0",
869869
"snowflake-connector-python>=2.7.8",
870870
"snowflake-sqlalchemy>=1.1.0"

‎images/breeze/output_static-checks.svg

+18-18
Loading
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1197108ac5d3038067e599375d5130dd
1+
6fb4fd65fb7d3b1430a7de7a17c85e22

‎scripts/ci/pre_commit/pre_commit_check_common_sql_dependency.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131

3232

3333
COMMON_SQL_PROVIDER_NAME: str = "apache-airflow-providers-common-sql"
34-
COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS: str = "1.8.1"
35-
COMMON_SQL_PROVIDER_LATEST_INCOMPATIBLE_VERSION: str = "1.8.0"
36-
MAKE_SERIALIZABLE_METHOD_NAME: str = "_make_serializable"
34+
COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS: str = "1.9.1"
35+
COMMON_SQL_PROVIDER_LATEST_INCOMPATIBLE_VERSION: str = "1.9.0"
36+
MAKE_COMMON_METHOD_NAME: str = "_make_common_data_structure"
3737

3838

3939
def get_classes(file_path: str) -> Iterable[ast.ClassDef]:
@@ -54,9 +54,9 @@ def is_subclass_of_dbapihook(node: ast.ClassDef) -> bool:
5454

5555

5656
def has_make_serializable_method(node: ast.ClassDef) -> bool:
57-
"""Return True if the given class implements `_make_serializable` method."""
57+
"""Return True if the given class implements `_make_common_data_structure` method."""
5858
for body_element in node.body:
59-
if isinstance(body_element, ast.FunctionDef) and (body_element.name == MAKE_SERIALIZABLE_METHOD_NAME):
59+
if isinstance(body_element, ast.FunctionDef) and (body_element.name == MAKE_COMMON_METHOD_NAME):
6060
return True
6161
return False
6262

@@ -109,11 +109,11 @@ def check_sql_providers_dependency():
109109
f"\n[yellow]Provider {provider_metadata['name']} must have "
110110
f"'{COMMON_SQL_PROVIDER_NAME}>={COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS}' as "
111111
f"dependency, because `{clazz.name}` overrides the "
112-
f"`{MAKE_SERIALIZABLE_METHOD_NAME}` method."
112+
f"`{MAKE_COMMON_METHOD_NAME}` method."
113113
)
114114
if error_count:
115115
console.print(
116-
f"The `{MAKE_SERIALIZABLE_METHOD_NAME}` method was introduced in {COMMON_SQL_PROVIDER_NAME} "
116+
f"The `{MAKE_COMMON_METHOD_NAME}` method was introduced in {COMMON_SQL_PROVIDER_NAME} "
117117
f"{COMMON_SQL_PROVIDER_MIN_COMPATIBLE_VERSIONS}. You cannot rely on an older version of this "
118118
"provider to override this method."
119119
)

‎tests/providers/databricks/hooks/test_databricks_sql.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#
1919
from __future__ import annotations
2020

21+
from collections import namedtuple
2122
from unittest import mock
2223
from unittest.mock import patch
2324

@@ -58,139 +59,160 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
5859
return [(field,) for field in fields]
5960

6061

62+
# Serializable Row object similar to the one returned by the Hook
63+
SerializableRow = namedtuple("Row", ["id", "value"]) # type: ignore[name-match]
64+
65+
6166
@pytest.mark.parametrize(
62-
"return_last, split_statements, sql, cursor_calls,"
67+
"return_last, split_statements, sql, cursor_calls, return_tuple,"
6368
"cursor_descriptions, cursor_results, hook_descriptions, hook_results, ",
6469
[
6570
pytest.param(
6671
True,
6772
False,
6873
"select * from test.test",
6974
["select * from test.test"],
75+
False,
7076
[["id", "value"]],
7177
([Row(id=1, value=2), Row(id=11, value=12)],),
7278
[[("id",), ("value",)]],
73-
[[1, 2], [11, 12]],
79+
[Row(id=1, value=2), Row(id=11, value=12)],
7480
id="The return_last set and no split statements set on single query in string",
7581
),
7682
pytest.param(
7783
False,
7884
False,
7985
"select * from test.test;",
8086
["select * from test.test"],
87+
False,
8188
[["id", "value"]],
8289
([Row(id=1, value=2), Row(id=11, value=12)],),
8390
[[("id",), ("value",)]],
84-
[[1, 2], [11, 12]],
91+
[Row(id=1, value=2), Row(id=11, value=12)],
8592
id="The return_last not set and no split statements set on single query in string",
8693
),
8794
pytest.param(
8895
True,
8996
True,
9097
"select * from test.test;",
9198
["select * from test.test"],
99+
False,
92100
[["id", "value"]],
93101
([Row(id=1, value=2), Row(id=11, value=12)],),
94102
[[("id",), ("value",)]],
95-
[[1, 2], [11, 12]],
103+
[Row(id=1, value=2), Row(id=11, value=12)],
96104
id="The return_last set and split statements set on single query in string",
97105
),
98106
pytest.param(
99107
False,
100108
True,
101109
"select * from test.test;",
102110
["select * from test.test"],
111+
False,
103112
[["id", "value"]],
104113
([Row(id=1, value=2), Row(id=11, value=12)],),
105114
[[("id",), ("value",)]],
106-
[[[1, 2], [11, 12]]],
115+
[[Row(id=1, value=2), Row(id=11, value=12)]],
107116
id="The return_last not set and split statements set on single query in string",
108117
),
109118
pytest.param(
110119
True,
111120
True,
112121
"select * from test.test;select * from test.test2;",
113122
["select * from test.test", "select * from test.test2"],
123+
False,
114124
[["id", "value"], ["id2", "value2"]],
115125
([Row(id=1, value=2), Row(id=11, value=12)], [Row(id=3, value=4), Row(id=13, value=14)]),
116126
[[("id2",), ("value2",)]],
117-
[[3, 4], [13, 14]],
127+
[Row(id=3, value=4), Row(id=13, value=14)],
118128
id="The return_last set and split statements set on multiple queries in string",
119129
),
120130
pytest.param(
121131
False,
122132
True,
123133
"select * from test.test;select * from test.test2;",
124134
["select * from test.test", "select * from test.test2"],
135+
False,
125136
[["id", "value"], ["id2", "value2"]],
126137
([Row(id=1, value=2), Row(id=11, value=12)], [Row(id=3, value=4), Row(id=13, value=14)]),
127138
[[("id",), ("value",)], [("id2",), ("value2",)]],
128-
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
139+
[
140+
[Row(id=1, value=2), Row(id=11, value=12)],
141+
[Row(id=3, value=4), Row(id=13, value=14)],
142+
],
129143
id="The return_last not set and split statements set on multiple queries in string",
130144
),
131145
pytest.param(
132146
True,
133147
True,
134148
["select * from test.test;"],
135149
["select * from test.test"],
150+
False,
136151
[["id", "value"]],
137152
([Row(id=1, value=2), Row(id=11, value=12)],),
138153
[[("id",), ("value",)]],
139-
[[[1, 2], [11, 12]]],
154+
[[Row(id=1, value=2), Row(id=11, value=12)]],
140155
id="The return_last set on single query in list",
141156
),
142157
pytest.param(
143158
False,
144159
True,
145160
["select * from test.test;"],
146161
["select * from test.test"],
162+
False,
147163
[["id", "value"]],
148164
([Row(id=1, value=2), Row(id=11, value=12)],),
149165
[[("id",), ("value",)]],
150-
[[[1, 2], [11, 12]]],
166+
[[Row(id=1, value=2), Row(id=11, value=12)]],
151167
id="The return_last not set on single query in list",
152168
),
153169
pytest.param(
154170
True,
155171
True,
156172
"select * from test.test;select * from test.test2;",
157173
["select * from test.test", "select * from test.test2"],
174+
False,
158175
[["id", "value"], ["id2", "value2"]],
159176
([Row(id=1, value=2), Row(id=11, value=12)], [Row(id=3, value=4), Row(id=13, value=14)]),
160177
[[("id2",), ("value2",)]],
161-
[[3, 4], [13, 14]],
178+
[Row(id=3, value=4), Row(id=13, value=14)],
162179
id="The return_last set on multiple queries in list",
163180
),
164181
pytest.param(
165182
False,
166183
True,
167184
"select * from test.test;select * from test.test2;",
168185
["select * from test.test", "select * from test.test2"],
186+
False,
169187
[["id", "value"], ["id2", "value2"]],
170188
([Row(id=1, value=2), Row(id=11, value=12)], [Row(id=3, value=4), Row(id=13, value=14)]),
171189
[[("id",), ("value",)], [("id2",), ("value2",)]],
172-
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
190+
[
191+
[Row(id=1, value=2), Row(id=11, value=12)],
192+
[Row(id=3, value=4), Row(id=13, value=14)],
193+
],
173194
id="The return_last not set on multiple queries not set",
174195
),
175196
pytest.param(
176197
True,
177198
False,
178199
"select * from test.test",
179200
["select * from test.test"],
201+
True,
180202
[["id", "value"]],
181203
(Row(id=1, value=2),),
182204
[[("id",), ("value",)]],
183-
[1, 2],
205+
SerializableRow(1, 2),
184206
id="The return_last set and no split statements set on single query in string",
185207
),
186208
],
187209
)
188210
def test_query(
189-
databricks_hook,
190211
return_last,
191212
split_statements,
192213
sql,
193214
cursor_calls,
215+
return_tuple,
194216
cursor_descriptions,
195217
cursor_results,
196218
hook_descriptions,
@@ -227,6 +249,7 @@ def test_query(
227249
cursors.append(cur)
228250
connections.append(conn)
229251
mock_conn.side_effect = connections
252+
databricks_hook = DatabricksSqlHook(sql_endpoint_name="Test", return_tuple=return_tuple)
230253
results = databricks_hook.run(
231254
sql=sql, handler=fetch_all_handler, return_last=return_last, split_statements=split_statements
232255
)

‎tests/providers/databricks/operators/test_databricks_sql.py

+2
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
133133
db_mock_class.assert_called_once_with(
134134
DEFAULT_CONN_ID,
135135
http_path=None,
136+
return_tuple=True,
136137
session_configuration=None,
137138
sql_endpoint_name=None,
138139
http_headers=None,
@@ -276,6 +277,7 @@ def test_exec_write_file(
276277
db_mock_class.assert_called_once_with(
277278
DEFAULT_CONN_ID,
278279
http_path=None,
280+
return_tuple=True,
279281
session_configuration=None,
280282
sql_endpoint_name=None,
281283
http_headers=None,

‎tests/providers/odbc/hooks/test_odbc.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,9 @@ def test_pyodbc_mock(self):
310310
"""
311311
assert hasattr(pyodbc.Row, "cursor_description")
312312

313-
def test_query_return_serializable_result_with_fetchall(self, pyodbc_row_mock):
313+
def test_query_return_serializable_result_with_fetchall(
314+
self, pyodbc_row_mock, monkeypatch, pyodbc_instancecheck
315+
):
314316
"""
315317
Simulate a cursor.fetchall which returns an iterable of pyodbc.Row object, and check if this iterable
316318
get converted into a list of tuples.
@@ -322,7 +324,9 @@ def mock_handler(*_):
322324
return pyodbc_result
323325

324326
hook = self.get_hook()
325-
result = hook.run("SQL", handler=mock_handler)
327+
with monkeypatch.context() as patcher:
328+
patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
329+
result = hook.run("SQL", handler=mock_handler)
326330
assert hook_result == result
327331

328332
def test_query_return_serializable_result_with_fetchone(

0 commit comments

Comments
 (0)
Please sign in to comment.