Skip to content

Commit 701c11a

Browse files
authored
[DataOriented] Revert "[DataOriented] Fix ndarrays on data oriented (#704)" (#719)
1 parent 57ad234 commit 701c11a

11 files changed

Lines changed: 55 additions & 1875 deletions

File tree

docs/source/user_guide/compound_types.md

Lines changed: 15 additions & 181 deletions
Large diffs are not rendered by default.

docs/source/user_guide/fastcache.md

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,43 @@ qd.init(arch=qd.gpu)
5050
# qd.init(arch=qd.gpu, print_non_pure=True)
5151
```
5252

53+
## Dataclass fields with cached values
54+
55+
By default, for `dataclasses.dataclass` parameters, fastcache only includes the *types* of each field in the cache key, not their values. This is fine for fields like ndarrays, where the compiled kernel doesn't depend on the actual data, only the dtype and dimensionality.
56+
57+
However, some dataclass fields hold configuration values that get baked into the compiled kernel — typically values used with `qd.static()`, such as loop bounds or feature flags:
58+
59+
```python
60+
for i in qd.static(range(config.num_layers)):
61+
...
62+
```
63+
64+
Here the value of `num_layers` is compiled into the kernel. Concretely the loop will be unrolled, at compile time. If `num_layers` changes, a different kernel must be compiled.
65+
66+
Mark such fields with `add_value_to_cache_key` so their values are included in the cache key:
67+
68+
```python
69+
import dataclasses
70+
from quadrants.lang._fast_caching import FIELD_METADATA_CACHE_VALUE
71+
72+
@dataclasses.dataclass
73+
class SimConfig:
74+
num_envs: int = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
75+
dt: float = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
76+
use_gravity: bool = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
77+
```
78+
79+
With this annotation, changing `num_envs` from 100 to 200 produces a different cache key so the correct compiled kernel is looked up (or compiled if not yet cached). Without it, the wrong kernel could be loaded.
80+
81+
Note: `@qd.data_oriented` objects and `qd.Template` parameters already include primitive values in the cache key automatically — this annotation is only needed for `dataclasses.dataclass` fields.
82+
5383
## Constraints
5484

5585
A kernel is eligible for fastcache only if all of the following hold:
5686

5787
### 1. All data flows through parameters
5888

59-
The kernel must receive every piece of data it operates on as an explicit parameter. It must **not** capture variables from the enclosing Python scope (closures over ndarrays, mutable globals, or any other external state). This is the core "purity" constraint — the compiled kernel's behavior must be fully determined by its arguments.
89+
The kernel must receive every piece of data it operates on as an explicit parameter. It must **not** capture variables from the enclosing Python scope (closures over fields, ndarrays, or mutable globals). This is the core "purity" constraint — the compiled kernel's behavior must be fully determined by its arguments.
6090

6191
```python
6292
a = qd.ndarray(qd.f32, (10,))
@@ -95,8 +125,8 @@ Fastcache supports the following parameter types:
95125
| `qd.types.NDArray` (scalar, vector, matrix) | Yes | dtype, ndim, layout |
96126
| `torch.Tensor` | Yes | dtype, ndim |
97127
| `numpy.ndarray` | Yes | dtype, ndim |
98-
| `dataclasses.dataclass` | Yes | member types recursively; member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Appendix — compound-type cache keying](#compound-type-cache-keying)) |
99-
| `@qd.data_oriented` objects | Yes | member types recursively; primitive member types and values baked into kernel (see [Appendix — compound-type cache keying](#compound-type-cache-keying)) |
128+
| `dataclasses.dataclass` | Yes | field types recursively; field values if annotated with `add_value_to_cache_key` (see [above](#dataclass-fields-with-cached-values)) |
129+
| `@qd.data_oriented` objects | Yes | member types and primitive member values recursively |
100130
| `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) |
101131
| Non-template primitives (int, float, bool) | Yes | type only |
102132
| `enum.Enum` | Yes | name and value |
@@ -142,33 +172,3 @@ print(obs.cache_stored) # True if the compiled kernel was stored to cach
142172
```
143173

144174
On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the second run (after `qd.init`), `cache_loaded=True`.
145-
146-
## Appendix
147-
148-
### Compound-type cache keying
149-
150-
The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules:
151-
152-
**`@qd.data_oriented`:** the walker descends into `vars(obj)`. For each child:
153-
154-
- `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not.
155-
- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries.
156-
- Nested `@qd.data_oriented` member — recurses.
157-
- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below).
158-
- `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted.
159-
160-
**`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it:
161-
162-
```python
163-
import dataclasses
164-
from quadrants.lang._fast_caching import FIELD_METADATA_CACHE_VALUE
165-
166-
@dataclasses.dataclass
167-
class SimConfig:
168-
num_layers: int = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
169-
dt: float = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
170-
```
171-
172-
This is necessary whenever the compiled kernel depends on the member's *value* rather than just its type (for example, when the value is used as a loop bound that the compiler bakes into the generated code). Without the annotation, two `SimConfig` instances with different `num_layers` values would share a fastcache key, and the second instance would silently load a kernel compiled for the wrong value.
173-
174-
Note the asymmetry: `@qd.data_oriented` primitive members are baked into the kernel automatically (same semantics as `qd.Template`); `dataclasses.dataclass` members contribute only their *type* to the cache key unless you opt in per-member.

docs/source/user_guide/tensor.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,6 @@ fill(b) # ndarray branch
205205

206206
The kernel argument is unwrapped to the bare impl before the template-mapper / AST sees it, so kernel bodies still write `x[i, j]` and pay no per-call cost for the wrapper.
207207

208-
`qd.Tensor` is also the right annotation when storing a tensor as a `dataclasses.dataclass` member:
209-
210-
```python
211-
@dataclass
212-
class State:
213-
a: qd.Tensor
214-
b: qd.Tensor
215-
```
216-
217208
## Pickle
218209

219210
`qd.Tensor` objects are picklable on **both** backends, including under non-identity layouts. Round-trip (pickle then unpickle) preserves the canonical data, the dtype, the shape, and the layout:

python/quadrants/lang/_template_mapper.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,10 @@
55
from quadrants.lang import impl
66
from quadrants.lang.impl import Program
77
from quadrants.lang.kernel_arguments import ArgMetadata
8-
from quadrants.lang.util import is_data_oriented
98

109
from .._test_tools import warnings_helper
1110
from ._kernel_types import ArgsHash
12-
from ._template_mapper_hotpath import (
13-
_extract_arg,
14-
_primitive_types,
15-
_struct_nd_paths_for,
16-
)
17-
18-
19-
def _collect_data_oriented_nd_ids(arg: Any, out: list) -> None:
20-
"""Append ``id(ndarray)`` for every ndarray reachable from ``arg``, using the per-class path cache in
21-
``_template_mapper_hotpath._struct_nd_paths_for`` so the first call walks ``vars(arg)`` once and subsequent calls
22-
are just ``getattr`` chains. Empty path list short-circuits with zero work — critical for genesis's
23-
``@qd.data_oriented`` Solver passed as ``self`` to every kernel.
24-
"""
25-
for chain in _struct_nd_paths_for(arg):
26-
v = arg
27-
for a in chain:
28-
v = getattr(v, a)
29-
out.append(id(v))
30-
11+
from ._template_mapper_hotpath import _extract_arg, _primitive_types
3112

3213
Key: TypeAlias = tuple[Any, ...]
3314

@@ -90,17 +71,6 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl
9071
# branching for primitive types dramatically improve performance of hash computation.
9172
mapping_cache_tracker: list[ReferenceType | None] | None = None
9273
args_hash: ArgsHash = tuple([id(arg) for arg in args])
93-
# ``@qd.data_oriented`` containers can have their member ndarrays reassigned between calls on the same instance
94-
# (``state.x = other_ndarray``). The id(arg) alone does not capture that, so the spec-key cache below would
95-
# serve a stale entry and the new ndarray's dtype/ndim would be wrong. Fold the reachable ndarray ids into the
96-
# hash. No-op for data_oriented containers that hold no ndarrays — the walker returns an empty list. See
97-
# ``_collect_data_oriented_nd_ids``.
98-
nd_ids: list = []
99-
for arg in args:
100-
if is_data_oriented(arg):
101-
_collect_data_oriented_nd_ids(arg, nd_ids)
102-
if nd_ids:
103-
args_hash = args_hash + tuple(nd_ids)
10474
try:
10575
mapping_cache_tracker = self._mapping_cache_tracker[args_hash]
10676
except KeyError:

python/quadrants/lang/_template_mapper_hotpath.py

Lines changed: 1 addition & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
a consequence of inlining 'is_dataclass' and 'fields'.
2626
"""
2727

28-
import dataclasses
2928
import weakref
3029
from dataclasses import _FIELD, _FIELDS
3130
from typing import Any, Union
@@ -72,112 +71,6 @@
7271
_primitive_types = {int, float, bool}
7372

7473

75-
# Per-class cache: ``type(arg) -> list[tuple[str, ...]]`` of attribute paths whose values are ``Ndarray`` instances at
76-
# first observation. Populated lazily by ``_struct_nd_paths_for`` on the first call with each new data_oriented (or
77-
# nested dataclass) class. Empty list means "this class holds no ndarrays anywhere", in which case subsequent calls
78-
# pay only a dict-lookup per arg. Non-empty list short-circuits the full ``vars()`` recursion and just resolves each
79-
# cached path via ``getattr`` chains. Critical for the genesis field-backend hot path: the ``@qd.data_oriented``
80-
# Solver is passed as ``self`` to most kernels and holds dozens of attributes, so a full per-call ``vars()`` walk
81-
# costs >100ns per kernel and trashed FPS until this cache was added.
82-
_struct_nd_paths_cache: dict[type, list[tuple]] = {}
83-
84-
85-
def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None:
86-
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
87-
children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj))
88-
else:
89-
# ``NamedTuple`` (decorated as ``@qd.data_oriented``) has no instance ``__dict__`` — fall back to ``_asdict()``
90-
# which materialises a dict view of the named fields. Mirrors the same fallback in
91-
# ``args_hasher.stringify_obj_type`` so the per-class path cache here picks up ndarray members on NamedTuples
92-
# too (regression covered by ``test_args_hasher_named_tuple``).
93-
try:
94-
children = obj._asdict().items()
95-
except AttributeError:
96-
children = obj.__dict__.items()
97-
for k, v in children:
98-
chain = prefix + (k,)
99-
if type(v) in _TENSOR_WRAPPER_TYPES:
100-
v = v._unwrap()
101-
v_type = type(v)
102-
if issubclass(v_type, Ndarray):
103-
out.append(chain)
104-
elif is_data_oriented(v) or (dataclasses.is_dataclass(v) and not isinstance(v, type)):
105-
_build_struct_nd_paths(v, chain, out)
106-
107-
108-
def _struct_nd_paths_for(arg: Any) -> list[tuple]:
109-
"""Return the cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` instances are
110-
reachable from ``arg`` of type ``type(arg)``. First call for a class walks ``arg`` once via
111-
``_build_struct_nd_paths``; subsequent calls are dict-lookups.
112-
113-
Trades freshness for speed: assumes the *set* of ndarray-holding attribute paths is stable across instances of
114-
the same class. The genesis Solver and similar ``@qd.data_oriented`` containers satisfy this — their ndarray
115-
members are declared in ``__init__`` and not added later. If you need to add an ndarray attribute after the first
116-
kernel launch on an instance of a given class, the new attribute won't be tracked. Call ``invalidate_struct_nd_
117-
paths_for`` (below) or restart the program.
118-
119-
FIXME (Codex #3 on PR #704, https://github.com/Genesis-Embodied-AI/quadrants/pull/704#discussion_r3253281957):
120-
the cache is keyed by ``type(arg)`` only. If two instances of the same class have *polymorphic attribute
121-
structure* — e.g. instance A has ``.x`` as a ``qd.ndarray``-backed ``qd.Tensor`` while instance B has the same
122-
``.x`` as a field-backed ``qd.Tensor`` — the paths discovered from the first-walked instance are reused for the
123-
second. ``_collect_struct_nd_descriptors`` then unconditionally reads ndarray-only attrs (``element_type``,
124-
``grad``, ``_qd_layout``) on what is now a ``ScalarField``, raising before the kernel can run. The fix is the
125-
per-instance walk implemented on top of this branch in PR #705; this branch ships the class-level cache as-is.
126-
"""
127-
cls = type(arg)
128-
paths = _struct_nd_paths_cache.get(cls)
129-
if paths is None:
130-
paths = []
131-
_build_struct_nd_paths(arg, (), paths)
132-
_struct_nd_paths_cache[cls] = paths
133-
return paths
134-
135-
136-
def chain_has_mutable_container(args, template_arg_idx, attr_chain) -> bool:
137-
"""Return True if any container along ``attr_chain`` from ``args[template_arg_idx]`` down to (but excluding) the
138-
leaf ndarray attribute is mutable in a way that lets it rebind its child attribute. Such a parent makes
139-
``id(args[template_arg_idx])`` alone insufficient to uniquely identify the leaf, so the leaf id must be folded
140-
into the launch-context cache key.
141-
142-
A container is "mutable" here iff:
143-
- its type has ``__hash__ is None`` (Python sets this for non-frozen ``@dataclass(eq=True)`` types), or
144-
- it is a ``@qd.data_oriented`` instance (these inherit ``object.__hash__`` so the ``__hash__ is None`` check
145-
misses them; they support normal attribute assignment).
146-
147-
Walks all parents from the root down to ``attr_chain[:-1]`` — the final entry is the leaf itself, whose own
148-
mutability does not affect rebinding by its parent. Returns on the first mutable parent.
149-
"""
150-
cur = args[template_arg_idx]
151-
if type(cur).__hash__ is None or is_data_oriented(cur):
152-
return True
153-
for attr_name in attr_chain[:-1]:
154-
cur = getattr(cur, attr_name)
155-
if type(cur).__hash__ is None or is_data_oriented(cur):
156-
return True
157-
return False
158-
159-
160-
def _collect_struct_nd_descriptors(arg: Any, out: list) -> None:
161-
"""Emit per-ndarray shape descriptors ``(joined-path, element_type, ndim, needs_grad, layout)`` for every ndarray
162-
reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding
163-
ndarrays — see the data_oriented branch in ``_extract_arg``.
164-
165-
FIXME (Codex #3 on PR #704): when a polymorphic instance reuses a cached path that pointed to an ``Ndarray`` on
166-
the first-walked instance, ``v`` here can be a ``ScalarField`` and the ``v.element_type`` / ``v.grad`` /
167-
``v._qd_layout`` reads will raise. See ``_struct_nd_paths_for`` above for details. Fixed in PR #705 via the
168-
per-instance walk redesign.
169-
"""
170-
for chain in _struct_nd_paths_for(arg):
171-
v = arg
172-
for a in chain:
173-
v = getattr(v, a)
174-
if type(v) in _TENSOR_WRAPPER_TYPES:
175-
v = v._unwrap()
176-
type_id = id(v.element_type)
177-
element_type = type_id if type_id in primitive_types.type_ids else v.element_type
178-
out.append((".".join(chain), element_type, len(v.shape), v.grad is not None, v._qd_layout))
179-
180-
18174
def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: AnnotationType, arg_name: str) -> Any:
18275
# ``qd.Tensor`` wrappers passed as struct fields. Top-level kernel-arg unwrap in ``Kernel.__call__`` covers direct
18376
# args, but the dataclass-field recursion at the bottom of this function walks struct attributes via raw
@@ -231,7 +124,7 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati
231124
raise QuadrantsRuntimeTypeError(
232125
"Ndarray shouldn't be passed in via `qd.template()`, please annotate your kernel using `qd.types.ndarray(...)` instead"
233126
)
234-
if arg_type in _composite_mutable_types:
127+
if arg_type in _composite_mutable_types or is_data_oriented(arg):
235128
# [Composite arguments] Return weak reference to the object
236129
# Quadrants kernel will cache the extracted arguments, thus we can't simply return the original argument.
237130
# Instead, a weak reference to the original value is returned to avoid memory leak.
@@ -241,21 +134,6 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati
241134
# 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
242135
# 2. Different argument instances with same type and same value, will get templatized into separate kernels.
243136
return weakref.ref(arg)
244-
if is_data_oriented(arg):
245-
# Same memory-leak avoidance as above — keep ``weakref.ref(arg)`` so the spec key never holds a strong
246-
# reference to user state. But for data_oriented containers that hold ``Ndarray`` members, the live
247-
# ``weakref`` alone is too coarse: same instance with ``state.x = other_ndarray`` of a different dtype/ndim
248-
# would re-use the previously-compiled kernel, which was specialised for the old shape. Walk the reachable
249-
# ndarrays and prepend their shape descriptors so dtype/ndim changes trigger re-specialisation. Mirrors what
250-
# the dataclass branch below does via ``annotation_fields``.
251-
#
252-
# Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is
253-
# a no-op for the existing data_oriented + qd.field workloads (genesis field-backend).
254-
nd_descriptors: list = []
255-
_collect_struct_nd_descriptors(arg, nd_descriptors)
256-
if nd_descriptors:
257-
return (id(type(arg)), tuple(nd_descriptors), weakref.ref(arg))
258-
return weakref.ref(arg)
259137

260138
# Return value directly for other types, i.e. primitive types and all qd.Field-derived classes
261139
if raise_on_templated_floats and arg_type is float:

0 commit comments

Comments
 (0)