Skip to content

Used TypeDict for CachingVisitor.state #19135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: branch-25.08
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cf2a77b
Used TypeDict for CachingVisitor.state
TomAugspurger Jun 11, 2025
9a828f5
fixed generic syntax
TomAugspurger Jun 11, 2025
2d70672
Try typing-extensions
TomAugspurger Jun 11, 2025
4ed5168
Attempted packaging fix
TomAugspurger Jun 12, 2025
b492498
Revert "Attempted packaging fix"
TomAugspurger Jun 12, 2025
2077154
Embed the marker
TomAugspurger Jun 12, 2025
07acbc0
Merge remote-tracking branch 'upstream/branch-25.08' into tom/typed-s…
TomAugspurger Jun 16, 2025
36efb76
Coverage ignore
TomAugspurger Jun 16, 2025
b51ce87
More coverage skips
TomAugspurger Jun 16, 2025
68b4a3c
wip -- parital
TomAugspurger Jun 17, 2025
8859f6f
Visitor-specific types
TomAugspurger Jun 17, 2025
65344f5
Merge remote-tracking branch 'upstream/branch-25.08' into tom/typed-s…
TomAugspurger Jun 17, 2025
443758f
linting
TomAugspurger Jun 18, 2025
6a6c933
docs
TomAugspurger Jun 18, 2025
9048d6d
doc fixes
TomAugspurger Jun 18, 2025
aaa305a
Merge remote-tracking branch 'upstream/branch-25.08' into tom/typed-s…
TomAugspurger Jun 18, 2025
1a7622d
docs
TomAugspurger Jun 18, 2025
729b0b8
simplify
TomAugspurger Jun 18, 2025
73fae84
Reorganize
TomAugspurger Jun 18, 2025
a19936a
Remove unnecessary StateT
TomAugspurger Jun 18, 2025
cd78544
doc fixes
TomAugspurger Jun 18, 2025
53fe0d6
Merge remote-tracking branch 'upstream/branch-25.08' into tom/typed-s…
TomAugspurger Jun 18, 2025
282dd28
Get typing_extensions.TypeDict for 3.10
TomAugspurger Jun 18, 2025
bf8cd79
Merge remote-tracking branch 'upstream/branch-25.08' into tom/typed-s…
TomAugspurger Jun 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conda/recipes/cudf-polars/recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ requirements:
- pylibcudf =${{ version }}
- polars >=1.24,<1.31
- ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }}
- if: python == "3.10"
then: typing_extensions
ignore_run_exports:
by_name:
- cuda-version
Expand Down
6 changes: 6 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,12 @@ dependencies:
- output_types: [conda, requirements, pyproject]
packages:
- polars>=1.25,<1.31
specific:
- output_types: [requirements, pyproject]
matrices:
- matrix: null
packages:
- "typing-extensions; python_version < '3.11'"
run_cudf_polars_experimental:
common:
- output_types: [conda, requirements, pyproject]
Expand Down
58 changes: 44 additions & 14 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from functools import partial, reduce, singledispatch
from typing import TYPE_CHECKING, TypeAlias
from typing import TYPE_CHECKING, TypeAlias, TypedDict

import pylibcudf as plc
from pylibcudf import expressions as plc_expr
Expand All @@ -18,7 +18,6 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from cudf_polars.typing import ExprTransformer

# Can't merge these op-mapping dictionaries because scoped enum values
# are exposed by cython with equality/hash based one their underlying
Expand Down Expand Up @@ -91,7 +90,43 @@
}


Transformer: TypeAlias = GenericTransformer[expr.Expr, plc_expr.Expression]
class ASTState(TypedDict):
"""
State for AST transformations.

Parameters
----------
for_parquet
Indicator for whether this transformation should provide an expression
suitable for use in parquet filters.
"""

for_parquet: bool


class ExprTransformerState(TypedDict):
"""
State used for AST transformation when inserting column references.

Parameters
----------
name_to_index
Mapping from column names to column indices in the table
eventually used for evaluation.
table_ref
pylibcudf `TableReference` indicating whether column
references are coming from the left or right table.
"""

name_to_index: Mapping[str, int]
table_ref: plc.expressions.TableReference


Transformer: TypeAlias = GenericTransformer[expr.Expr, plc_expr.Expression, ASTState]
ExprTransformer: TypeAlias = GenericTransformer[
expr.Expr, expr.Expr, ExprTransformerState
]
"""Protocol for transformation of Expr nodes."""


