Skip to content

Commit cf2a77b

Browse files
committed
Used TypeDict for CachingVisitor.state
This changes `CachingVisitor.state` to used a `TypedDict`. Previously, we used a `Mapping[str, Any]`, which had two problems: 1. Risks typos in the key names causing unexpected KeyErrors 2. The `Any` type for the values of `state` mean that, at least by default, you won't get any type checking on things using values looked up from `state`. Unit tests should catch all these, but this can provide some earlier feedback on any issues. The `total=False` keyword is used in the typeddict to indicate that these keys are all optional. `mypy` *doesn't* error if you do a lookup on an optional field of a TypedDict, so we do still need to handle missing keys like we have been.
1 parent ca6e9e7 commit cf2a77b

File tree

5 files changed

+37
-13
lines changed

5 files changed

+37
-13
lines changed

python/cudf_polars/cudf_polars/dsl/traversal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""Traversal and visitor utilities for nodes."""
@@ -7,7 +7,7 @@
77

88
from typing import TYPE_CHECKING, Any, Generic
99

10-
from cudf_polars.typing import U_contra, V_co
10+
from cudf_polars.typing import CachingVisitorState, U_contra, V_co
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence
@@ -150,11 +150,11 @@ def __init__(
150150
self,
151151
fn: Callable[[U_contra, GenericTransformer[U_contra, V_co]], V_co],
152152
*,
153-
state: Mapping[str, Any] | None = None,
153+
state: CachingVisitorState | None = None,
154154
) -> None:
155155
self.fn = fn
156156
self.cache: MutableMapping[U_contra, V_co] = {}
157-
self.state = state if state is not None else {}
157+
self.state = state if state is not None else CachingVisitorState()
158158

159159
def __call__(self, value: U_contra) -> V_co:
160160
"""

python/cudf_polars/cudf_polars/experimental/expressions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
from cudf_polars.dsl.expressions.base import Expr
5959
from cudf_polars.dsl.ir import IR
60-
from cudf_polars.typing import GenericTransformer, Schema
60+
from cudf_polars.typing import CachingVisitorState, GenericTransformer, Schema
6161
from cudf_polars.utils.config import ConfigOptions
6262

6363

@@ -510,7 +510,7 @@ def decompose_expr_graph(
510510
This function recursively decomposes ``named_expr.value`` and
511511
``input_ir`` into multiple partition-wise stages.
512512
"""
513-
state = {
513+
state: CachingVisitorState = {
514514
"input_ir": input_ir,
515515
"input_partition_info": partition_info[input_ir],
516516
"config_options": config_options,

python/cudf_polars/cudf_polars/typing/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir
2222

2323
if TYPE_CHECKING:
24-
from collections.abc import Callable, Mapping
24+
from collections.abc import Callable, Generator, Mapping
2525
from typing import TypeAlias
2626

2727
import pylibcudf as plc
2828

2929
from cudf_polars.containers import DataFrame, DataType
3030
from cudf_polars.dsl import expr, ir, nodebase
31+
from cudf_polars.experimental.base import PartitionInfo
32+
from cudf_polars.utils.config import ConfigOptions
33+
3134

3235
__all__: list[str] = [
3336
"ClosedInterval",
@@ -158,7 +161,7 @@ def __call__(self, __value: U_contra) -> V_co:
158161
...
159162

160163
@property
161-
def state(self) -> Mapping[str, Any]:
164+
def state(self) -> CachingVisitorState:
162165
"""Arbitrary immutable state."""
163166
...
164167

@@ -198,3 +201,16 @@ class DataFrameHeader(TypedDict):
198201

199202
columns_kwargs: list[ColumnOptions]
200203
frame_count: int
204+
205+
206+
class CachingVisitorState[NodeT](TypedDict, total=False):
207+
"""State for CachingVisitor."""
208+
209+
config_options: ConfigOptions
210+
for_parquet: bool
211+
input_ir: ir.IR
212+
input_partition_info: PartitionInfo
213+
name_to_index: Mapping[str, int]
214+
replacements: Mapping[NodeT, NodeT]
215+
table_ref: plc.expressions.TableReference
216+
unique_names: Generator[str, None, None]

python/cudf_polars/docs/overview.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ def rename(e: Expr, mapping: Mapping[str, str]) -> Expr:
371371
return mapper(e)
372372
```
373373

374+
In practice, `state` is a `TypedDict` defined in `cudf_polars.typing`. To add a
375+
new field to the state, you'll need to update `CachingVisitorState` to include
376+
the key and type of that field.
377+
374378
# Containers
375379

376380
Containers should be constructed as relatively lightweight objects

python/cudf_polars/tests/dsl/test_traversal.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,23 @@ def test_caching_visitor():
6666

6767
e1 = make_expr(dt, "a", "b")
6868

69-
mapper = CachingVisitor(rename, state={"mapping": {"b": "c"}})
69+
# This test adds an extra key to CachingVisitorState.
70+
# That's just a TypedDict, so adding extra keys is fine at runtime.
71+
# We don't want to include these test-only fields in the TypeDict,
72+
# so we use an ignore here.
73+
mapper = CachingVisitor(rename, state={"mapping": {"b": "c"}}) # type: ignore
7074

7175
renamed = mapper(e1)
7276
assert renamed == make_expr(dt, "a", "c")
7377
assert len(mapper.cache) == 3
7478

7579
e2 = make_expr(dt, "a", "a")
76-
mapper = CachingVisitor(rename, state={"mapping": {"b": "c"}})
80+
mapper = CachingVisitor(rename, state={"mapping": {"b": "c"}}) # type: ignore
7781

7882
renamed = mapper(e2)
7983
assert renamed == make_expr(dt, "a", "a")
8084
assert len(mapper.cache) == 2
81-
mapper = CachingVisitor(rename, state={"mapping": {"a": "c"}})
85+
mapper = CachingVisitor(rename, state={"mapping": {"a": "c"}}) # type: ignore
8286

8387
renamed = mapper(e2)
8488
assert renamed == make_expr(dt, "c", "c")
@@ -189,7 +193,7 @@ def _transform(e: expr.Expr, fn: ExprTransformer) -> expr.Expr:
189193

190194
@_transform.register
191195
def _(e: expr.Col, fn: ExprTransformer):
192-
mapping = fn.state["mapping"]
196+
mapping = fn.state["mapping"] # type: ignore
193197
if e.name in mapping:
194198
return type(e)(e.dtype, mapping[e.name])
195199
return e
@@ -210,7 +214,7 @@ def _rewrite(node: ir.IR, fn: IRTransformer) -> ir.IR:
210214

211215
@_rewrite.register
212216
def _(node: ir.Select, fn: IRTransformer):
213-
expr_mapper = fn.state["expr_mapper"]
217+
expr_mapper = fn.state["expr_mapper"] # type: ignore
214218
return type(node)(
215219
node.schema,
216220
[expr.NamedExpr(e.name, expr_mapper(e.value)) for e in node.exprs],

0 commit comments

Comments
 (0)