diff --git a/airflow/providers/ydb/hooks/_vendor/dbapi/connection.py b/airflow/providers/ydb/hooks/_vendor/dbapi/connection.py index 148b0a78f681..fa0941e99a71 100644 --- a/airflow/providers/ydb/hooks/_vendor/dbapi/connection.py +++ b/airflow/providers/ydb/hooks/_vendor/dbapi/connection.py @@ -57,9 +57,12 @@ def __init__( self.interactive_transaction: bool = False # AUTOCOMMIT self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() self.tx_context: Optional[ydb.TxContext] = None + self.use_scan_query: bool = False def cursor(self): - return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context, self.table_path_prefix) + return self._cursor_class( + self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix + ) def describe(self, table_path: str) -> ydb.TableDescription: abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path) @@ -115,9 +118,15 @@ def get_isolation_level(self) -> str: else: raise NotSupportedError(f"{self.tx_mode.name} is not supported") + def set_ydb_scan_query(self, value: bool) -> None: + self.use_scan_query = value + + def get_ydb_scan_query(self) -> bool: + return self.use_scan_query + def begin(self): self.tx_context = None - if self.interactive_transaction: + if self.interactive_transaction and not self.use_scan_query: session = self._maybe_await(self.session_pool.acquire) self.tx_context = session.transaction(self.tx_mode) self._maybe_await(self.tx_context.begin) diff --git a/airflow/providers/ydb/hooks/_vendor/dbapi/cursor.py b/airflow/providers/ydb/hooks/_vendor/dbapi/cursor.py index 7fb50e11c80a..22dfdaacb2c0 100644 --- a/airflow/providers/ydb/hooks/_vendor/dbapi/cursor.py +++ b/airflow/providers/ydb/hooks/_vendor/dbapi/cursor.py @@ -3,9 +3,18 @@ import functools import hashlib import itertools -import logging import posixpath -from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +from collections.abc import AsyncIterator +from typing import ( + Any, + Dict, + Generator, + List, + Mapping, + Optional, + Sequence, + Union, +) import ydb import ydb.aio @@ -21,8 +30,6 @@ ProgrammingError, ) -logger = logging.getLogger(__name__) - def get_column_type(type_obj: Any) -> str: return str(ydb.convert.type_to_native(type_obj)) @@ -77,14 +84,18 @@ def wrapper(*args, **kwargs): class Cursor: def __init__( self, + driver: Union[ydb.Driver, ydb.aio.Driver], session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], tx_mode: ydb.AbstractTransactionModeBuilder, tx_context: Optional[ydb.BaseTxContext] = None, + use_scan_query: bool = False, table_path_prefix: str = "", ): + self.driver = driver self.session_pool = session_pool self.tx_mode = tx_mode self.tx_context = tx_context + self.use_scan_query = use_scan_query self.description = None self.arraysize = 1 self.rows = None @@ -117,9 +128,10 @@ def get_table_names(self, abs_dir_path: str) -> List[str]: def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): query = self._get_ydb_query(operation) - logger.info("execute sql: %s, params: %s", query, parameters) if operation.is_ddl: chunks = self._execute_ddl(query) + elif self.use_scan_query: + chunks = self._execute_scan_query(query, parameters) else: chunks = self._execute_dml(query, parameters) @@ -164,6 +176,21 @@ def _make_data_query( name = hashlib.sha256(yql_with_params.encode("utf-8")).hexdigest() return ydb.DataQuery(yql_text, parameters_types, name=name) + @_handle_ydb_errors + def _execute_scan_query( + self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None + ) -> Generator[ydb.convert.ResultSet, None, None]: + prepared_query = query + if isinstance(query, str) and parameters: + prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query) + + if isinstance(query, str): + scan_query = ydb.ScanQuery(query, None) + else: + scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types) + + return self._execute_scan_query_in_driver(scan_query, parameters) + @_handle_ydb_errors def _execute_dml( self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None @@ -219,6 +246,15 @@ def _execute_in_session( ) -> ydb.convert.ResultSets: return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) + def _execute_scan_query_in_driver( + self, + scan_query: ydb.ScanQuery, + parameters: Optional[Mapping[str, Any]], + ) -> Generator[ydb.convert.ResultSet, None, None]: + chunk: ydb.ScanQueryResult + for chunk in self.driver.table_client.scan_query(scan_query, parameters): + yield chunk.result_set + def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): return callee(self.tx_context, *args, **kwargs) @@ -264,7 +300,7 @@ def executescript(self, script): return self.execute(script) def fetchone(self): - return next(self.rows or [], None) + return next(self.rows or iter([]), None) def fetchmany(self, size=None): return list(itertools.islice(self.rows, size or self.arraysize)) @@ -328,6 +364,21 @@ async def _execute_in_session( ) -> ydb.convert.ResultSets: return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) + def _execute_scan_query_in_driver( + self, + scan_query: ydb.ScanQuery, + parameters: Optional[Mapping[str, Any]], + ) -> Generator[ydb.convert.ResultSet, None, None]: + iterator: AsyncIterator[ydb.ScanQueryResult] = self._await( + self.driver.table_client.scan_query(scan_query, parameters) + ) + while True: + try: + result = self._await(iterator.__anext__()) + yield result.result_set + except StopAsyncIteration: + break + def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): return self._await(callee(self.tx_context, *args, **kwargs)) diff --git a/airflow/providers/ydb/hooks/_vendor/readme.md b/airflow/providers/ydb/hooks/_vendor/readme.md index 3336923e0541..14a2585b69d5 100644 --- a/airflow/providers/ydb/hooks/_vendor/readme.md +++ b/airflow/providers/ydb/hooks/_vendor/readme.md @@ -1,3 +1,3 @@ -dbapi is extracted from https://github.com/ydb-platform/ydb-sqlalchemy/releases/tag/0.0.1b17 (Apache License 2.0) to avoid dependency on sqlalchemy package ver > 2. +dbapi is extracted from https://github.com/ydb-platform/ydb-sqlalchemy/releases/tag/0.0.1b22 (Apache License 2.0) to avoid dependency on sqlalchemy package ver > 2. _vendor could be removed in favor of ydb-sqlalchemy package after switching Airflow core to sqlalchemy > 2 (related issue https://github.com/apache/airflow/issues/28723). Another option is to wait for separate package for ydb-dbapi: https://github.com/ydb-platform/ydb-sqlalchemy/issues/46 and switch to it afterwards. diff --git a/airflow/providers/ydb/hooks/ydb.py b/airflow/providers/ydb/hooks/ydb.py index e4580212d8d6..19740d40e6dd 100644 --- a/airflow/providers/ydb/hooks/ydb.py +++ b/airflow/providers/ydb/hooks/ydb.py @@ -95,9 +95,11 @@ def description(self): class YDBConnection: """YDB connection wrapper.""" - def __init__(self, ydb_session_pool: Any, is_ddl: bool): + def __init__(self, ydb_session_pool: Any, is_ddl: bool, use_scan_query: bool): self.is_ddl = is_ddl + self.use_scan_query = use_scan_query self.delegatee: DbApiConnection = DbApiConnection(ydb_session_pool=ydb_session_pool) + self.delegatee.set_ydb_scan_query(use_scan_query) def cursor(self) -> YDBCursor: return YDBCursor(self.delegatee.cursor(), is_ddl=self.is_ddl) @@ -134,9 +136,10 @@ class YDBHook(DbApiHook): supports_autocommit: bool = True supports_executemany: bool = True - def __init__(self, *args, is_ddl: bool = False, **kwargs) -> None: + def __init__(self, *args, is_ddl: bool = False, use_scan_query: bool = False, **kwargs) -> None: super().__init__(*args, **kwargs) self.is_ddl = is_ddl + self.use_scan_query = use_scan_query conn: Connection = self.get_connection(self.get_conn_id()) host: str | None = conn.host @@ -234,7 +237,7 @@ def sqlalchemy_url(self) -> URL: def get_conn(self) -> YDBConnection: """Establish a connection to a YDB database.""" - return YDBConnection(self.ydb_session_pool, is_ddl=self.is_ddl) + return YDBConnection(self.ydb_session_pool, is_ddl=self.is_ddl, use_scan_query=self.use_scan_query) @staticmethod def _serialize_cell(cell: object, conn: YDBConnection | None = None) -> Any: diff --git a/airflow/providers/ydb/operators/ydb.py b/airflow/providers/ydb/operators/ydb.py index 1867f227b4eb..787abf8cbd82 100644 --- a/airflow/providers/ydb/operators/ydb.py +++ b/airflow/providers/ydb/operators/ydb.py @@ -52,3 +52,33 @@ def __init__( kwargs["hook_params"] = {"is_ddl": is_ddl, **hook_params} super().__init__(conn_id=ydb_conn_id, sql=sql, parameters=parameters, **kwargs) + + +class YDBScanQueryOperator(SQLExecuteQueryOperator): + """ + Executes scan query in a specific YDB database. + + :param sql: the SQL code to be executed as a single string, or + a list of str (sql statements), or a reference to a template file. + Template references are recognized by str ending in '.sql' + :param ydb_conn_id: The :ref:`ydb conn id ` + reference to a specific YDB cluster and database. + :param parameters: (optional) the parameters to render the SQL query with. + """ + + ui_color = "#ededed" + + def __init__( + self, + sql: str | list[str], + ydb_conn_id: str = "ydb_default", + parameters: Mapping | Iterable | None = None, + **kwargs, + ) -> None: + if parameters is not None: + raise AirflowException("parameters are not supported yet") + + hook_params = kwargs.pop("hook_params", {}) + kwargs["hook_params"] = {"use_scan_query": True, **hook_params} + + super().__init__(conn_id=ydb_conn_id, sql=sql, parameters=parameters, **kwargs) diff --git a/docs/apache-airflow-providers-ydb/index.rst b/docs/apache-airflow-providers-ydb/index.rst index 25edc87386c4..45cbd2b00b7a 100644 --- a/docs/apache-airflow-providers-ydb/index.rst +++ b/docs/apache-airflow-providers-ydb/index.rst @@ -34,7 +34,7 @@ :caption: Guides Connection types - YDBExecuteQueryOperator types + Operator types .. toctree:: diff --git a/docs/apache-airflow-providers-ydb/operators/ydb_operator_howto_guide.rst b/docs/apache-airflow-providers-ydb/operators/ydb_operator_howto_guide.rst index 3b7aaab44c2b..894be8101d5b 100644 --- a/docs/apache-airflow-providers-ydb/operators/ydb_operator_howto_guide.rst +++ b/docs/apache-airflow-providers-ydb/operators/ydb_operator_howto_guide.rst @@ -17,7 +17,7 @@ .. _howto/operators:ydb: -How-to Guide for YDB using YDBExecuteQueryOperator +How-to Guide for using YDB Operators ================================================== Introduction @@ -29,7 +29,7 @@ workflow. Airflow is essentially a graph (Directed Acyclic Graph) made up of tas A task defined or implemented by a operator is a unit of work in your data pipeline. The purpose of this guide is to define tasks involving interactions with a YDB database with -the :class:`~airflow.providers.ydb.operators.YDBExecuteQueryOperator`. +the :class:`~airflow.providers.ydb.operators.YDBExecuteQueryOperator` and :class:`~airflow.providers.ydb.operators.YDBScanQueryOperator`. Common database operations with YDBExecuteQueryOperator ------------------------------------------------------- @@ -162,6 +162,26 @@ by creating a sql file. ) +Executing Scan Queries with YDBScanQueryOperator +------------------------------------------------------- + +YDBScanQueryOperator executes YDB Scan Queries, which designed primarily for running analytical ad hoc queries. Parameters of the operators are: + +- ``sql`` - string with query; +- ``conn_id`` - YDB connection id. Default value is ``ydb_default``; +- ``params`` - parameters to be injected into query if it is Jinja template, more details about :doc:`params ` + +Example of using YDBScanQueryOperator: + +.. code-block:: python + + get_birth_date_scan = YDBScanQueryOperator( + task_id="get_birth_date_scan", + sql="sql/birth_date.sql", + params={"begin_date": "2020-01-01", "end_date": "2020-12-31"}, + ) + + The complete YDB Operator DAG ----------------------------- @@ -176,7 +196,7 @@ When we put everything together, our DAG should look like this: Conclusion ---------- -In this how-to guide we explored the Apache Airflow YDBExecuteQueryOperator to connect to YDB database. Let's quickly highlight the key takeaways. +In this how-to guide we explored the Apache Airflow YDBExecuteQueryOperator and YDBScanQueryOperator to connect to YDB database. Let's quickly highlight the key takeaways. It is best practice to create subdirectory called ``sql`` in your ``dags`` directory where you can store your sql files. This will make your code more elegant and more maintainable. And finally, we looked at the templated version of sql script and usage of ``params`` attribute. diff --git a/tests/system/providers/ydb/example_ydb.py b/tests/system/providers/ydb/example_ydb.py index 8d43b6199abb..39156328f241 100644 --- a/tests/system/providers/ydb/example_ydb.py +++ b/tests/system/providers/ydb/example_ydb.py @@ -24,7 +24,7 @@ from airflow import DAG from airflow.decorators import task from airflow.providers.ydb.hooks.ydb import YDBHook -from airflow.providers.ydb.operators.ydb import YDBExecuteQueryOperator +from airflow.providers.ydb.operators.ydb import YDBExecuteQueryOperator, YDBScanQueryOperator # [START ydb_operator_howto_guide] @@ -101,12 +101,21 @@ def populate_pet_table_via_bulk_upsert(): ) # [END ydb_operator_howto_guide_get_birth_date] + # [START ydb_operator_howto_guide_get_birth_date_scan] + get_birth_date_scan = YDBScanQueryOperator( + task_id="get_birth_date_scan", + sql="SELECT * FROM pet WHERE birth_date BETWEEN '{{params.begin_date}}' AND '{{params.end_date}}'", + params={"begin_date": "2020-01-01", "end_date": "2020-12-31"}, + ) + # [END ydb_operator_howto_guide_get_birth_date_scan] + ( create_pet_table >> populate_pet_table >> populate_pet_table_via_bulk_upsert() >> get_all_pets >> get_birth_date + >> get_birth_date_scan ) # [END ydb_operator_howto_guide]