Skip to content

Add PDS-DS Query 1 #19131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: branch-25.08
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""
Experimental PDS-DS benchmarks.

Based on https://github.com/pola-rs/polars-benchmark.

WARNING: This is an experimental (and unofficial)
benchmark script. It is not intended for public use
and may be modified or removed at any time.
"""

from __future__ import annotations

import contextlib
import importlib
import os
import time
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING

import polars as pl

with contextlib.suppress(ImportError):
from cudf_polars.experimental.benchmarks.utils import (
Record,
RunConfig,
get_executor_options,
parse_args,
run_polars,
)

if TYPE_CHECKING:
from collections.abc import Sequence
from types import ModuleType

# Without this setting, the first IO task to run
# on each worker takes ~15 sec extra
os.environ["KVIKIO_COMPAT_MODE"] = os.environ.get("KVIKIO_COMPAT_MODE", "on")
os.environ["KVIKIO_NTHREADS"] = os.environ.get("KVIKIO_NTHREADS", "8")


def valid_query(name: str) -> bool:
"""Return True for valid query names eg. 'q9', 'q65', etc."""
if not name.startswith("q"):
return False
try:
q_num = int(name[1:])
except ValueError:
return False
else:
return 1 <= q_num <= 99


class PDSDSQueriesMeta(type):
"""Metaclass used for query lookup."""

def __getattr__(cls, name: str): # type: ignore
"""Query lookup."""
if valid_query(name):
q_num = int(name[1:])
module: ModuleType = importlib.import_module(
f"cudf_polars.experimental.benchmarks.pdsds_queries.q{q_num}"
)
return getattr(module, cls.q_impl)
raise AttributeError(f"{name} is not a valid query name")


class PDSDSQueries(metaclass=PDSDSQueriesMeta):
"""Base class for query loading."""

q_impl: str


class PDSDSPolarsQueries(PDSDSQueries):
"""Polars Queries."""

q_impl = "polars_impl"


class PDSDSDuckDBQueries(PDSDSQueries):
"""DuckDB Queries."""

q_impl = "duckdb_impl"


def execute_duckdb_query(query: str, dataset_path: Path) -> pl.DataFrame:
"""Execute a query with DuckDB."""
import duckdb

conn = duckdb.connect()

statements = [
f"CREATE VIEW {table.stem} as SELECT * FROM read_parquet('{table.absolute()}');"
for table in Path(dataset_path).glob("*.parquet")
]
statements.append(query)
return conn.execute("\n".join(statements)).pl()


def run_duckdb(options: Sequence[str] | None = None) -> None:
"""Run the benchmark with DuckDB."""
args = parse_args(options, num_queries=99)
run_config = RunConfig.from_args(args)
records: defaultdict[int, list[Record]] = defaultdict(list)

for q_id in run_config.queries:
try:
duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
except AttributeError as err:
raise NotImplementedError(f"Query {q_id} not implemented.") from err

print(f"DuckDB Executing: {q_id}")
records[q_id] = []

for i in range(args.iterations):
t0 = time.time()

result = execute_duckdb_query(duckdb_query, run_config.dataset_path)

t1 = time.time()
record = Record(query=q_id, duration=t1 - t0)
if args.print_results:
print(result)

print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
records[q_id].append(record)


def run_validate(options: Sequence[str] | None = None) -> None:
"""Validate Polars CPU vs DuckDB or Polars GPU."""
from polars.testing import assert_frame_equal

args = parse_args(options, num_queries=99)
run_config = RunConfig.from_args(args)

baseline = args.baseline
if baseline not in {"duckdb", "cpu"}:
raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")

failures: list[int] = []

engine: pl.GPUEngine | None = None
if run_config.executor != "cpu":
engine = pl.GPUEngine(
raise_on_fail=True,
executor=run_config.executor,
executor_options=get_executor_options(run_config, PDSDSPolarsQueries),
)

for q_id in run_config.queries:
print(f"\nValidating Query {q_id}")
try:
polars_query = getattr(PDSDSPolarsQueries, f"q{q_id}")(run_config)
duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
except AttributeError as err:
raise NotImplementedError(f"Query {q_id} not implemented.") from err

if baseline == "duckdb":
base_result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
elif baseline == "cpu":
base_result = polars_query.collect(new_streaming=True)

if run_config.executor == "cpu":
test_result = polars_query.collect(new_streaming=True)
else:
test_result = polars_query.collect(engine=engine)

try:
assert_frame_equal(
base_result,
test_result,
check_dtypes=True,
check_column_order=False,
)
print(f"✅ Query {q_id} passed validation.")
except AssertionError as e:
failures.append(q_id)
print(f"❌ Query {q_id} failed validation:\n{e}")
if args.print_results:
print("Baseline Result:\n", base_result)
print("Test Result:\n", test_result)

if failures:
print("\nValidation Summary:")
print("===================")
print(f"{len(failures)} query(s) failed: {failures}")
else:
print("\nAll queries passed validation.")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Run PDS-DS benchmarks.")
parser.add_argument(
"--engine",
choices=["polars", "duckdb", "validate"],
default="polars",
help="Which engine to use for executing the benchmarks or to validate results.",
)
args, extra_args = parser.parse_known_args()

if args.engine == "polars":
run_polars(PDSDSPolarsQueries, extra_args, num_queries=99)
elif args.engine == "duckdb":
run_duckdb(extra_args)
elif args.engine == "validate":
run_validate(extra_args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""DuckDB and Polars queries."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Query 1."""

