-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathduckdb.py
157 lines (132 loc) · 6.06 KB
/
duckdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations
import socket
import sqlalchemy
from duckdb_provider.hooks.duckdb_hook import DuckDBHook
from sqlalchemy import MetaData as SqlaMetaData
from sqlalchemy.sql.schema import Table as SqlaTable
from astro.constants import MergeConflictStrategy
from astro.databases.base import BaseDatabase
from astro.options import LoadOptions
from astro.table import BaseTable, Metadata
from astro.utils.compat.functools import cached_property
DEFAULT_CONN_ID = DuckDBHook.default_conn_name
class DuckdbDatabase(BaseDatabase):
"""Handle interactions with Duckdb databases."""
def __init__(
self,
conn_id: str = DEFAULT_CONN_ID,
table: BaseTable | None = None,
load_options: LoadOptions | None = None,
):
super().__init__(conn_id)
self.table = table
self.load_options = load_options
@property
def sql_type(self) -> str:
return "duckdb"
# We are caching this property to persist the DuckDB in-memory connection, to avoid
# the problem described in
# https://github.com/astronomer/astro-sdk/issues/1831
@cached_property
def connection(self) -> sqlalchemy.engine.base.Connection: # skipcq PYL-W0236
"""Return a Sqlalchemy connection object for the given database."""
return self.sqlalchemy_engine.connect()
@cached_property
def hook(self) -> DuckDBHook:
"""Retrieve Airflow hook to interface with the DuckDB database."""
return DuckDBHook(duckdb_conn_id=self.conn_id)
@property
def default_metadata(self) -> Metadata:
"""Since Duckdb does not use Metadata, we return an empty Metadata instances."""
return Metadata()
# ---------------------------------------------------------
# Table metadata
# ---------------------------------------------------------
@staticmethod
def get_table_qualified_name(table: BaseTable) -> str:
"""
Return the table qualified name.
:param table: The table we want to retrieve the qualified name for.
"""
return str(table.name)
def populate_table_metadata(self, table: BaseTable) -> BaseTable:
"""
Since Duckdb does not have a concept of databases or schemas, we just return the table as is,
without any modifications.
"""
table.conn_id = table.conn_id or self.conn_id
return table
def create_schema_if_needed(self, schema: str | None) -> None:
"""
Since Duckdb does not have schemas, we do not need to set a schema here.
"""
@staticmethod
def get_merge_initialization_query(parameters: tuple) -> str:
"""
Handles database-specific logic to handle index for DuckDB.
"""
joined_parameters = ",".join(parameters)
return f"CREATE UNIQUE INDEX merge_index ON {{{{table}}}}({joined_parameters})"
def merge_table(
self,
source_table: BaseTable,
target_table: BaseTable,
source_to_target_columns_map: dict[str, str],
target_conflict_columns: list[str],
if_conflicts: MergeConflictStrategy = "exception",
) -> None:
"""
Merge the source table rows into a destination table.
The argument `if_conflicts` allows the user to define how to handle conflicts.
:param source_table: Contains the rows to be merged to the target_table
:param target_table: Contains the destination table in which the rows will be merged
:param source_to_target_columns_map: Dict of target_table columns names to source_table columns names
:param target_conflict_columns: List of cols where we expect to have a conflict while combining
:param if_conflicts: The strategy to be applied if there are conflicts.
"""
statement = "INSERT INTO {main_table} ({target_columns}) SELECT {append_columns} FROM {source_table} Where true"
if if_conflicts == "ignore":
statement += " ON CONFLICT ({merge_keys}) DO NOTHING"
elif if_conflicts == "update":
statement += " ON CONFLICT ({merge_keys}) DO UPDATE SET {update_statements}"
append_column_names = list(source_to_target_columns_map.keys())
target_column_names = list(source_to_target_columns_map.values())
update_statements = [f"{col_name}=EXCLUDED.{col_name}" for col_name in target_column_names]
query = statement.format(
target_columns=",".join(target_column_names),
main_table=target_table.name,
append_columns=",".join(append_column_names),
source_table=source_table.name,
update_statements=",".join(update_statements),
merge_keys=",".join(list(target_conflict_columns)),
)
self.run_sql(sql=query)
def get_sqla_table(self, table: BaseTable) -> SqlaTable:
"""
Return SQLAlchemy table instance
:param table: Astro Table to be converted to SQLAlchemy table instance
"""
return SqlaTable(table.name, SqlaMetaData(), autoload_with=self.sqlalchemy_engine)
def openlineage_dataset_name(self, table: BaseTable) -> str:
"""
Returns the open lineage dataset name as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
Example: /tmp/local.duckdb.table_name
"""
conn = self.hook.get_connection(self.conn_id)
return f"{conn.host}.{table.name}"
def openlineage_dataset_namespace(self) -> str:
"""
Returns the open lineage dataset namespace as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
Example: duckdb://127.0.0.1:22
"""
conn = self.hook.get_connection(self.conn_id)
port = conn.port or 22
return f"file://{socket.gethostbyname(socket.gethostname())}:{port}"
def openlineage_dataset_uri(self, table: BaseTable) -> str:
"""
Returns the open lineage dataset uri as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name(table=table)}"