Skip to content

ruff rules: TCHTC #3032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ extend-exclude = [
extend-select = [
"ANN", # flake8-annotations
"B", # flake8-bugbear
"EXE", # flake8-executable
"C4", # flake8-comprehensions
"EXE", # flake8-executable
"FA", # flake8-future-annotations
"FLY", # flynt
"FURB", # refurb
Expand All @@ -310,7 +310,7 @@ extend-select = [
"RUF",
"SIM", # flake8-simplify
"SLOT", # flake8-slots
"TCH", # flake8-type-checking
"TC", # flake8-type-checking
"TRY", # tryceratops
"UP", # pyupgrade
"W", # pycodestyle warnings
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ async def open(
try:
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
# TODO: remove this cast when we fix typing for array metadata dicts
_metadata_dict = cast(ArrayMetadataDict, metadata_dict)
_metadata_dict = cast("ArrayMetadataDict", metadata_dict)
# for v2, the above would already have raised an exception if not an array
zarr_format = _metadata_dict["zarr_format"]
is_v3_array = zarr_format == 3 and _metadata_dict.get("node_type") == "array"
Expand Down
6 changes: 4 additions & 2 deletions src/zarr/codecs/crc32c_.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ async def _decode_single(
inner_bytes = data[:-4]

# Need to do a manual cast until https://github.com/numpy/numpy/issues/26783 is resolved
computed_checksum = np.uint32(crc32c(cast(typing_extensions.Buffer, inner_bytes))).tobytes()
computed_checksum = np.uint32(
crc32c(cast("typing_extensions.Buffer", inner_bytes))
).tobytes()
stored_checksum = bytes(crc32_bytes)
if computed_checksum != stored_checksum:
raise ValueError(
Expand All @@ -55,7 +57,7 @@ async def _encode_single(
) -> Buffer | None:
data = chunk_bytes.as_numpy_array()
# Calculate the checksum and "cast" it to a numpy array
checksum = np.array([crc32c(cast(typing_extensions.Buffer, data))], dtype=np.uint32)
checksum = np.array([crc32c(cast("typing_extensions.Buffer", data))], dtype=np.uint32)
# Append the checksum (as bytes) to the data
return chunk_spec.prototype.buffer.from_array_like(np.append(data, checksum.view("B")))

Expand Down
2 changes: 1 addition & 1 deletion src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class _ShardIndex(NamedTuple):
def chunks_per_shard(self) -> ChunkCoords:
result = tuple(self.offsets_and_lengths.shape[0:-1])
# The cast is required until https://github.com/numpy/numpy/pull/27211 is merged
return cast(ChunkCoords, result)
return cast("ChunkCoords", result)

def _localize_chunk(self, chunk_coords: ChunkCoords) -> ChunkCoords:
return tuple(
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/codecs/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def parse_transpose_order(data: JSON | Iterable[int]) -> tuple[int, ...]:
raise TypeError(f"Expected an iterable. Got {data} instead.")
if not all(isinstance(a, int) for a in data):
raise TypeError(f"Expected an iterable of integers. Got {data} instead.")
return tuple(cast(Iterable[int], data))
return tuple(cast("Iterable[int]", data))


@dataclass(frozen=True)
Expand Down
26 changes: 13 additions & 13 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(
if isinstance(metadata, dict):
zarr_format = metadata["zarr_format"]
# TODO: remove this when we extensively type the dict representation of metadata
_metadata = cast(dict[str, JSON], metadata)
_metadata = cast("dict[str, JSON]", metadata)
if zarr_format == 2:
metadata = ArrayV2Metadata.from_dict(_metadata)
elif zarr_format == 3:
Expand Down Expand Up @@ -898,7 +898,7 @@ async def open(
store_path = await make_store_path(store)
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
# TODO: remove this cast when we have better type hints
_metadata_dict = cast(ArrayV3MetadataDict, metadata_dict)
_metadata_dict = cast("ArrayV3MetadataDict", metadata_dict)
return cls(store_path=store_path, metadata=_metadata_dict)

@property
Expand Down Expand Up @@ -1394,7 +1394,7 @@ async def _set_selection(
if isinstance(array_like, np._typing._SupportsArrayFunc):
# TODO: need to handle array types that don't support __array_function__
# like PyTorch and JAX
array_like_ = cast(np._typing._SupportsArrayFunc, array_like)
array_like_ = cast("np._typing._SupportsArrayFunc", array_like)
value = np.asanyarray(value, dtype=self.metadata.dtype, like=array_like_)
else:
if not hasattr(value, "shape"):
Expand All @@ -1408,7 +1408,7 @@ async def _set_selection(
value = value.astype(dtype=self.metadata.dtype, order="A")
else:
value = np.array(value, dtype=self.metadata.dtype, order="A")
value = cast(NDArrayLike, value)
value = cast("NDArrayLike", value)
# We accept any ndarray like object from the user and convert it
# to a NDBuffer (or subclass). From this point onwards, we only pass
# Buffer and NDBuffer between components.
Expand Down Expand Up @@ -2431,11 +2431,11 @@ def __getitem__(self, selection: Selection) -> NDArrayLikeOrScalar:
"""
fields, pure_selection = pop_fields(selection)
if is_pure_fancy_indexing(pure_selection, self.ndim):
return self.vindex[cast(CoordinateSelection | MaskSelection, selection)]
return self.vindex[cast("CoordinateSelection | MaskSelection", selection)]
elif is_pure_orthogonal_indexing(pure_selection, self.ndim):
return self.get_orthogonal_selection(pure_selection, fields=fields)
else:
return self.get_basic_selection(cast(BasicSelection, pure_selection), fields=fields)
return self.get_basic_selection(cast("BasicSelection", pure_selection), fields=fields)

def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None:
"""Modify data for an item or region of the array.
Expand Down Expand Up @@ -2530,11 +2530,11 @@ def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None:
"""
fields, pure_selection = pop_fields(selection)
if is_pure_fancy_indexing(pure_selection, self.ndim):
self.vindex[cast(CoordinateSelection | MaskSelection, selection)] = value
self.vindex[cast("CoordinateSelection | MaskSelection", selection)] = value
elif is_pure_orthogonal_indexing(pure_selection, self.ndim):
self.set_orthogonal_selection(pure_selection, value, fields=fields)
else:
self.set_basic_selection(cast(BasicSelection, pure_selection), value, fields=fields)
self.set_basic_selection(cast("BasicSelection", pure_selection), value, fields=fields)

@_deprecate_positional_args
def get_basic_selection(
Expand Down Expand Up @@ -3652,7 +3652,7 @@ def update_attributes(self, new_attributes: dict[str, JSON]) -> Array:
# TODO: remove this cast when type inference improves
new_array = sync(self._async_array.update_attributes(new_attributes))
# TODO: remove this cast when type inference improves
_new_array = cast(AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], new_array)
_new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array)
return type(self)(_new_array)

def __repr__(self) -> str:
Expand Down Expand Up @@ -4238,7 +4238,7 @@ async def init_array(
serializer=serializer,
dtype=dtype_parsed,
)
sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes))
sub_codecs = cast("tuple[Codec, ...]", (*array_array, array_bytes, *bytes_bytes))
codecs_out: tuple[Codec, ...]
if shard_shape_parsed is not None:
index_location = None
Expand Down Expand Up @@ -4509,7 +4509,7 @@ def _parse_keep_array_attr(
compressors = "auto"
if serializer == "keep":
if zarr_format == 3 and data.metadata.zarr_format == 3:
serializer = cast(SerializerLike, data.serializer)
serializer = cast("SerializerLike", data.serializer)
else:
serializer = "auto"
if fill_value is None:
Expand Down Expand Up @@ -4687,7 +4687,7 @@ def _parse_chunk_encoding_v3(
if isinstance(filters, dict | Codec):
maybe_array_array = (filters,)
else:
maybe_array_array = cast(Iterable[Codec | dict[str, JSON]], filters)
maybe_array_array = cast("Iterable[Codec | dict[str, JSON]]", filters)
out_array_array = tuple(_parse_array_array_codec(c) for c in maybe_array_array)

if serializer == "auto":
Expand All @@ -4704,7 +4704,7 @@ def _parse_chunk_encoding_v3(
if isinstance(compressors, dict | Codec):
maybe_bytes_bytes = (compressors,)
else:
maybe_bytes_bytes = cast(Iterable[Codec | dict[str, JSON]], compressors)
maybe_bytes_bytes = cast("Iterable[Codec | dict[str, JSON]]", compressors)

out_bytes_bytes = tuple(_parse_bytes_bytes_codec(c) for c in maybe_bytes_bytes)

Expand Down
2 changes: 1 addition & 1 deletion src/zarr/core/array_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def from_dict(cls, data: ArrayConfigParams) -> Self:
"""
kwargs_out: ArrayConfigParams = {}
for f in fields(ArrayConfig):
field_name = cast(Literal["order", "write_empty_chunks"], f.name)
field_name = cast("Literal['order', 'write_empty_chunks']", f.name)
if field_name not in data:
kwargs_out[field_name] = zarr_config.get(f"array.{field_name}")
else:
Expand Down
12 changes: 6 additions & 6 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def create_zero_length(cls) -> Self:
if cls is Buffer:
raise NotImplementedError("Cannot call abstract method on the abstract class 'Buffer'")
return cls(
cast(ArrayLike, None)
cast("ArrayLike", None)
) # This line will never be reached, but it satisfies the type checker

@classmethod
Expand Down Expand Up @@ -207,7 +207,7 @@ def from_buffer(cls, buffer: Buffer) -> Self:
if cls is Buffer:
raise NotImplementedError("Cannot call abstract method on the abstract class 'Buffer'")
return cls(
cast(ArrayLike, None)
cast("ArrayLike", None)
) # This line will never be reached, but it satisfies the type checker

@classmethod
Expand All @@ -227,7 +227,7 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self:
if cls is Buffer:
raise NotImplementedError("Cannot call abstract method on the abstract class 'Buffer'")
return cls(
cast(ArrayLike, None)
cast("ArrayLike", None)
) # This line will never be reached, but it satisfies the type checker

def as_array_like(self) -> ArrayLike:
Expand Down Expand Up @@ -358,7 +358,7 @@ def create(
"Cannot call abstract method on the abstract class 'NDBuffer'"
)
return cls(
cast(NDArrayLike, None)
cast("NDArrayLike", None)
) # This line will never be reached, but it satisfies the type checker

@classmethod
Expand Down Expand Up @@ -395,7 +395,7 @@ def from_numpy_array(cls, array_like: npt.ArrayLike) -> Self:
"Cannot call abstract method on the abstract class 'NDBuffer'"
)
return cls(
cast(NDArrayLike, None)
cast("NDArrayLike", None)
) # This line will never be reached, but it satisfies the type checker

def as_ndarray_like(self) -> NDArrayLike:
Expand Down Expand Up @@ -427,7 +427,7 @@ def as_scalar(self) -> ScalarType:
"""Returns the buffer as a scalar value"""
if self._data.size != 1:
raise ValueError("Buffer does not contain a single scalar value")
return cast(ScalarType, self.as_numpy_array()[()])
return cast("ScalarType", self.as_numpy_array()[()])

@property
def dtype(self) -> np.dtype[Any]:
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/buffer/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
return cls.from_array_like(cp.frombuffer(bytes_like, dtype="B"))

def as_numpy_array(self) -> npt.NDArray[Any]:
return cast(npt.NDArray[Any], cp.asnumpy(self._data))
return cast("npt.NDArray[Any]", cp.asnumpy(self._data))

def __add__(self, other: core.Buffer) -> Self:
other_array = other.as_array_like()
Expand Down Expand Up @@ -204,7 +204,7 @@
-------
NumPy array of this buffer (might be a data copy)
"""
return cast(npt.NDArray[Any], cp.asnumpy(self._data))
return cast("npt.NDArray[Any]", cp.asnumpy(self._data))

Check warning on line 207 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L207

Added line #L207 was not covered by tests

def __getitem__(self, key: Any) -> Self:
return self.__class__(self._data.__getitem__(key))
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/chunk_key_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def parse_separator(data: JSON) -> SeparatorLiteral:
if data not in (".", "/"):
raise ValueError(f"Expected an '.' or '/' separator. Got {data} instead.")
return cast(SeparatorLiteral, data)
return cast("SeparatorLiteral", data)


class ChunkKeyEncodingParams(TypedDict):
Expand Down Expand Up @@ -48,7 +48,7 @@ def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncodingLike) -> ChunkKeyEnco
data = {"name": data["name"], "configuration": {"separator": data["separator"]}}

# TODO: remove this cast when we are statically typing the JSON metadata completely.
data = cast(dict[str, JSON], data)
data = cast("dict[str, JSON]", data)

# configuration is optional for chunk key encodings
name_parsed, config_parsed = parse_named_configuration(data, require_configuration=False)
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def parse_fill_value(data: Any) -> Any:

def parse_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
return cast("Literal['C', 'F']", data)
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


Expand Down Expand Up @@ -201,4 +201,4 @@ def _warn_order_kwarg() -> None:

def _default_zarr_format() -> ZarrFormat:
"""Return the default zarr_version"""
return cast(ZarrFormat, int(zarr_config.get("default_zarr_format", 3)))
return cast("ZarrFormat", int(zarr_config.get("default_zarr_format", 3)))
2 changes: 1 addition & 1 deletion src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,6 @@ def enable_gpu(self) -> ConfigSet:

def parse_indexing_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
return cast("Literal['C', 'F']", data)
msg = f"Expected one of ('C', 'F'), got {data} instead."
raise ValueError(msg)
12 changes: 6 additions & 6 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
import warnings
from collections import defaultdict
from collections.abc import Iterator, Mapping
from dataclasses import asdict, dataclass, field, fields, replace
from itertools import accumulate
from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload
Expand Down Expand Up @@ -64,6 +63,8 @@
Coroutine,
Generator,
Iterable,
Iterator,
Mapping,
)
from typing import Any

Expand All @@ -80,15 +81,15 @@
def parse_zarr_format(data: Any) -> ZarrFormat:
"""Parse the zarr_format field from metadata."""
if data in (2, 3):
return cast(ZarrFormat, data)
return cast("ZarrFormat", data)
msg = f"Invalid zarr_format. Expected one of 2 or 3. Got {data}."
raise ValueError(msg)


def parse_node_type(data: Any) -> NodeType:
"""Parse the node_type field from metadata."""
if data in ("array", "group"):
return cast(Literal["array", "group"], data)
return cast("Literal['array', 'group']", data)
raise MetadataValidationError("node_type", "array or group", data)


Expand Down Expand Up @@ -361,7 +362,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
# it's an array
if isinstance(v.get("fill_value", None), np.void):
v["fill_value"] = base64.standard_b64encode(
cast(bytes, v["fill_value"])
cast("bytes", v["fill_value"])
).decode("ascii")
else:
v = _replace_special_floats(v)
Expand Down Expand Up @@ -3245,8 +3246,7 @@ def _ensure_consistent_zarr_format(
raise ValueError(msg)

return cast(
Mapping[str, GroupMetadata | ArrayV2Metadata]
| Mapping[str, GroupMetadata | ArrayV3Metadata],
"Mapping[str, GroupMetadata | ArrayV2Metadata] | Mapping[str, GroupMetadata | ArrayV3Metadata]",
data,
)

Expand Down
Loading