Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an ability to use scan queries via new YDB operator #42311

Merged
merged 7 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions airflow/providers/ydb/hooks/_vendor/dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ def __init__(
self.interactive_transaction: bool = False # AUTOCOMMIT
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
self.tx_context: Optional[ydb.TxContext] = None
self.use_scan_query: bool = False

def cursor(self):
return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context, self.table_path_prefix)
return self._cursor_class(
self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix
)

def describe(self, table_path: str) -> ydb.TableDescription:
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
Expand Down Expand Up @@ -115,9 +118,15 @@ def get_isolation_level(self) -> str:
else:
raise NotSupportedError(f"{self.tx_mode.name} is not supported")

def set_ydb_scan_query(self, value: bool) -> None:
self.use_scan_query = value

def get_ydb_scan_query(self) -> bool:
return self.use_scan_query

def begin(self):
self.tx_context = None
if self.interactive_transaction:
if self.interactive_transaction and not self.use_scan_query:
session = self._maybe_await(self.session_pool.acquire)
self.tx_context = session.transaction(self.tx_mode)
self._maybe_await(self.tx_context.begin)
Expand Down
63 changes: 57 additions & 6 deletions airflow/providers/ydb/hooks/_vendor/dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import functools
import hashlib
import itertools
import logging
import posixpath
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from collections.abc import AsyncIterator
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Sequence,
Union,
)

import ydb
import ydb.aio
Expand All @@ -21,8 +30,6 @@
ProgrammingError,
)

logger = logging.getLogger(__name__)


def get_column_type(type_obj: Any) -> str:
return str(ydb.convert.type_to_native(type_obj))
Expand Down Expand Up @@ -77,14 +84,18 @@ def wrapper(*args, **kwargs):
class Cursor:
def __init__(
self,
driver: Union[ydb.Driver, ydb.aio.Driver],
potiuk marked this conversation as resolved.
Show resolved Hide resolved
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
tx_mode: ydb.AbstractTransactionModeBuilder,
tx_context: Optional[ydb.BaseTxContext] = None,
use_scan_query: bool = False,
table_path_prefix: str = "",
):
self.driver = driver
self.session_pool = session_pool
self.tx_mode = tx_mode
self.tx_context = tx_context
self.use_scan_query = use_scan_query
self.description = None
self.arraysize = 1
self.rows = None
Expand Down Expand Up @@ -117,9 +128,10 @@ def get_table_names(self, abs_dir_path: str) -> List[str]:
def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None):
query = self._get_ydb_query(operation)

logger.info("execute sql: %s, params: %s", query, parameters)
if operation.is_ddl:
chunks = self._execute_ddl(query)
elif self.use_scan_query:
chunks = self._execute_scan_query(query, parameters)
else:
chunks = self._execute_dml(query, parameters)

Expand Down Expand Up @@ -164,6 +176,21 @@ def _make_data_query(
name = hashlib.sha256(yql_with_params.encode("utf-8")).hexdigest()
return ydb.DataQuery(yql_text, parameters_types, name=name)

@_handle_ydb_errors
def _execute_scan_query(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
) -> Generator[ydb.convert.ResultSet, None, None]:
prepared_query = query
if isinstance(query, str) and parameters:
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query)

if isinstance(query, str):
scan_query = ydb.ScanQuery(query, None)
else:
scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types)

return self._execute_scan_query_in_driver(scan_query, parameters)

@_handle_ydb_errors
def _execute_dml(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
Expand Down Expand Up @@ -219,6 +246,15 @@ def _execute_in_session(
) -> ydb.convert.ResultSets:
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
) -> Generator[ydb.convert.ResultSet, None, None]:
chunk: ydb.ScanQueryResult
for chunk in self.driver.table_client.scan_query(scan_query, parameters):
yield chunk.result_set

def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
return callee(self.tx_context, *args, **kwargs)

Expand Down Expand Up @@ -264,7 +300,7 @@ def executescript(self, script):
return self.execute(script)

def fetchone(self):
return next(self.rows or [], None)
return next(self.rows or iter([]), None)

