Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions bmsdna/lakeapi/context/df_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@

from sqlglot import Dialect
from bmsdna.lakeapi.core.types import FileTypes, OperatorType
from typing import Sequence, Literal, Optional, List, Tuple, Any, TYPE_CHECKING, Union
from typing import (
Callable,
Sequence,
Literal,
Optional,
List,
Tuple,
Any,
TYPE_CHECKING,
Union,
cast,
)
import pyarrow as pa
import sqlglot.expressions as ex

Expand Down Expand Up @@ -35,6 +46,7 @@ def get_sql(
limit: int | None = None,
*,
dialect: str | Dialect,
modifier: Optional[Callable[[ex.Query], ex.Query]] = None,
) -> str:
if not isinstance(sql_or_pypika, str) and dialect == "tsql":
from_ = sql_or_pypika.args.get(
Expand All @@ -60,10 +72,20 @@ def get_sql(
)
)
if isinstance(sql_or_pypika, str):
return sql_or_pypika
if not modifier:
return sql_or_pypika
else:
import sqlglot

sql_or_pypika = cast(
ex.Query, sqlglot.parse_one(sql_or_pypika, dialect=dialect)
)

if len(sql_or_pypika.expressions) == 0:
sql_or_pypika = sql_or_pypika.select("*")
assert not isinstance(sql_or_pypika, str)
if modifier:
sql_or_pypika = modifier(sql_or_pypika)
return sql_or_pypika.sql(dialect=dialect)


Expand Down Expand Up @@ -111,7 +133,7 @@ async def to_json(self):

return pydantic.TypeAdapter(list[dict]).dump_json(full_list)

async def to_ndjson(self):
async def to_ndjson(self) -> str:
import polars as pl

