Skip to content

Commit

Permalink
Merge pull request #21 from data-exp-lab/fix_zarr_future_compatibility
Browse files Browse the repository at this point in the history
fix compatibility with zarr>=3.1
  • Loading branch information
chrishavlin authored Jan 22, 2025
2 parents 9c97435 + 263897f commit f4a3d84
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
14 changes: 8 additions & 6 deletions src/pyramid_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,16 @@ def _downsample_by_one_level(
zarr_field: str,
) -> None:
level = coarse_level
level_str = str(level)
fine_level = level - 1
lev_shape = self._get_level_shape(level)
fine_lev_str = str(fine_level)
lev_shape = self._get_level_shape(level).tolist()

field1 = zarr.open(self.zarr_store_path)[zarr_field]
dtype = field1[fine_level].dtype
field1.empty(level, shape=lev_shape, chunks=self.chunks, dtype=dtype)
dtype = field1[fine_lev_str].dtype
field1.empty(name=level_str, shape=lev_shape, chunks=self.chunks, dtype=dtype)

numchunks = field1[str(level)].nchunks
numchunks = field1[level_str].nchunks

chunk_writes = []
for ichunk in range(numchunks):
Expand Down Expand Up @@ -270,7 +272,7 @@ def initialize_test_image(
"""
if dtype is None:
dtype = np.float64
field1 = zarr_store.create_group(zarr_field, overwrite=overwrite_field)
field1 = zarr_store.create_group(name=zarr_field, overwrite=overwrite_field)

if chunks is None:
chunks = (64, 64, 64)
Expand All @@ -288,5 +290,5 @@ def initialize_test_image(
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] = (
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5 * fac
)
field1.empty(0, shape=base_resolution, chunks=chunks, dtype=dtype)
field1.empty(name="0", shape=base_resolution, chunks=chunks, dtype=dtype)
da.to_zarr(lev0, field1["0"])
10 changes: 5 additions & 5 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def test_initialize_test_image(tmp_path):
initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=False)

assert fieldname in zarr_store
assert zarr_store[fieldname][0].shape == res
assert zarr_store[fieldname][0].chunks == chunks
assert zarr_store[fieldname]["0"].shape == res
assert zarr_store[fieldname]["0"].chunks == chunks
assert Path.exists(tmp_path / "myzarr.zarr" / fieldname)

res = (16, 16, 16)
initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=True)
assert zarr_store[fieldname][0].shape == res
assert zarr_store[fieldname]["0"].shape == res


@pytest.mark.parametrize("dtype", ["float32", np.float64, "int", np.int32, np.int16])
Expand All @@ -43,8 +43,8 @@ def test_downsampler(tmp_path, dtype):
dsr.downsample(10, fieldname)
expected_max_lev = 2
for lev in range(expected_max_lev + 1):
assert lev in zarr_store[fieldname]
assert zarr_store[fieldname][lev].dtype == np.dtype(dtype)
assert str(lev) in zarr_store[fieldname]
assert zarr_store[fieldname][str(lev)].dtype == np.dtype(dtype)

with pytest.raises(ValueError, match="max_level must exceed 0"):
dsr.downsample(0, fieldname)
Expand Down

0 comments on commit f4a3d84

Please sign in to comment.