@singledispatch
Expand All @@ -104,14 +139,8 @@ def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression:
node
Expression to translate.
self
Recursive transformer. The state dictionary should contain a
`for_parquet` key indicating if this transformation should
provide an expression suitable for use in parquet filters.

If `for_parquet` is `False`, the dictionary should contain a
`name_to_index` mapping that maps column names to their
integer index in the table that will be used for evaluation of
the expression.
Recursive transformer. The state dictionary is an instance of
:class:`ASTState`.

Returns
-------
Expand Down Expand Up @@ -240,7 +269,7 @@ def to_parquet_filter(node: expr.Expr) -> plc_expr.Expression | None:
-------
pylibcudf Expression if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(_to_ast, state={"for_parquet": True})
mapper = CachingVisitor(_to_ast, state=ASTState(for_parquet=True))
try:
return mapper(node)
except (KeyError, NotImplementedError):
Expand All @@ -266,7 +295,7 @@ def to_ast(node: expr.Expr) -> plc_expr.Expression | None:
-------
pylibcudf Expression if conversion is possible, otherwise None.
"""
mapper = CachingVisitor(_to_ast, state={"for_parquet": False})
mapper = CachingVisitor(_to_ast, state=ASTState(for_parquet=False))
try:
return mapper(node)
except (KeyError, NotImplementedError):
Expand Down Expand Up @@ -315,6 +344,7 @@ def insert_colrefs(
New expression with column references inserted.
"""
mapper = CachingVisitor(
_insert_colrefs, state={"table_ref": table_ref, "name_to_index": name_to_index}
_insert_colrefs,
state=ExprTransformerState(name_to_index=name_to_index, table_ref=table_ref),
)
return mapper(node)
24 changes: 15 additions & 9 deletions python/cudf_polars/cudf_polars/dsl/traversal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Traversal and visitor utilities for nodes."""
Expand All @@ -7,7 +7,11 @@

from typing import TYPE_CHECKING, Any, Generic

from cudf_polars.typing import U_contra, V_co
from cudf_polars.typing import (
StateT_co,
U_contra,
V_co,
)

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence
Expand Down Expand Up @@ -49,7 +53,9 @@ def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
lifo.append(child)


def reuse_if_unchanged(node: NodeT, fn: GenericTransformer[NodeT, NodeT]) -> NodeT:
def reuse_if_unchanged(
node: NodeT, fn: GenericTransformer[NodeT, NodeT, StateT_co]
) -> NodeT:
"""
Recipe for transforming nodes that returns the old object if unchanged.

Expand Down Expand Up @@ -77,10 +83,10 @@ def reuse_if_unchanged(node: NodeT, fn: GenericTransformer[NodeT, NodeT]) -> Nod


def make_recursive(
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co]], V_co],
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co, StateT_co]], V_co],
*,
state: Mapping[str, Any] | None = None,
) -> GenericTransformer[U_contra, V_co]:
) -> GenericTransformer[U_contra, V_co, StateT_co]:
"""
No-op wrapper for recursive visitors.

Expand Down Expand Up @@ -120,7 +126,7 @@ def rec(node: U_contra) -> V_co:
return rec # type: ignore[return-value]


class CachingVisitor(Generic[U_contra, V_co]):
class CachingVisitor(Generic[U_contra, V_co, StateT_co]):
"""
Caching wrapper for recursive visitors.

Expand Down Expand Up @@ -148,13 +154,13 @@ class CachingVisitor(Generic[U_contra, V_co]):

def __init__(
self,
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co]], V_co],
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co, StateT_co]], V_co],
*,
state: Mapping[str, Any] | None = None,
state: StateT_co,
) -> None:
self.fn = fn
self.cache: MutableMapping[U_contra, V_co] = {}
self.state = state if state is not None else {}
self.state = state

def __call__(self, value: U_contra) -> V_co:
"""
Expand Down
25 changes: 20 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/utils/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,34 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic

from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged
from cudf_polars.typing import NodeT, TypedDict

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

from cudf_polars.typing import GenericTransformer, NodeT
from cudf_polars.typing import GenericTransformer

__all__ = ["replace"]


def _replace(node: NodeT, fn: GenericTransformer[NodeT, NodeT]) -> NodeT:
class State(Generic[NodeT], TypedDict):
"""
State used when replacing nodes in expressions.