result_strings = []
Expand Down Expand Up @@ -250,7 +272,7 @@ def get_pyarrow_dataset(
case "delta":
from bmsdna.lakeapi.utils.meta_cache import get_deltalake_meta

meta = get_deltalake_meta(uri)
meta = get_deltalake_meta(self.engine_name == "polars", uri)
assert meta.protocol is not None
if meta.protocol["minReaderVersion"] > 1:
raise ValueError(
Expand Down Expand Up @@ -368,7 +390,7 @@ def get_modified_date(
try:
from bmsdna.lakeapi.utils.meta_cache import get_deltalake_meta

meta = get_deltalake_meta(uri)
meta = get_deltalake_meta(self.engine_name == "polars", uri)
return meta.last_write_time
except FileNotFoundError:
return None
Expand Down
47 changes: 41 additions & 6 deletions bmsdna/lakeapi/context/df_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import duckdb
import pyarrow.dataset
import sqlglot.expressions as ex
from sqlglot import from_, parse_one
import sqlglot as sg
import os
from datetime import timezone
from bmsdna.lakeapi.core.config import SearchConfig
Expand All @@ -33,6 +33,10 @@
AZURE_LOADED_SCRIPTS: list[str] = []


def _to_json(query: ex.Query):
return sg.select("to_json(t)", dialect="duckdb").from_(query.subquery("t"))


def _get_temp_table_name():
return "temp_" + str(uuid4()).replace("-", "")

Expand All @@ -55,11 +59,11 @@ def columns(self):

def query_builder(self) -> ex.Select:
if not isinstance(self.original_sql, str):
return from_(self.original_sql.subquery())
return sg.from_(self.original_sql.subquery())
else:
return from_(
return sg.from_(
cast(
ex.Select, parse_one(self.original_sql, dialect="duckdb")
ex.Select, sg.parse_one(self.original_sql, dialect="duckdb")
).subquery()
)

Expand Down Expand Up @@ -119,9 +123,26 @@ async def write_parquet(self, file_name: str):

await run_in_threadpool(self.con.execute, full_query)

async def to_ndjson(self) -> str:
query = get_sql(self.original_sql, dialect="duckdb")

query = get_sql(self.original_sql, dialect="duckdb", modifier=_to_json)
await run_in_threadpool(self.con.execute, query)
res = []
while chunk := self.con.fetchmany(self.chunk_size):
for item in chunk:
res.append(item[0])
return "\n".join(res)

async def write_nd_json(self, file_name: str):
if not ENABLE_COPY_TO:
return await super().write_nd_json(file_name)
query = get_sql(self.original_sql, dialect="duckdb", modifier=_to_json)
await run_in_threadpool(self.con.execute, query)
with open(file_name, "w", encoding="utf-8") as f:
while chunk := self.con.fetchmany(self.chunk_size):
for item in chunk:
f.write(item[0] + "\n")
return
query = get_sql(self.original_sql, dialect="duckdb")
uuidstr = _get_temp_table_name()
full_query = f"""CREATE TEMP VIEW {uuidstr} AS {query};
Expand All @@ -144,8 +165,22 @@ async def write_csv(self, file_name: str, *, separator: str):

async def write_json(self, file_name: str):
if not ENABLE_COPY_TO:
return await super().write_json(file_name)
query = get_sql(self.original_sql, dialect="duckdb", modifier=_to_json)
await run_in_threadpool(self.con.execute, query)
with open(file_name, "w", encoding="utf-8") as f:
f.write("[\n")
first = True
while chunk := self.con.fetchmany(self.chunk_size):
for item in chunk:
if not first:
f.write(",\n")
else:
first = False
f.write(",\n".join([r[0] for r in item]))
f.write("\n]")
return
query = get_sql(self.original_sql, dialect="duckdb")

uuidstr = _get_temp_table_name()
full_query = f"""CREATE TEMP VIEW {uuidstr} AS {query};
COPY (SELECT *FROM {uuidstr})
Expand Down
2 changes: 0 additions & 2 deletions bmsdna/lakeapi/context/df_odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from uuid import uuid4
from .source_uri import SourceUri

ENABLE_COPY_TO = os.environ.get("ENABLE_COPY_TO", "0") == "1"


def _get_temp_table_name():
return "temp_" + str(uuid4()).replace("-", "")
Expand Down
7 changes: 5 additions & 2 deletions bmsdna/lakeapi/context/source_uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def exists(self) -> bool:
return fs.exists(fs_path)

def copy_to_local(
self, local_path: str, delta_table: Union[bool, Literal["meta"]] = False
self,
local_path: str,
delta_table: Union[bool, Literal["meta"]] = False,
use_polars=False,
):
local_uri = SourceUri(
uri=local_path,
Expand Down Expand Up @@ -136,7 +139,7 @@ def copy_to_local(
if delta_table:
from bmsdna.lakeapi.utils.meta_cache import get_deltalake_meta

meta = get_deltalake_meta(self)
meta = get_deltalake_meta(use_polars, self)
vnr = meta.version
if local_versions.get(self.uri) == vnr:
return local_uri
Expand Down
4 changes: 3 additions & 1 deletion bmsdna/lakeapi/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def _from_dict(cls, config: Dict, basic_config: BasicConfig, accounts: dict):
try:
from bmsdna.lakeapi.utils.meta_cache import get_deltalake_meta

meta = get_deltalake_meta(uri_obj)
meta = get_deltalake_meta(
basic_config.default_engine == "polars", uri_obj
)
assert meta.last_metadata is not None
cfg = json.loads(
_to_dict(meta.last_metadata.get("configuration", {})).get(
Expand Down
5 changes: 4 additions & 1 deletion bmsdna/lakeapi/core/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def get_execution_uri(self, meta_only: bool):
"meta"
if (meta_only and self.config.file_type == "delta")
else (self.config.file_type == "delta"),
self.basic_config.default_engine == "polars",
)
if (
meta_only
Expand All @@ -183,7 +184,9 @@ def get_delta_table(self, schema_only: bool):
try:
from bmsdna.lakeapi.utils.meta_cache import get_deltalake_meta

meta = get_deltalake_meta(self.source_uri)
meta = get_deltalake_meta(
self.basic_config.default_engine == "polars", self.source_uri
)
return meta
except FileNotFoundError:
return None
Expand Down
2 changes: 1 addition & 1 deletion bmsdna/lakeapi/core/partition_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _with_implicit_parameters(
try:
from bmsdna.lakeapi.utils.meta_cache import get_deltalake_meta

meta = get_deltalake_meta(uri)
meta = get_deltalake_meta(basic_config.default_engine == "polars", uri)
assert meta.last_metadata is not None
part_cols = meta.last_metadata.get("partitionColumns", [])
if part_cols and len(part_cols) > 0:
Expand Down
32 changes: 25 additions & 7 deletions bmsdna/lakeapi/utils/meta_cache.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
import os
import duckdb
from bmsdna.lakeapi.context.source_uri import SourceUri
from deltalake2db import (
get_deltalake_meta as _get_deltalake_meta,
PolarsMetaEngine,
DuckDBMetaEngine,
DeltaTableMeta,
duckdb_apply_storage_options,
)
from typing import Optional


_cached_meta: dict[SourceUri, DeltaTableMeta] = {}

_global_duck_con: Optional[duckdb.DuckDBPyConnection] = None


def get_deltalake_meta(use_polars: bool, uri: SourceUri):
global _global_duck_con
if use_polars:
ab_uri, ab_opts = uri.get_uri_options(flavor="object_store")

meta_engine = PolarsMetaEngine(ab_opts)
else:
if _global_duck_con is None:
_global_duck_con = duckdb.connect(":memory:")
ab_uri, ab_opts = uri.get_uri_options(flavor="original")

if not uri.is_local():
duckdb_apply_storage_options(
_global_duck_con,
ab_uri,
ab_opts,
use_fsspec=os.getenv("DUCKDB_DELTA_USE_FSSPEC", "0") == "1",
)
meta_engine = DuckDBMetaEngine(_global_duck_con)

def get_deltalake_meta(uri: SourceUri):
ab_uri, ab_opts = uri.get_uri_options(flavor="object_store")
meta_engine = (
PolarsMetaEngine(ab_opts)
if not uri.is_local()
else DuckDBMetaEngine(duckdb.default_connection())
)
if mt := _cached_meta.get(uri):
mt.update_incremental(meta_engine)
return mt
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
[project]
name = "bmsdna-lakeapi"
version = "0.26.0"
version = "0.27.0"
description = ""
authors = [{ name = "DWH Team", email = "[email protected]" }]
dependencies = [
"pyyaml ~=6.0",
"duckdb >=1.1.0,<2",
"polars >=1.12.0,<2",
"sqlglot >=24.0.0",
"fastexcel >=0.10.4",
"argon2-cffi >=23.1.0,<24",
"xlsxwriter >=3.1.0,<4",
"pyjwt >=2.6.0,<3",
"ruamel.yaml >=0.18.5",
"fastapi >=0.110.0",
Expand All @@ -27,7 +24,7 @@ validate_lakeapi_schema = "bmsdna.lakeapi.tools.validateschema:validate_schema_c
add_lakeapi_user = "bmsdna.lakeapi.tools.useradd:useradd_cli"

[project.optional-dependencies]
polars = ["fastexcel"]
polars = ["fastexcel", "polars >=1.12.0,<2", "xlsxwriter >=3.1.0,<4"]
auth = ["argon2-cffi", "pyjwt"]
useradd = ["ruamel.yaml"]
odbc = ["arrow-odbc"]
Expand Down Expand Up @@ -93,4 +90,5 @@ test = [
"python-dotenv >=1.0.1,<2",
"deltalake>=1.1.4",
"pytest-asyncio>=0.23.8",
"polars >=1.12.0,<2",
]
Loading
Loading