def fetchmany(self, size=None):
return list(itertools.islice(self.rows, size or self.arraysize))
Expand Down Expand Up @@ -328,6 +364,21 @@ async def _execute_in_session(
) -> ydb.convert.ResultSets:
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
) -> Generator[ydb.convert.ResultSet, None, None]:
iterator: AsyncIterator[ydb.ScanQueryResult] = self._await(
self.driver.table_client.scan_query(scan_query, parameters)
)
while True:
try:
result = self._await(iterator.__anext__())
yield result.result_set
except StopAsyncIteration:
break

def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs):
return self._await(callee(self.tx_context, *args, **kwargs))

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/ydb/hooks/_vendor/readme.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
dbapi is extracted from https://github.com/ydb-platform/ydb-sqlalchemy/releases/tag/0.0.1b17 (Apache License 2.0) to avoid dependency on sqlalchemy package ver > 2.
dbapi is extracted from https://github.com/ydb-platform/ydb-sqlalchemy/releases/tag/0.0.1b22 (Apache License 2.0) to avoid dependency on sqlalchemy package ver > 2.
_vendor could be removed in favor of ydb-sqlalchemy package after switching Airflow core to sqlalchemy > 2 (related issue https://github.com/apache/airflow/issues/28723).
Another option is to wait for separate package for ydb-dbapi: https://github.com/ydb-platform/ydb-sqlalchemy/issues/46 and switch to it afterwards.
9 changes: 6 additions & 3 deletions airflow/providers/ydb/hooks/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ def description(self):
class YDBConnection:
"""YDB connection wrapper."""

def __init__(self, ydb_session_pool: Any, is_ddl: bool):
def __init__(self, ydb_session_pool: Any, is_ddl: bool, use_scan_query: bool):
self.is_ddl = is_ddl
self.use_scan_query = use_scan_query
self.delegatee: DbApiConnection = DbApiConnection(ydb_session_pool=ydb_session_pool)
self.delegatee.set_ydb_scan_query(use_scan_query)

def cursor(self) -> YDBCursor:
return YDBCursor(self.delegatee.cursor(), is_ddl=self.is_ddl)
Expand Down Expand Up @@ -134,9 +136,10 @@ class YDBHook(DbApiHook):
supports_autocommit: bool = True
supports_executemany: bool = True

def __init__(self, *args, is_ddl: bool = False, **kwargs) -> None:
def __init__(self, *args, is_ddl: bool = False, use_scan_query: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.is_ddl = is_ddl
self.use_scan_query = use_scan_query

conn: Connection = self.get_connection(self.get_conn_id())
host: str | None = conn.host
Expand Down Expand Up @@ -234,7 +237,7 @@ def sqlalchemy_url(self) -> URL:

def get_conn(self) -> YDBConnection:
"""Establish a connection to a YDB database."""
return YDBConnection(self.ydb_session_pool, is_ddl=self.is_ddl)
return YDBConnection(self.ydb_session_pool, is_ddl=self.is_ddl, use_scan_query=self.use_scan_query)

@staticmethod
def _serialize_cell(cell: object, conn: YDBConnection | None = None) -> Any:
Expand Down
30 changes: 30 additions & 0 deletions airflow/providers/ydb/operators/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,33 @@ def __init__(
kwargs["hook_params"] = {"is_ddl": is_ddl, **hook_params}

super().__init__(conn_id=ydb_conn_id, sql=sql, parameters=parameters, **kwargs)


class YDBScanQueryOperator(SQLExecuteQueryOperator):
"""
Executes scan query in a specific YDB database.

:param sql: the SQL code to be executed as a single string, or
a list of str (sql statements), or a reference to a template file.
Template references are recognized by str ending in '.sql'
:param ydb_conn_id: The :ref:`ydb conn id <howto/connection:ydb>`
reference to a specific YDB cluster and database.
:param parameters: (optional) the parameters to render the SQL query with.
"""

ui_color = "#ededed"

def __init__(
self,
sql: str | list[str],
ydb_conn_id: str = "ydb_default",
parameters: Mapping | Iterable | None = None,
**kwargs,
) -> None:
if parameters is not None:
raise AirflowException("parameters are not supported yet")

hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {"use_scan_query": True, **hook_params}

super().__init__(conn_id=ydb_conn_id, sql=sql, parameters=parameters, **kwargs)
2 changes: 1 addition & 1 deletion docs/apache-airflow-providers-ydb/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
:caption: Guides

Connection types <connections/ydb>
YDBExecuteQueryOperator types <operators/ydb_operator_howto_guide>
Operator types <operators/ydb_operator_howto_guide>


.. toctree::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
.. _howto/operators:ydb:

How-to Guide for YDB using YDBExecuteQueryOperator
How-to Guide for using YDB Operators
==================================================

Introduction
Expand All @@ -29,7 +29,7 @@ workflow. Airflow is essentially a graph (Directed Acyclic Graph) made up of tas
A task defined or implemented by a operator is a unit of work in your data pipeline.

The purpose of this guide is to define tasks involving interactions with a YDB database with
the :class:`~airflow.providers.ydb.operators.YDBExecuteQueryOperator`.
the :class:`~airflow.providers.ydb.operators.YDBExecuteQueryOperator` and :class:`~airflow.providers.ydb.operators.YDBScanQueryOperator`.

Common database operations with YDBExecuteQueryOperator
-------------------------------------------------------
Expand Down Expand Up @@ -162,6 +162,26 @@ by creating a sql file.
)
Executing Scan Queries with YDBScanQueryOperator
-------------------------------------------------------

YDBScanQueryOperator executes YDB Scan Queries, which designed primarily for running analytical ad hoc queries. Parameters of the operators are:

- ``sql`` - string with query;
- ``conn_id`` - YDB connection id. Default value is ``ydb_default``;
- ``params`` - parameters to be injected into query if it is Jinja template, more details about :doc:`params <apache-airflow:core-concepts/params>`

Example of using YDBScanQueryOperator:

.. code-block:: python
get_birth_date_scan = YDBScanQueryOperator(
task_id="get_birth_date_scan",
sql="sql/birth_date.sql",
params={"begin_date": "2020-01-01", "end_date": "2020-12-31"},
)
The complete YDB Operator DAG
-----------------------------

Expand All @@ -176,7 +196,7 @@ When we put everything together, our DAG should look like this:
Conclusion
----------

In this how-to guide we explored the Apache Airflow YDBExecuteQueryOperator to connect to YDB database. Let's quickly highlight the key takeaways.
In this how-to guide we explored the Apache Airflow YDBExecuteQueryOperator and YDBScanQueryOperator to connect to YDB database. Let's quickly highlight the key takeaways.
It is best practice to create subdirectory called ``sql`` in your ``dags`` directory where you can store your sql files.
This will make your code more elegant and more maintainable.
And finally, we looked at the templated version of sql script and usage of ``params`` attribute.
11 changes: 10 additions & 1 deletion tests/system/providers/ydb/example_ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow import DAG
from airflow.decorators import task
from airflow.providers.ydb.hooks.ydb import YDBHook
from airflow.providers.ydb.operators.ydb import YDBExecuteQueryOperator
from airflow.providers.ydb.operators.ydb import YDBExecuteQueryOperator, YDBScanQueryOperator

# [START ydb_operator_howto_guide]

Expand Down Expand Up @@ -101,12 +101,21 @@ def populate_pet_table_via_bulk_upsert():
)
# [END ydb_operator_howto_guide_get_birth_date]

# [START ydb_operator_howto_guide_get_birth_date_scan]
get_birth_date_scan = YDBScanQueryOperator(
task_id="get_birth_date_scan",
sql="SELECT * FROM pet WHERE birth_date BETWEEN '{{params.begin_date}}' AND '{{params.end_date}}'",
params={"begin_date": "2020-01-01", "end_date": "2020-12-31"},
)
# [END ydb_operator_howto_guide_get_birth_date_scan]

(
create_pet_table
>> populate_pet_table
>> populate_pet_table_via_bulk_upsert()
>> get_all_pets
>> get_birth_date
>> get_birth_date_scan
)
# [END ydb_operator_howto_guide]

Expand Down