diff --git a/merkly/mtree.py b/merkly/mtree.py index 635e908..c59d62b 100644 --- a/merkly/mtree.py +++ b/merkly/mtree.py @@ -28,18 +28,18 @@ class MerkleTree: def __init__( self, - leaves: List[str], + leaves: List[bytes], hash_function: Callable[[bytes, bytes], bytes] = lambda x, y: keccak(x + y), ) -> None: validate_leafs(leaves) validate_hash_function(hash_function) self.hash_function: Callable[[bytes, bytes], bytes] = hash_function - self.raw_leaves: List[str] = leaves - self.leaves: List[str] = self.__hash_leaves(leaves) - self.short_leaves: List[str] = self.short(self.leaves) + self.raw_leaves: List[bytes] = leaves + self.leaves: List[bytes] = self.__hash_leaves(leaves) + self.short_leaves: List[bytes] = self.short(self.leaves) - def __hash_leaves(self, leaves: List[str]) -> List[str]: - return list(map(lambda x: self.hash_function(x.encode(), bytes()), leaves)) + def __hash_leaves(self, leaves: List[bytes]) -> List[bytes]: + return list(map(lambda x: self.hash_function(x, bytes()), leaves)) def __repr__(self) -> str: return f"""MerkleTree(\nraw_leaves: {self.raw_leaves}\nleaves: {self.leaves}\nshort_leaves: {self.short(self.leaves)})""" @@ -51,13 +51,13 @@ def short(self, data: List[str]) -> List[str]: def root(self) -> bytes: return self.make_root(self.leaves) - def proof(self, raw_leaf: str) -> List[Node]: + def proof(self, raw_leaf: bytes) -> List[Node]: return self.make_proof( - self.leaves, [], self.hash_function(raw_leaf.encode(), bytes()) + self.leaves, [], self.hash_function(raw_leaf, bytes()) ) - def verify(self, proof: List[bytes], raw_leaf: str) -> bool: - full_proof = [self.hash_function(raw_leaf.encode(), bytes())] + def verify(self, proof: List[bytes], raw_leaf: bytes) -> bool: + full_proof = [self.hash_function(raw_leaf, bytes())] full_proof.extend(proof) def concat_nodes(left: Node, right: Node) -> Node: @@ -167,15 +167,15 @@ def up_layer(self, leaves: List[bytes]) -> List[bytes]: return new_layer @property - def human_leaves(self) -> List[str]: + def human_leaves(self) -> List[bytes]: return [leaf.hex() for leaf in self.leaves] @property - def human_short_leaves(self) -> List[str]: + def human_short_leaves(self) -> List[bytes]: return [leaf.hex() for leaf in self.short_leaves] @staticmethod - def verify_proof(proof: List[Node], raw_leaf: str, root: str, **kwargs) -> bool: + def verify_proof(proof: List[Node], raw_leaf: bytes, root: str, **kwargs) -> bool: """ Verify the validity of a Merkle proof for a given leaf against the expected root hash. @@ -210,7 +210,7 @@ def verify_proof(proof: List[Node], raw_leaf: str, root: str, **kwargs) -> bool: else: hash_function = kwargs["hash_function"] - full_proof = [hash_function(raw_leaf.encode(), bytes())] + full_proof = [hash_function(raw_leaf, bytes())] full_proof.extend(proof) def concat_nodes(left: Node, right: Node) -> Node: diff --git a/merkly/utils.py b/merkly/utils.py index e3719e5..7874f01 100644 --- a/merkly/utils.py +++ b/merkly/utils.py @@ -92,7 +92,7 @@ def validate_leafs(leafs: List[str]): raise Exception("Invalid size, need > 2") a = isinstance(leafs, List) - b = all(isinstance(leaf, str) for leaf in leafs) + b = all(isinstance(leaf, bytes) for leaf in leafs) if not (a and b): raise Exception("Invalid type of leafs") diff --git a/test/merkle_root/test_merkle_root.py b/test/merkle_root/test_merkle_root.py index 19cd35b..a2e2a71 100644 --- a/test/merkle_root/test_merkle_root.py +++ b/test/merkle_root/test_merkle_root.py @@ -6,6 +6,8 @@ def test_simple_merkle_tree_constructor(): leaves = ["a", "b", "c", "d"] + leaves = list(map(lambda x: x.encode(),leaves)) + tree = MerkleTree(leaves) assert tree.raw_leaves == leaves diff --git a/test/merkletreejs/merkle_proof/merkle_proof_test.js b/test/merkletreejs/merkle_proof/merkle_proof_test.js index 4d2e44c..c358046 100644 --- a/test/merkletreejs/merkle_proof/merkle_proof_test.js +++ b/test/merkletreejs/merkle_proof/merkle_proof_test.js @@ -1,9 +1,9 @@ const { MerkleTree } = require('merkletreejs'); const SHA256 = require('crypto-js/sha256'); -const leaves = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'].map(x => SHA256(x)); +const leaves = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'].map(x => Buffer.from(x, 'utf-8')).map(SHA256); const tree = new MerkleTree(leaves, SHA256); -const leaf = SHA256('a'); +const leaf = SHA256(Buffer.from('a', 'utf-8')); const proof = tree.getProof(leaf).map(node => ({ data: node.data.toString('hex'), position: node.position })); console.log(JSON.stringify({ proof, isValid: tree.verify(proof, leaf, tree.getRoot()) })) diff --git a/test/merkletreejs/merkle_proof/merkle_proof_test.py b/test/merkletreejs/merkle_proof/merkle_proof_test.py index eaa3957..a45686d 100644 --- a/test/merkletreejs/merkle_proof/merkle_proof_test.py +++ b/test/merkletreejs/merkle_proof/merkle_proof_test.py @@ -9,8 +9,9 @@ def sha256(x, y): leaves = ["a", "b", "c", "d", "e", "f", "g", "h"] +leaves = list(map(lambda x: x.encode(),leaves)) tree = MerkleTree(leaves, sha256) -leaf = "a" +leaf = ("a").encode() proof = tree.proof(leaf) formatted_proof = [ {"data": node.data.hex(), "position": node.side.name.lower()} for node in proof diff --git a/test/merkletreejs/merkle_root/merkle_root_test.js b/test/merkletreejs/merkle_root/merkle_root_test.js index 5f3e2ad..ae2e366 100644 --- a/test/merkletreejs/merkle_root/merkle_root_test.js +++ b/test/merkletreejs/merkle_root/merkle_root_test.js @@ -1,7 +1,8 @@ const { MerkleTree } = require('merkletreejs'); const SHA256 = require('crypto-js/sha256'); -const leaves = ['a', 'b', 'c', 'd'].map(SHA256); + +const leaves = ['a', 'b', 'c', 'd'].map(x => Buffer.from(x, 'utf-8')).map(SHA256); const tree = new MerkleTree(leaves, SHA256, {}); const root = tree.getRoot().toString('hex'); diff --git a/test/merkletreejs/merkle_root/merkle_root_test.py b/test/merkletreejs/merkle_root/merkle_root_test.py index cb69112..850e3ab 100644 --- a/test/merkletreejs/merkle_root/merkle_root_test.py +++ b/test/merkletreejs/merkle_root/merkle_root_test.py @@ -9,6 +9,7 @@ def sha256(x, y): leaves = ["a", "b", "c", "d"] +leaves = list(map(lambda x: x.encode(),leaves)) tree = MerkleTree(leaves, sha256) root = tree.root.hex()