Skip to content

Commit d942d2c

Browse files
ianhiclaude
andcommitted
Simplify multi-actor stateful tests and fix delete+recreate rebase
- Refactor ClaimTracker: remove _protected, _try_clear, _stash_session, _on_clear - Add delete_dir() and rewrite_manifests() to MultiActorModel - Simplify IcechunkModel with cleaner lifecycle methods - Extract actor logic to multi-actor case only - Fix delete+recreate in rebase (Rust conflict detector) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3fc6f89 commit d942d2c

7 files changed

Lines changed: 542 additions & 140 deletions

File tree

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Multi-actor coordinator for stateful tests.
2+
3+
Manages multiple actors sharing a repository. Each actor has its own
4+
IcechunkModel. The coordinator swaps the active model and computes
5+
the blocked set (paths other actors have modified).
6+
7+
The test class's self.model is always a plain IcechunkModel — the
8+
coordinator just tracks which one is active and what's blocked.
9+
The committed baseline lives on the test (self.committed), not here.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from pathlib import PurePosixPath
15+
16+
from icechunk.testing.stateful_models import IcechunkModel
17+
18+
19+
class MultiActorCoordinator:
20+
"""Coordinates multiple actors sharing a repository.
21+
22+
Holds the actor dict. Computes the blocked set from other actors'
23+
changes. The test class owns self.model and self.committed directly.
24+
"""
25+
26+
@staticmethod
27+
def _has_blocked_ancestor(path: str, blocked: set[str]) -> bool:
28+
"""True if path or any ancestor of path is in the blocked set."""
29+
pp = PurePosixPath(path)
30+
return any(pp.is_relative_to(p) for p in blocked)
31+
32+
@staticmethod
33+
def _has_blocked_relative(path: str, blocked: set[str]) -> bool:
34+
"""True if any blocked path is an ancestor, descendant, or equal."""
35+
pp = PurePosixPath(path)
36+
return any(
37+
pp.is_relative_to(p) or PurePosixPath(p).is_relative_to(path) for p in blocked
38+
)
39+
40+
def __init__(self) -> None:
41+
self.current_actor: str = ""
42+
self._actors: dict[str, IcechunkModel] = {}
43+
self.blocked: set[str] = set()
44+
45+
@property
46+
def current(self) -> IcechunkModel:
47+
"""The current actor's model."""
48+
return self._actors[self.current_actor]
49+
50+
@current.setter
51+
def current(self, model: IcechunkModel) -> None:
52+
self._actors[self.current_actor] = model
53+
54+
def recompute_blocked(self, committed: IcechunkModel | None) -> None:
55+
"""Recompute the blocked set for the current actor.
56+
57+
The blocked set is:
58+
1. Other actors' uncommitted changes (structural + data)
59+
2. Main drift: paths that changed on committed since this actor's baseline
60+
"""
61+
actor = self.current_actor
62+
result: set[str] = set()
63+
64+
for name, state in self._actors.items():
65+
if name != actor:
66+
result |= state.changes()
67+
68+
current = self._actors.get(actor)
69+
if current is not None and committed is not None:
70+
committed_nodes = frozenset(committed.all_arrays) | frozenset(
71+
committed.all_groups
72+
)
73+
result |= committed_nodes ^ current.baseline
74+
75+
self.blocked = result
76+
77+
# ── Actor management ─────────────────────────────────────────────
78+
79+
def switch(self, actor: str, committed: IcechunkModel | None) -> None:
80+
"""Switch to an existing actor and recompute the blocked set."""
81+
self.current_actor = actor
82+
self.recompute_blocked(committed)
83+
84+
def init_actor(
85+
self, committed: IcechunkModel | None, bootstrap: str = "default"
86+
) -> None:
87+
"""Promote the bootstrap model to the current actor's model."""
88+
self._actors[self.current_actor] = self._actors.pop(bootstrap)
89+
if committed is not None:
90+
self.current.sync_baseline(committed)
91+
self.recompute_blocked(committed)
92+
93+
def other_actors_clean(self) -> bool:
94+
"""True if no other actor has uncommitted changes."""
95+
return all(
96+
not state.changes()
97+
for name, state in self._actors.items()
98+
if name != self.current_actor
99+
)
100+
101+
def can_add(self, path: str) -> bool:
102+
return not self._has_blocked_ancestor(path, self.blocked)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from __future__ import annotations
2+
3+
from zarr.core.buffer import default_buffer_prototype
4+
from zarr.storage import MemoryStore
5+
6+
PROTOTYPE = default_buffer_prototype()
7+
8+
9+
class IcechunkModel(MemoryStore):
10+
"""MemoryStore with move, copy, and lifecycle methods for testing.
11+
12+
Tracks all_arrays/all_groups alongside store data so that copy()
13+
preserves them.
14+
"""
15+
16+
spec_version: int
17+
18+
def __init__(self) -> None:
19+
super().__init__()
20+
self.all_arrays: set[str] = set()
21+
self.all_groups: set[str] = set()
22+
self.baseline: frozenset[str] = frozenset()
23+
self.claims: set[str] = set()
24+
25+
@classmethod
26+
def from_store(cls, store: MemoryStore) -> IcechunkModel:
27+
"""Promote a plain MemoryStore to an IcechunkModel."""
28+
model = cls()
29+
model._store_dict = store._store_dict
30+
return model
31+
32+
def _claim_node(self, key: str) -> None:
33+
"""Record which array or group a key belongs to."""
34+
for array in self.all_arrays:
35+
if key.startswith(array + "/"):
36+
self.claims.add(array)
37+
return
38+
for group in self.all_groups:
39+
if key.startswith(group + "/"):
40+
self.claims.add(group)
41+
return
42+
43+
async def set(
44+
self, key: str, value: object, byte_range: tuple[int, int] | None = None
45+
) -> None:
46+
self._claim_node(key)
47+
await super().set(key, value, byte_range)
48+
49+
async def delete(self, key: str) -> None:
50+
self._claim_node(key)
51+
await super().delete(key)
52+
53+
async def move(self, source: str, dest: str) -> None:
54+
"""Move all keys from source to dest.
55+
56+
Store keys always have form "node/zarr.json" or "node/c/...", never bare "node".
57+
"""
58+
all_keys = [k async for k in self.list_prefix("")]
59+
keys_to_move = [k for k in all_keys if k.startswith(source + "/")]
60+
for old_key in keys_to_move:
61+
new_key = dest + old_key[len(source) :]
62+
data = await self.get(old_key, prototype=PROTOTYPE)
63+
if data is not None:
64+
await self.set(new_key, data)
65+
await self.delete(old_key)
66+
67+
async def copy(self) -> IcechunkModel:
68+
"""Create a copy of this store (data + arrays + groups)."""
69+
new_store = IcechunkModel()
70+
new_store.spec_version = self.spec_version
71+
new_store.all_arrays = self.all_arrays.copy()
72+
new_store.all_groups = self.all_groups.copy()
73+
async for key in self.list_prefix(""):
74+
data = await self.get(key, prototype=PROTOTYPE)
75+
if data is not None:
76+
await new_store.set(key, data)
77+
return new_store
78+
79+
# things for tracking changes relative to committed baseline
80+
@property
81+
def nodes(self) -> frozenset[str]:
82+
"""Current set of all node paths in this model."""
83+
return frozenset(self.all_arrays) | frozenset(self.all_groups)
84+
85+
def changes(self) -> set[str]:
86+
"""Paths modified since last sync: structural diff + data claims."""
87+
return set(self.claims) | (self.nodes ^ self.baseline)
88+
89+
def sync_baseline(self, committed) -> None:
90+
"""Record committed state as this actor's baseline."""
91+
self.baseline = frozenset(committed.all_arrays) | frozenset(committed.all_groups)
92+
93+
# ── Lifecycle methods ─────────────────────────────────────────────
94+
95+
async def commit(self) -> IcechunkModel:
96+
"""Snapshot this model as the new committed baseline.
97+
98+
Returns the committed copy; caller stores it as shared state.
99+
"""
100+
committed = await self.copy()
101+
self.claims.clear()
102+
self.sync_baseline(committed)
103+
return committed
104+
105+
async def rebase(self, committed: IcechunkModel) -> IcechunkModel:
106+
"""Merge committed baseline with local changes.
107+
108+
Returns a new model with the merge result; caller swaps it in.
109+
"""
110+
merged = await committed.copy()
111+
my_changes = self.changes()
112+
113+
merged.all_arrays = (committed.all_arrays - my_changes) | (
114+
my_changes & self.all_arrays
115+
)
116+
merged.all_groups = (committed.all_groups - my_changes) | (
117+
my_changes & self.all_groups
118+
)
119+
120+
for path in my_changes:
121+
keys_to_delete = [k async for k in merged.list_prefix(path + "/")]
122+
for key in keys_to_delete:
123+
await merged.delete(key)
124+
async for key in self.list_prefix(path + "/"):
125+
data = await self.get(key, prototype=PROTOTYPE)
126+
if data is not None:
127+
await merged.set(key, data)
128+
129+
merged.sync_baseline(committed)
130+
return merged
131+
132+
async def new_session(self, committed: IcechunkModel) -> IcechunkModel:
133+
"""Reset to the committed baseline, discarding local changes."""
134+
fresh = await committed.copy()
135+
fresh.claims.clear()
136+
fresh.sync_baseline(committed)
137+
return fresh
138+
139+
async def rewrite_manifests(self, committed: IcechunkModel) -> IcechunkModel:
140+
"""Reset to committed baseline (rewrite_manifests doesn't change data)."""
141+
fresh = await committed.copy()
142+
fresh.claims.clear()
143+
fresh.sync_baseline(committed)
144+
return fresh

0 commit comments

Comments
 (0)