Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/ethereum_spec_tools/new_fork/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _to_args(
[
"codemod",
"--no-format",
"string.StringReplaceCommand",
"string_replace.StringReplaceCommand",
]
+ common,
[
Expand All @@ -229,6 +229,32 @@ def _to_args(
return commands


@dataclass
class ClearDocstring(CodemodArgs):
"""
Describe how to clear the docstring in __init__.py to libcst.tool:main.
"""

@override
def _to_args(
self, fork_builder: "ForkBuilder", working_directory: Path
) -> list[list[str]]:
init_path = (
working_directory
/ "ethereum"
/ fork_builder.new_fork
/ "__init__.py"
)
return [
[
"codemod",
"remove_docstring.RemoveDocstringCommand",
"--no-format",
str(init_path),
]
]


class ForkBuilder:
"""
Takes a template fork and uses it to generate a new fork, applying source
Expand Down Expand Up @@ -353,6 +379,7 @@ def __init__(
RenameFork(),
SetForkCriteria(),
ReplaceForkName(),
ClearDocstring(),
]

def _create_working_directory(self) -> TemporaryDirectory:
Expand Down
35 changes: 35 additions & 0 deletions src/ethereum_spec_tools/new_fork/codemod/remove_docstring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
libcst codemod that removes the module docstring.
"""

import libcst as cst
from libcst.codemod import CodemodCommand
from typing_extensions import override


class RemoveDocstringCommand(CodemodCommand):
"""
Removes the module docstring if it exists.
"""

DESCRIPTION: str = "Remove the module docstring."

@override
def transform_module_impl(self, tree: cst.Module) -> cst.Module:
"""
Transform the tree by removing the docstring.
"""
if len(tree.body) == 0:
return tree
first_stmt = tree.body[0]
if not isinstance(first_stmt, cst.SimpleStatementLine):
return tree
if len(first_stmt.body) != 1:
return tree
expr = first_stmt.body[0]
if not isinstance(expr, cst.Expr):
return tree
if not isinstance(expr.value, cst.SimpleString):
return tree
new_body = tree.body[1:]
return tree.with_changes(body=new_body)
84 changes: 83 additions & 1 deletion tests/json_infra/test_tools_new_fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
from pathlib import Path
from tempfile import TemporaryDirectory

import libcst as cst
import pytest
from libcst.codemod import CodemodContext

from ethereum_spec_tools.forks import Hardfork
from ethereum_spec_tools.new_fork.cli import main as new_fork
from ethereum_spec_tools.new_fork.codemod.remove_docstring import (
RemoveDocstringCommand,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -56,10 +61,14 @@ def test_end_to_end(template_fork: str) -> None:
with (fork_dir / "__init__.py").open("r") as f:
source = f.read()

assert '"""' not in source[:20]
assert "FORK_CRITERIA = ByTimestamp(7)" in source
assert "E2E Fork" in source
assert template_fork.capitalize() not in source

with (fork_dir / "utils" / "hexadecimal.py").open("r") as f:
source = f.read()
assert "E2E Fork" in source

with (fork_dir / "vm" / "gas.py").open("r") as f:
source = f.read()

Expand All @@ -82,3 +91,76 @@ def test_end_to_end(template_fork: str) -> None:
"from ethereum.forks.paris import trie as previous_trie"
in f.read()
)


def has_module_docstring(file_path: Path) -> bool:
"""Return True if the file starts with a module-level doc-string."""
tree = cst.parse_module(file_path.read_text())
if not tree.body:
return False
first = tree.body[0]
if not isinstance(first, cst.SimpleStatementLine):
return False
if len(first.body) != 1:
return False
expr = first.body[0]
return isinstance(expr, cst.Expr) and isinstance(
expr.value, cst.SimpleString
)


def test_remove_docstring_command() -> None:
"""Test that RemoveDocstringCommand removes module docstrings."""
source = '"""Module docstring."""\n\nsome_var = 123\n'
module = cst.parse_module(source)
context = CodemodContext()
command = RemoveDocstringCommand(context)

new_module = command.transform_module(module)
result = new_module.code

assert '"""Module docstring."""' not in result
assert "some_var = 123" in result


def test_remove_docstring_preserves_other_docstrings() -> None:
"""Test that function/class docstrings are preserved."""
source = '''"""Module docstring."""

def foo():
"""Function docstring."""
pass
'''
module = cst.parse_module(source)
context = CodemodContext()
command = RemoveDocstringCommand(context)

new_module = command.transform_module(module)
result = new_module.code

assert not result.startswith('"""Module docstring."""')
assert '"""Function docstring."""' in result


def test_remove_docstring_handles_files_without_docstrings() -> None:
"""Test that files without docstrings remain unchanged."""
source_without_docstring = "some_var = 123\n\ndef foo():\n pass\n"
module = cst.parse_module(source_without_docstring)
context = CodemodContext()
command = RemoveDocstringCommand(context)

new_module = command.transform_module(module)

assert new_module.code == source_without_docstring


def test_remove_docstring_handles_empty_files() -> None:
"""Test that empty files remain empty."""
source_empty = ""
module = cst.parse_module(source_empty)
context = CodemodContext()
command = RemoveDocstringCommand(context)

new_module = command.transform_module(module)

assert new_module.code == source_empty
4 changes: 3 additions & 1 deletion vulture_whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from ethereum_spec_tools.lint.lints.import_hygiene import ImportHygiene
from ethereum_spec_tools.new_fork.codemod.comment import CommentReplaceCommand
from ethereum_spec_tools.new_fork.codemod.constant import SetConstantCommand
from ethereum_spec_tools.new_fork.codemod.string import StringReplaceCommand
from ethereum_spec_tools.new_fork.codemod.string_replace import (
StringReplaceCommand,
)

# src/ethereum/utils/hexadecimal.py
hex_to_bytes256
Expand Down
Loading