Skip to content

Add some more types #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: add-mypy-typing
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ repos:
ssz/utils.py
ssz/constants.py
ssz/abc.py
ssz/exceptions.py
ssz/hash.py
ssz/hash_tree.py
ssz/hashable_structure.py
language: system
always_run: true
pass_filenames: false
Expand Down
59 changes: 41 additions & 18 deletions ssz/hash_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
partial,
)
import itertools
from numbers import (
Integral,
)
from typing import (
Any,
Callable,
Generator,
Iterable,
Optional,
Tuple,
Union,
cast,
overload,
)

# `transform` comes from a non-public API which is considered stable, but future changes
Expand All @@ -37,6 +38,7 @@
from pyrsistent.typing import (
PMap,
PVector,
PVectorEvolver,
)

from ssz.constants import (
Expand Down Expand Up @@ -105,10 +107,15 @@ def chunks(self) -> RawHashTreeLayer:
def root(self) -> Hash32:
return self.raw_hash_tree[-1][0]

def transform(self, *transformations):
return transform(self, transformations)
def transform(
self,
*transformations: Tuple[
Tuple[Tuple[int, ...], Union[Any, Callable[[Any], Any]]], ...
],
) -> "HashTree":
return cast("HashTree", transform(self, transformations))

def evolver(self):
def evolver(self) -> "HashTreeEvolver":
return HashTreeEvolver(self)

#
Expand All @@ -133,7 +140,7 @@ def __len__(self) -> int:
def __getitem__(self, index: Union[int, slice]) -> Hash32:
return self.chunks[index]

def index(self, value: Hash32, *args, **kwargs) -> Hash32:
def index(self, value: Hash32, *args: Any, **kwargs: Any) -> Hash32:
return self.chunks.index(value, *args, **kwargs)

def count(self, value: Hash32) -> int:
Expand All @@ -155,7 +162,9 @@ def extend(self, value: Iterable[Hash32]) -> "HashTree":
def __add__(self, other: Iterable[Hash32]) -> "HashTree":
return self.extend(other)

def __mul__(self, times: int) -> "HashTree":
# we override __mul__ to allow for a more natural syntax
# when using the evolver
def __mul__(self, times: int) -> "HashTree": # type: ignore[override]
if times <= 0:
raise ValueError(f"Multiplier must be greater or equal to 1, got {times}")

Expand Down Expand Up @@ -194,7 +203,7 @@ def remove(self, value: Hash32) -> "HashTree":
return self.__class__.compute(chunks, self.chunk_count)


class HashTreeEvolver:
class HashTreeEvolver(PVectorEvolver[Hash32]):
def __init__(self, hash_tree: "HashTree") -> None:
self.original_hash_tree = hash_tree
self.updated_chunks: PMap[int, Hash32] = pmap()
Expand All @@ -203,7 +212,18 @@ def __init__(self, hash_tree: "HashTree") -> None:
#
# Getters
#
@overload
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overloads added with subclassing from PVectorEvolver.

def __getitem__(self, index: int) -> Hash32:
...

@overload
def __getitem__(self, index: slice) -> "HashTreeEvolver":
...

def __getitem__(self, index: Union[int, slice]) -> Union[Hash32, "HashTreeEvolver"]:
if isinstance(index, slice):
raise NotImplementedError("Slicing not implemented.")

if index < 0:
index += len(self)

Expand All @@ -222,15 +242,16 @@ def __len__(self) -> int:
return len(self.original_hash_tree) + len(self.appended_chunks)

def is_dirty(self) -> bool:
return self.updated_chunks or self.appended_chunks
return any([self.updated_chunks, self.appended_chunks])

#
# Setters
#
def set(self, index: Integral, value: Hash32) -> None:
def set(self, index: int, value: Hash32) -> "HashTreeEvolver":
self[index] = value
return self

def __setitem__(self, index: Integral, value: Hash32) -> None:
def __setitem__(self, index: int, value: Hash32) -> None:
if index < 0:
index += len(self)

Expand All @@ -245,13 +266,15 @@ def __setitem__(self, index: Integral, value: Hash32) -> None:
#
# Length changing modifiers
#
def append(self, value: Hash32) -> None:
def append(self, value: Hash32) -> "HashTreeEvolver":
self.appended_chunks = self.appended_chunks.append(value)
self._check_chunk_count()
return self

def extend(self, values: Iterable[Hash32]) -> None:
def extend(self, values: Iterable[Hash32]) -> "HashTreeEvolver":
self.appended_chunks = self.appended_chunks.extend(values)
self._check_chunk_count()
return self

def _check_chunk_count(self) -> None:
chunk_count = self.original_hash_tree.chunk_count
Expand All @@ -261,13 +284,13 @@ def _check_chunk_count(self) -> None:
#
# Not implemented
#
def delete(self, index, stop=None):
def delete(self, index: int, stop: Optional[int] = None) -> None: # type: ignore[override] # noqa: E501
raise NotImplementedError()

def __delitem__(self, index):
def __delitem__(self, index: Union[int, slice]) -> None:
raise NotImplementedError()

def remove(self, value):
def remove(self, value: Hash32) -> None:
raise NotImplementedError()

#
Expand Down Expand Up @@ -422,7 +445,7 @@ def set_chunk_in_tree(hash_tree: RawHashTree, index: int, chunk: Hash32) -> RawH
for layer_index, hash_index in zip(parent_layer_indices, parent_hash_indices)
)