from __future__ import annotations

from typing import TYPE_CHECKING

import polars as pl

from cudf_polars.experimental.benchmarks.utils import get_data

if TYPE_CHECKING:
from cudf_polars.experimental.benchmarks.utils import RunConfig


def duckdb_impl(run_config: RunConfig) -> str:
"""Query 1."""
return """
WITH customer_total_return
AS (SELECT sr_customer_sk AS ctr_customer_sk,
sr_store_sk AS ctr_store_sk,
Sum(sr_return_amt) AS ctr_total_return
FROM store_returns,
date_dim
WHERE sr_returned_date_sk = d_date_sk
AND d_year = 2001
GROUP BY sr_customer_sk,
sr_store_sk)
SELECT c_customer_id
FROM customer_total_return ctr1,
store,
customer
WHERE ctr1.ctr_total_return > (SELECT Avg(ctr_total_return) * 1.2
FROM customer_total_return ctr2
WHERE ctr1.ctr_store_sk = ctr2.ctr_store_sk)
AND s_store_sk = ctr1.ctr_store_sk
AND s_state = 'TN'
AND ctr1.ctr_customer_sk = c_customer_sk
ORDER BY c_customer_id
LIMIT 100;
"""


def polars_impl(run_config: RunConfig) -> pl.LazyFrame:
"""Query 1."""
store_returns = get_data(
run_config.dataset_path, "store_returns", run_config.suffix
)
date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)
store = get_data(run_config.dataset_path, "store", run_config.suffix)
customer = get_data(run_config.dataset_path, "customer", run_config.suffix)

# Step 1: Create customer_total_return CTE equivalent
customer_total_return = (
store_returns.join(
date_dim, left_on="sr_returned_date_sk", right_on="d_date_sk"
)
.filter(pl.col("d_year") == 2001)
.group_by(["sr_customer_sk", "sr_store_sk"])
.agg(pl.col("sr_return_amt").sum().alias("ctr_total_return"))
.rename(
{
"sr_customer_sk": "ctr_customer_sk",
"sr_store_sk": "ctr_store_sk",
}
)
)

# Step 2: Calculate average return per store for the subquery
store_avg_returns = customer_total_return.group_by("ctr_store_sk").agg(
[(pl.col("ctr_total_return").mean() * 1.2).alias("avg_return_threshold")]
)

# Step 3: Join everything together and apply filters
return (
customer_total_return.join(
store_avg_returns, left_on="ctr_store_sk", right_on="ctr_store_sk"
)
.filter(pl.col("ctr_total_return") > pl.col("avg_return_threshold"))
.join(store, left_on="ctr_store_sk", right_on="s_store_sk")
.filter(pl.col("s_state") == "TN")
.join(customer, left_on="ctr_customer_sk", right_on="c_customer_sk")
.select(["c_customer_id"])
.sort("c_customer_id")
.limit(100)
)
Loading