Parameters
----------
replacements
Mapping from nodes to be replaced to their replacements.
This state is generic over the type of these nodes.
"""

replacements: Mapping[NodeT, NodeT]


def _replace(node: NodeT, fn: GenericTransformer[NodeT, NodeT, State]) -> NodeT:
try:
return fn.state["replacements"][node]
except KeyError:
Expand All @@ -40,7 +55,7 @@ def replace(nodes: Sequence[NodeT], replacements: Mapping[NodeT, NodeT]) -> list
list
Of nodes with replacements performed.
"""
mapper: GenericTransformer[NodeT, NodeT] = CachingVisitor(
_replace, state={"replacements": replacements}
mapper: GenericTransformer[NodeT, NodeT, State] = CachingVisitor(
_replace, state=State({"replacements": replacements})
)
return [mapper(node) for node in nodes]
29 changes: 22 additions & 7 deletions python/cudf_polars/cudf_polars/experimental/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Multi-partition dispatch functions."""

from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict

from cudf_polars.typing import GenericTransformer

if TYPE_CHECKING:
from collections.abc import MutableMapping
from typing import TypeAlias

from cudf_polars.dsl import ir
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.base import PartitionInfo
from cudf_polars.typing import GenericTransformer
from cudf_polars.utils.config import ConfigOptions


class State(TypedDict):
"""
State used for lowering IR nodes.

Parameters
----------
config_options
GPUEngine configuration options.
"""

config_options: ConfigOptions


LowerIRTransformer: TypeAlias = (
"GenericTransformer[IR, tuple[IR, MutableMapping[IR, PartitionInfo]]]"
)
LowerIRTransformer: TypeAlias = GenericTransformer[
"ir.IR", "tuple[ir.IR, MutableMapping[ir.IR, PartitionInfo]]", State
]
"""Protocol for Lowering IR nodes."""


Expand Down
42 changes: 30 additions & 12 deletions python/cudf_polars/cudf_polars/experimental/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import operator
from functools import reduce
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeAlias, TypedDict

import pylibcudf as plc

Expand All @@ -42,7 +42,7 @@
from cudf_polars.dsl.expressions.binaryop import BinOp
from cudf_polars.dsl.expressions.literal import Literal
from cudf_polars.dsl.expressions.unary import Cast, UnaryFunction
from cudf_polars.dsl.ir import Distinct, Empty, HConcat, Select
from cudf_polars.dsl.ir import IR, Distinct, Empty, HConcat, Select
from cudf_polars.dsl.traversal import (
CachingVisitor,
)
Expand All @@ -53,17 +53,33 @@

if TYPE_CHECKING:
from collections.abc import Generator, MutableMapping, Sequence
from typing import TypeAlias

from cudf_polars.dsl.expressions.base import Expr
from cudf_polars.dsl.ir import IR
from cudf_polars.typing import GenericTransformer, Schema
from cudf_polars.utils.config import ConfigOptions


ExprDecomposer: TypeAlias = (
"GenericTransformer[Expr, tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]]"
)
class State(TypedDict):
"""
State for decomposing expressions.

Parameters
----------
input_ir
IR of the input expression.
input_partition_info
Partition info of the input expression.
"""

input_ir: IR
input_partition_info: PartitionInfo
config_options: ConfigOptions
unique_names: Generator[str, None, None]


ExprDecomposer: TypeAlias = "GenericTransformer[Expr, tuple[Expr, IR, MutableMapping[IR, PartitionInfo]], State]"
"""Protocol for decomposing expressions."""


def select(
Expand Down Expand Up @@ -509,13 +525,15 @@ def decompose_expr_graph(
-----
This function recursively decomposes ``named_expr.value`` and
``input_ir`` into multiple partition-wise stages.

The state dictionary is an instance of :class:`State`.
"""
state = {
"input_ir": input_ir,
"input_partition_info": partition_info[input_ir],
"config_options": config_options,
"unique_names": unique_names((named_expr.name, *input_ir.schema.keys())),
}
state = State(
input_ir=input_ir,
input_partition_info=partition_info[input_ir],
config_options=config_options,
unique_names=unique_names((named_expr.name, *input_ir.schema.keys())),
)
mapper = CachingVisitor(_decompose, state=state)
expr, input_ir, partition_info = mapper(named_expr.value)
return named_expr.reconstruct(expr), input_ir, partition_info
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, get_key_name
from cudf_polars.experimental.dispatch import (
State,
generate_ir_tasks,
lower_ir_node,
)
Expand Down Expand Up @@ -81,7 +82,7 @@ def lower_ir_graph(
--------
lower_ir_node
"""
mapper = CachingVisitor(lower_ir_node, state={"config_options": config_options})
mapper = CachingVisitor(lower_ir_node, state=State(config_options=config_options))
return mapper(ir)


Expand Down
Loading