hash_tree_with_updated_branch = pipe(
hash_tree_with_updated_branch: PVector[PVector[Any]] = pipe(
hash_tree_with_updated_chunk, *update_functions
)

Expand Down
40 changes: 26 additions & 14 deletions ssz/hashable_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import itertools
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Expand All @@ -12,6 +14,7 @@
Tuple,
TypeVar,
Union,
cast,
)

from eth_typing import (
Expand Down Expand Up @@ -53,11 +56,15 @@
BaseProperCompositeSedes,
)

TStructure = TypeVar("TStructure", bound="BaseHashableStructure")
TElement = TypeVar("TElement")
TStructure = TypeVar("TStructure", bound="BaseHashableStructure[TElement]")
# TStructure = TypeVar("TStructure", bound="BaseHashableStructure")
TResizableStructure = TypeVar(
"TResizableStructure", bound="BaseResizableHashableStructure"
"TResizableStructure", bound="BaseResizableHashableStructure[TElement]"
)
TElement = TypeVar("TElement")

TSerializable = TypeVar("TSerializable")
TDeserialized = TypeVar("TDeserialized")


def update_element_in_chunk(
Expand Down Expand Up @@ -195,12 +202,12 @@ def get_appended_chunks(
yield Hash32(b"".join(elements_in_chunk))


class BaseHashableStructure(HashableStructureAPI[TElement]):
class BaseHashableStructure(HashableStructureAPI[TElement], Generic[TElement]):
def __init__(
self,
elements: PVector[TElement],
hash_tree: HashTree,
sedes: BaseProperCompositeSedes,
sedes: BaseProperCompositeSedes[TSerializable, TDeserialized],
max_length: Optional[int] = None,
) -> None:
self._elements = elements
Expand All @@ -212,9 +219,9 @@ def __init__(
def from_iterable_and_sedes(
cls,
iterable: Iterable[TElement],
sedes: BaseProperCompositeSedes,
sedes: BaseProperCompositeSedes[TSerializable, TDeserialized],
max_length: Optional[int] = None,
):
) -> "BaseHashableStructure[TElement]":
elements = pvector(iterable)
if max_length and len(elements) > max_length:
raise ValueError(
Expand Down Expand Up @@ -257,7 +264,7 @@ def raw_root(self) -> Hash32:
return self.hash_tree.root

@property
def sedes(self) -> BaseProperCompositeSedes:
def sedes(self) -> BaseProperCompositeSedes[TSerializable, TDeserialized]:
return self._sedes

#
Expand Down Expand Up @@ -289,10 +296,15 @@ def __getitem__(self, index: int) -> TElement:
def __iter__(self) -> Iterator[TElement]:
return iter(self.elements)

def transform(self, *transformations):
return transform(self, transformations)
def transform(
self,
*transformations: Tuple[
Tuple[Tuple[int, ...], Union[Any, Callable[[Any], Any]]], ...
],
) -> "BaseHashableStructure[TElement]":
return cast("BaseHashableStructure", transform(self, transformations))

def mset(self: TStructure, *args: Union[int, TElement]) -> TStructure:
def mset(self, *args: Union[int, TElement]) -> "BaseHashableStructure[TElement]":
if len(args) % 2 != 0:
raise TypeError(
f"mset must be called with an even number of arguments, got {len(args)}"
Expand All @@ -303,7 +315,7 @@ def mset(self: TStructure, *args: Union[int, TElement]) -> TStructure:
evolver[index] = value
return evolver.persistent()

def set(self: TStructure, index: int, value: TElement) -> TStructure:
def set(self, index: int, value: TElement) -> "BaseHashableStructure[TElement]":
return self.mset(index, value)

def evolver(
Expand Down Expand Up @@ -399,7 +411,7 @@ def persistent(self) -> TStructure:
)
).extend(self._appended_elements)
hash_tree = self._original_structure.hash_tree.mset(
*itertools.chain.from_iterable(updated_chunks.items()) # type: ignore
*itertools.chain.from_iterable(updated_chunks.items())
).extend(appended_chunks)

return self._original_structure.__class__(
Expand All @@ -408,7 +420,7 @@ def persistent(self) -> TStructure:


class BaseResizableHashableStructure(
BaseHashableStructure, ResizableHashableStructureAPI[TElement]
BaseHashableStructure[TElement], ResizableHashableStructureAPI[TElement]
):
def append(self: TResizableStructure, value: TElement) -> TResizableStructure:
evolver = self.evolver()
Expand Down
2 changes: 1 addition & 1 deletion ssz/sedes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def infer_sedes(
if isinstance(value.__class__, BaseSedes):
return value.__class__
elif isinstance(value, BaseHashableStructure):
return cast(BaseProperCompositeSedes[TSerializable, TDeserialized], value.sedes)
return value.sedes
elif isinstance(value, bool):
return cast(BaseSedes[bool, bool], boolean)
elif isinstance(value, int):
Expand Down