Skip to content

Commit 293e5fc

Browse files
committed
feat: add ZarrsArray
1 parent 234030b commit 293e5fc

File tree

8 files changed

+1171
-58
lines changed

8 files changed

+1171
-58
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ crate-type = ["cdylib", "rlib"]
1010

1111
[dependencies]
1212
pyo3 = { version = "0.27.1", features = ["abi3-py311"] }
13-
zarrs = { version = "0.23.0", features = ["async", "zlib", "pcodec", "bz2"] }
13+
zarrs = { version = "0.23.1", features = ["async", "zlib", "pcodec", "bz2"] }
1414
rayon_iter_concurrent_limit = "0.2.0"
1515
rayon = "1.10.0"
1616
# fix for https://stackoverflow.com/questions/76593417/package-openssl-was-not-found-in-the-pkg-config-search-path
@@ -26,6 +26,7 @@ itertools = "0.14.0"
2626
bytemuck = { version = "1.24.0", features = ["must_cast"] }
2727
pyo3-object_store = "0.7.0" # object_store 0.12
2828
zarrs_object_store = "0.5.0" # object_store 0.12
29+
mimalloc = { version = "0.1", default-features = false }
2930

3031
[profile.release]
3132
lto = true

python/zarrs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._internal import __version__
2+
from .array import ZarrsArray
23
from .pipeline import ZarrsCodecPipeline as _ZarrsCodecPipeline
34
from .utils import CollapsedDimensionError, DiscontiguousArrayError
45

@@ -9,6 +10,7 @@ class ZarrsCodecPipeline(_ZarrsCodecPipeline):
910

1011

