Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 99 additions & 37 deletions src/ethereum_test_rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

logger = get_logger(__name__)

BlockNumberType = int | Literal["latest", "earliest", "pending"]


Expand Down Expand Up @@ -79,19 +78,25 @@ def __init_subclass__(cls, namespace: str | None = None) -> None:

def post_request(
self,
*,
method: str,
*params: Any,
params: List[Any] | None = None,
extra_headers: Dict | None = None,
request_id: int | str | None = None,
timeout: int | None = None,
) -> Any:
"""Send JSON-RPC POST request to the client RPC server at port defined in the url."""
if extra_headers is None:
extra_headers = {}
if params is None:
params = []

assert self.namespace, "RPC namespace not set"

next_request_id_counter = next(self.request_id_counter)
if request_id is None:
request_id = next_request_id_counter

payload = {
"jsonrpc": "2.0",
"method": f"{self.namespace}_{method}",
Expand All @@ -103,7 +108,8 @@ def post_request(
}
headers = base_header | extra_headers

response = requests.post(self.url, json=payload, headers=headers)
logger.debug(f"Sending RPC request, timeout is set to {timeout}...")
response = requests.post(self.url, json=payload, headers=headers, timeout=timeout)
response.raise_for_status()
response_json = response.json()

Expand Down Expand Up @@ -135,53 +141,79 @@ def __init__(
super().__init__(*args, **kwargs)
self.transaction_wait_timeout = transaction_wait_timeout

def config(self):
def config(self, timeout: int | None = None):
"""`eth_config`: Returns information about a fork configuration of the client."""
try:
response = self.post_request("config")
response = self.post_request(method="config", timeout=timeout)
if response is None:
logger.warning("eth_config request: failed to get response")
return None
return EthConfigResponse.model_validate(
response, context=self.response_validation_context
)
except ValidationError as e:
pprint(e.errors())
raise e
except Exception as e:
logger.debug(f"exception occurred when sending JSON-RPC request: {e}")
raise e

def chain_id(self) -> int:
"""`eth_chainId`: Returns the current chain id."""
return int(self.post_request("chainId"), 16)
response = self.post_request(method="chainId", timeout=10)

return int(response, 16)

def get_block_by_number(self, block_number: BlockNumberType = "latest", full_txs: bool = True):
"""`eth_getBlockByNumber`: Returns information about a block by block number."""
block = hex(block_number) if isinstance(block_number, int) else block_number
return self.post_request("getBlockByNumber", block, full_txs)
params = [block, full_txs]
response = self.post_request(method="getBlockByNumber", params=params)

return response

def get_block_by_hash(self, block_hash: Hash, full_txs: bool = True):
"""`eth_getBlockByHash`: Returns information about a block by hash."""
return self.post_request("getBlockByHash", f"{block_hash}", full_txs)
params = [f"{block_hash}", full_txs]
response = self.post_request(method="getBlockByHash", params=params)

return response

def get_balance(self, address: Address, block_number: BlockNumberType = "latest") -> int:
"""`eth_getBalance`: Returns the balance of the account of given address."""
block = hex(block_number) if isinstance(block_number, int) else block_number
return int(self.post_request("getBalance", f"{address}", block), 16)
params = [f"{address}", block]

response = self.post_request(method="getBalance", params=params)

return int(response, 16)

def get_code(self, address: Address, block_number: BlockNumberType = "latest") -> Bytes:
"""`eth_getCode`: Returns code at a given address."""
block = hex(block_number) if isinstance(block_number, int) else block_number
return Bytes(self.post_request("getCode", f"{address}", block))
params = [f"{address}", block]

response = self.post_request(method="getCode", params=params)

return Bytes(response)

def get_transaction_count(
self, address: Address, block_number: BlockNumberType = "latest"
) -> int:
"""`eth_getTransactionCount`: Returns the number of transactions sent from an address."""
block = hex(block_number) if isinstance(block_number, int) else block_number
return int(self.post_request("getTransactionCount", f"{address}", block), 16)
params = [f"{address}", block]

response = self.post_request(method="getTransactionCount", params=params)

return int(response, 16)

def get_transaction_by_hash(self, transaction_hash: Hash) -> TransactionByHashResponse | None:
"""`eth_getTransactionByHash`: Returns transaction details."""
try:
response = self.post_request("getTransactionByHash", f"{transaction_hash}")
response = self.post_request(
method="getTransactionByHash", params=[f"{transaction_hash}"]
)
if response is None:
return None
return TransactionByHashResponse.model_validate(
Expand All @@ -196,37 +228,45 @@ def get_storage_at(
) -> Hash:
"""`eth_getStorageAt`: Returns the value from a storage position at a given address."""
block = hex(block_number) if isinstance(block_number, int) else block_number
return Hash(self.post_request("getStorageAt", f"{address}", f"{position}", block))
params = [f"{address}", f"{position}", block]

response = self.post_request(method="getStorageAt", params=params)
return Hash(response)

def gas_price(self) -> int:
"""`eth_gasPrice`: Returns the number of transactions sent from an address."""
return int(self.post_request("gasPrice"), 16)
response = self.post_request(method="gasPrice")

return int(response, 16)

def send_raw_transaction(
self, transaction_rlp: Bytes, request_id: int | str | None = None
) -> Hash:
"""`eth_sendRawTransaction`: Send a transaction to the client."""
try:
result_hash = Hash(
self.post_request(
"sendRawTransaction", f"{transaction_rlp.hex()}", request_id=request_id
),
response = self.post_request(
method="sendRawTransaction",
params=[transaction_rlp.hex()],
request_id=request_id, # noqa: E501
)

result_hash = Hash(response)
assert result_hash is not None
return result_hash
except Exception as e:
raise SendTransactionExceptionError(str(e), tx_rlp=transaction_rlp) from e

def send_transaction(self, transaction: Transaction) -> Hash:
"""`eth_sendRawTransaction`: Send a transaction to the client."""
# TODO: is this a copypaste error from above?
try:
result_hash = Hash(
self.post_request(
"sendRawTransaction",
f"{transaction.rlp().hex()}",
request_id=transaction.metadata_string(),
)
response = self.post_request(
method="sendRawTransaction",
params=[transaction.rlp().hex()],
request_id=transaction.metadata_string(), # noqa: E501
)

result_hash = Hash(response)
assert result_hash == transaction.hash
assert result_hash is not None
return transaction.hash
Expand Down Expand Up @@ -318,7 +358,8 @@ class DebugRPC(EthRPC):

def trace_call(self, tr: dict[str, str], block_number: str):
"""`debug_traceCall`: Returns pre state required for transaction."""
return self.post_request("traceCall", tr, block_number, {"tracer": "prestateTracer"})
params = [tr, block_number, {"tracer": "prestateTracer"}]
return self.post_request(method="traceCall", params=params)


class EngineRPC(BaseRPC):
Expand All @@ -341,10 +382,12 @@ def __init__(

def post_request(
self,
*,
method: str,
*params: Any,
params: Any | None = None,
extra_headers: Dict | None = None,
request_id: int | str | None = None,
timeout: int | None = None,
) -> Any:
"""Send JSON-RPC POST request to the client RPC server at port defined in the url."""
if extra_headers is None:
Expand All @@ -357,14 +400,22 @@ def post_request(
extra_headers = {
"Authorization": f"Bearer {jwt_token}",
} | extra_headers

return super().post_request(
method, *params, extra_headers=extra_headers, request_id=request_id
method=method,
params=params,
extra_headers=extra_headers,
timeout=timeout,
request_id=request_id,
)

def new_payload(self, *params: Any, version: int) -> PayloadStatus:
"""`engine_newPayloadVX`: Attempts to execute the given payload on an execution client."""
method = f"newPayloadV{version}"
params_list = [to_json(param) for param in params]

return PayloadStatus.model_validate(
self.post_request(f"newPayloadV{version}", *[to_json(param) for param in params]),
self.post_request(method=method, params=params_list),
context=self.response_validation_context,
)

Expand All @@ -376,11 +427,17 @@ def forkchoice_updated(
version: int,
) -> ForkchoiceUpdateResponse:
"""`engine_forkchoiceUpdatedVX`: Updates the forkchoice state of the execution client."""
method = f"forkchoiceUpdatedV{version}"

if payload_attributes is None:
params = [to_json(forkchoice_state), None]
else:
params = [to_json(forkchoice_state), to_json(payload_attributes)]

return ForkchoiceUpdateResponse.model_validate(
self.post_request(
f"forkchoiceUpdatedV{version}",
to_json(forkchoice_state),
to_json(payload_attributes) if payload_attributes is not None else None,
method=method,
params=params,
),
context=self.response_validation_context,
)
Expand All @@ -395,10 +452,12 @@ def get_payload(
`engine_getPayloadVX`: Retrieves a payload that was requested through
`engine_forkchoiceUpdatedVX`.
"""
method = f"getPayloadV{version}"

return GetPayloadResponse.model_validate(
self.post_request(
f"getPayloadV{version}",
f"{payload_id}",
method=method,
params=[f"{payload_id}"],
),
context=self.response_validation_context,
)
Expand All @@ -410,9 +469,12 @@ def get_blobs(
version: int,
) -> GetBlobsResponse | None:
"""`engine_getBlobsVX`: Retrieves blobs from an execution layers tx pool."""
method = f"getBlobsV{version}"
params = [f"{h}" for h in versioned_hashes]

response = self.post_request(
f"getBlobsV{version}",
[f"{h}" for h in versioned_hashes],
method=method,
params=[params],
)
if response is None: # for tests that request non-existing blobs
logger.debug("get_blobs response received but it has value: None")
Expand All @@ -429,7 +491,7 @@ class NetRPC(BaseRPC):

def peer_count(self) -> int:
"""`net_peerCount`: Get the number of peers connected to the client."""
response = self.post_request("peerCount")
response = self.post_request(method="peerCount")
return int(response, 16) # hex -> int


Expand All @@ -438,4 +500,4 @@ class AdminRPC(BaseRPC):

def add_peer(self, enode: str) -> bool:
"""`admin_addPeer`: Add a peer by enode URL."""
return self.post_request("addPeer", enode)
return self.post_request(method="addPeer", params=[enode])
Loading
Loading