1112
__all__ = [
13+
"ZarrsArray",
1214
"ZarrsCodecPipeline",
1315
"DiscontiguousArrayError",
1416
"CollapsedDimensionError",

python/zarrs/_internal.pyi

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,40 @@ import typing
77
import numpy.typing
88
import zarr.abc.store
99

10+
@typing.final
11+
class ArrayImpl:
12+
@property
13+
def shape(self) -> builtins.list[builtins.int]: ...
14+
@property
15+
def ndim(self) -> builtins.int: ...
16+
@property
17+
def dtype(self) -> builtins.str: ...
18+
def __new__(
19+
cls,
20+
store_config: zarr.abc.store.Store,
21+
path: builtins.str,
22+
*,
23+
validate_checksums: builtins.bool = False,
24+
num_threads: builtins.int | None = None,
25+
direct_io: builtins.bool = False,
26+
) -> ArrayImpl: ...
27+
def retrieve(
28+
self,
29+
ranges: typing.Sequence[tuple[builtins.int, builtins.int]],
30+
output: numpy.typing.NDArray[typing.Any],
31+
) -> None: ...
32+
def store(
33+
self,
34+
ranges: typing.Sequence[tuple[builtins.int, builtins.int]],
35+
input: numpy.typing.NDArray[typing.Any],
36+
) -> None: ...
37+
def copy_from(
38+
self,
39+
source: ArrayImpl,
40+
source_ranges: typing.Sequence[tuple[builtins.int, builtins.int]],
41+
dest_ranges: typing.Sequence[tuple[builtins.int, builtins.int]],
42+
) -> None: ...
43+
1044
@typing.final
1145
class ChunkItem:
1246
def __new__(

python/zarrs/array.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import zarr
5+
from zarr.core.array import Array
6+
7+
from ._internal import ArrayImpl
8+
9+
10+
def _is_basic_indexing(key) -> bool:
11+
"""Check if key uses only int, step-1 slices, and/or a single Ellipsis."""
12+
if not isinstance(key, tuple):
13+
key = (key,)
14+
has_ellipsis = False
15+
for k in key:
16+
if isinstance(k, int):
17+
continue
18+
elif isinstance(k, slice):
19+
if k.step is not None and k.step != 1:
20+
return False
21+
elif k is Ellipsis:
22+
if has_ellipsis:
23+
return False # multiple ellipses
24+
has_ellipsis = True
25+
else:
26+
return False
27+
return True
28+
29+
30+
class _LazySlice:
31+
"""Lazy reference to a subset of a ZarrsArray (no I/O until consumed)."""
32+
33+
__slots__ = ("_dtype", "_impl", "_ranges", "_region_shape", "_squeeze_dims")
34+
35+
def __init__(self, impl_, ranges, region_shape, dtype, squeeze_dims):
36+
self._impl = impl_
37+
self._ranges = ranges
38+
self._region_shape = region_shape
39+
self._dtype = dtype
40+
self._squeeze_dims = squeeze_dims
41+
42+
def __array__(self, dtype=None, copy=None) -> np.ndarray:
43+
out = np.empty(self._region_shape, dtype=self._dtype)
44+
if out.size > 0:
45+
self._impl.retrieve(self._ranges, out)
46+
if self._squeeze_dims:
47+
out = out.squeeze(axis=tuple(self._squeeze_dims))
48+
if dtype is not None and out.dtype != dtype:
49+
out = out.astype(dtype, copy=False)
50+
return out
51+
52+
53+
class _LazyIndexer:
54+
"""Proxy returned by ``ZarrsArray.lazy`` that captures indexing lazily."""
55+
56+
__slots__ = ("_pipeline",)
57+
58+
def __init__(self, pipeline: ZarrsArray):
59+
self._pipeline = pipeline
60+
61+
def __getitem__(self, key: slice | int | tuple[slice | int, ...]) -> _LazySlice:
62+
ranges, region_shape, squeeze_dims = self._pipeline._parse_key(key)
63+
return _LazySlice(
64+
self._pipeline._impl,
65+
ranges,
66+
region_shape,
67+
self._pipeline.dtype,
68+
squeeze_dims,
69+
)
70+
71+
72+
class ZarrsArray(Array):
73+
"""zarr.Array subclass backed by zarrs for fast I/O.
74+
75+
Supports all zarr.Array operations. Basic slice indexing (ints, step-1
76+
slices, ellipsis) is handled by the Rust fast path; advanced indexing
77+
falls back to zarr.Array unless ``codec_pipeline.strict`` is set.
78+
"""
79+
80+
def __init__(
81+
self,
82+
array: Array,
83+
*,
84+
validate_checksums: bool = False,
85+
chunk_concurrent_minimum: int | None = None,
86+
num_threads: int | None = None,
87+
direct_io: bool = False,
88+
) -> None:
89+
super().__init__(array._async_array)
90+
store = array.store_path.store
91+
zarr_path = array.store_path.path
92+
zarrs_path = "/" + zarr_path if zarr_path else "/"
93+
self._impl = ArrayImpl(
94+
store,
95+
zarrs_path,
96+
validate_checksums=validate_checksums,
97+
chunk_concurrent_minimum=chunk_concurrent_minimum,
98+
num_threads=num_threads,
99+
direct_io=direct_io,
100+
)
101+
102+
@property
103+
def lazy(self) -> _LazyIndexer:
104+
return _LazyIndexer(self)
105+
106+
def _parse_key(
107+
self, key: slice | int | tuple[slice | int, ...]
108+
) -> tuple[list[tuple[int, int]], list[int], list[int]]:
109+
if not isinstance(key, tuple):
110+
key = (key,)
111+
112+
# Expand Ellipsis
113+
if Ellipsis in key:
114+
idx = key.index(Ellipsis)
115+
n_explicit = len(key) - 1 # everything except the Ellipsis
116+
n_expand = self.ndim - n_explicit
117+
if n_expand < 0:
118+
raise IndexError(
119+
f"too many indices for array: "
120+
f"array is {self.ndim}-dimensional, "
121+
f"but {n_explicit} were indexed"
122+
)
123+
key = key[:idx] + (slice(None),) * n_expand + key[idx + 1 :]
124+
125+
if len(key) > self.ndim:
126+
raise IndexError(
127+
f"too many indices for array: "
128+
f"array is {self.ndim}-dimensional, "
129+
f"but {len(key)} were indexed"
130+
)
131+
132+
# Pad missing dimensions with full slices
133+
if len(key) < self.ndim:
134+
key = key + (slice(None),) * (self.ndim - len(key))
135+
136+
ranges: list[tuple[int, int]] = []
137+
region_shape: list[int] = []
138+
squeeze_dims: list[int] = []
139+
140+
for i, (k, dim_size) in enumerate(zip(key, self.shape)):
141+
if isinstance(k, int):
142+
if k < 0:
143+
k += dim_size
144+
if k < 0 or k >= dim_size:
145+
raise IndexError(
146+
f"index {k} is out of bounds for axis {i} with size {dim_size}"
147+
)
148+
ranges.append((k, k + 1))
149+
region_shape.append(1)
150+
squeeze_dims.append(i)
151+
elif isinstance(k, slice):
152+
start, stop, step = k.indices(dim_size)
153+
if step != 1:
154+
raise IndexError("only step=1 slices are supported")
155+
ranges.append((start, stop))
156+
region_shape.append(max(0, stop - start))
157+
else:
158+
raise IndexError(f"unsupported index type: {type(k).__name__}")
159+
160+
return ranges, region_shape, squeeze_dims
161+
162+
def __getitem__(self, key: slice | int | tuple[slice | int, ...]) -> np.ndarray:
163+
if _is_basic_indexing(key):
164+
ranges, region_shape, squeeze_dims = self._parse_key(key)
165+
out = np.empty(region_shape, dtype=self.dtype)
166+
if out.size > 0:
167+
self._impl.retrieve(ranges, out)
168+
if squeeze_dims:
169+
out = out.squeeze(axis=tuple(squeeze_dims))
170+
return out
171+
172+
strict = zarr.config.get("codec_pipeline.strict", False)
173+
if strict:
174+
raise IndexError(
175+
"ZarrsArray in strict mode does not support advanced indexing"
176+
)
177+
return super().__getitem__(key)
178+
179+
def __setitem__(self, key: slice | int | tuple[slice | int, ...], value) -> None:
180+
if _is_basic_indexing(key):
181+
ranges, region_shape, squeeze_dims = self._parse_key(key)
182+
183+
if isinstance(value, _LazySlice):
184+
if value._region_shape != region_shape:
185+
raise ValueError(
186+
f"could not broadcast input array from shape "
187+
f"{tuple(value._region_shape)} "
188+
f"into shape {tuple(region_shape)}"
189+
)
190+
if all(s > 0 for s in region_shape):
191+
self._impl.copy_from(value._impl, value._ranges, ranges)
192+
return
193+
194+
value = np.asarray(value, dtype=self.dtype)
195+
196+
# Ensure native byte order
197+
if not value.dtype.isnative:
198+
value = value.byteswap().view(value.dtype.newbyteorder("="))
199+
200+
# Expand squeezed dimensions back
201+
for dim in squeeze_dims:
202+
value = np.expand_dims(value, axis=dim)
203+
204+
if value.shape != tuple(region_shape):
205+
raise ValueError(
206+
f"could not broadcast input array from shape {value.shape} "
207+
f"into shape {tuple(region_shape)}"
208+
)
209+
210+
# Ensure C-contiguous before passing to Rust
211+
value = np.ascontiguousarray(value)
212+
213+
if value.size > 0:
214+
self._impl.store(ranges, value)
215+
return
216+
217+
strict = zarr.config.get("codec_pipeline.strict", False)
218+
if strict:
219+
raise IndexError(
220+
"ZarrsArray in strict mode does not support advanced indexing"
221+
)
222+
super().__setitem__(key, value)

0 commit comments

Comments
 (0)