Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ dependencies = [
"tree-sitter-go",
"tree-sitter-javascript",
"tree-sitter-java",
"tree-sitter-php>=0.23.11",
"tree-sitter-ruby",
"tree-sitter-rust==v0.23.2",
"tree-sitter-typescript",
"unidiff",
"textual",
"tree-sitter-php>=0.23.11",
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions swesmith/bug_gen/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from swesmith.bug_gen.adapters.python import get_entities_from_file_py
from swesmith.bug_gen.adapters.ruby import get_entities_from_file_rb
from swesmith.bug_gen.adapters.rust import get_entities_from_file_rs
from swesmith.bug_gen.adapters.typescript import get_entities_from_file_ts

get_entities_from_file = {
"c": get_entities_from_file_c,
Expand All @@ -16,4 +17,5 @@
"py": get_entities_from_file_py,
"rb": get_entities_from_file_rb,
"rs": get_entities_from_file_rs,
"ts": get_entities_from_file_ts,
}
4 changes: 2 additions & 2 deletions swesmith/bug_gen/adapters/javascript.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import re
import warnings

import tree_sitter_javascript as tsjs
import tree_sitter_javascript as js

from swesmith.constants import CodeEntity, CodeProperty, TODO_REWRITE
from tree_sitter import Language, Parser

JS_LANGUAGE = Language(tsjs.language())
JS_LANGUAGE = Language(js.language())


class JavaScriptEntity(CodeEntity):
Expand Down
275 changes: 275 additions & 0 deletions swesmith/bug_gen/adapters/typescript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import re
import warnings

import tree_sitter_typescript as ts
from swesmith.constants import CodeEntity, CodeProperty, TODO_REWRITE
from tree_sitter import Language, Parser

TS_LANGUAGE = Language(ts.language_typescript())


class TypeScriptEntity(CodeEntity):
def _analyze_properties(self):
node = self.node
# Core entity types
if node.type in [
"function_declaration",
"function",
"arrow_function",
"method_definition",
]:
self._tags.add(CodeProperty.IS_FUNCTION)
elif node.type in ["class_declaration", "class"]:
self._tags.add(CodeProperty.IS_CLASS)
# Control flow analysis
self._walk_for_properties(node)

def _walk_for_properties(self, n):
self._check_control_flow(n)
self._check_operations(n)
self._check_binary_expressions(n)
for child in n.children:
self._walk_for_properties(child)

def _check_control_flow(self, n):
if n.type in [
"for_statement",
"for_in_statement",
"for_of_statement",
"while_statement",
"do_statement",
]:
self._tags.add(CodeProperty.HAS_LOOP)
if n.type == "if_statement":
self._tags.add(CodeProperty.HAS_IF)
if any(child.type == "else_clause" for child in n.children):
self._tags.add(CodeProperty.HAS_IF_ELSE)
if n.type in ["try_statement", "catch_clause", "throw_statement"]:
self._tags.add(CodeProperty.HAS_EXCEPTION)

def _check_operations(self, n):
if n.type in ["subscript_expression", "member_expression"]:
self._tags.add(CodeProperty.HAS_LIST_INDEXING)
if n.type == "call_expression":
self._tags.add(CodeProperty.HAS_FUNCTION_CALL)
if n.type == "return_statement":
self._tags.add(CodeProperty.HAS_RETURN)
if n.type in ["import_statement", "import_clause"]:
self._tags.add(CodeProperty.HAS_IMPORT)
if n.type in ["assignment_expression", "variable_declaration"]:
self._tags.add(CodeProperty.HAS_ASSIGNMENT)
if n.type == "arrow_function":
self._tags.add(CodeProperty.HAS_LAMBDA)
if n.type in ["binary_expression", "unary_expression", "update_expression"]:
self._tags.add(CodeProperty.HAS_ARITHMETIC)
if n.type == "decorator":
self._tags.add(CodeProperty.HAS_DECORATOR)
if n.type in ["try_statement", "with_statement"]:
self._tags.add(CodeProperty.HAS_WRAPPER)
if n.type == "class_declaration" and any(
child.type == "class_heritage" for child in n.children
):
self._tags.add(CodeProperty.HAS_PARENT)
if n.type in ["unary_expression", "update_expression"]:
self._tags.add(CodeProperty.HAS_UNARY_OP)

def _check_binary_expressions(self, n):
if n.type == "binary_expression":
self._tags.add(CodeProperty.HAS_BINARY_OP)
# Check for boolean operators
if any(
hasattr(child, "text") and child.text.decode("utf-8") in ["&&", "||"]
for child in n.children
):
self._tags.add(CodeProperty.HAS_BOOL_OP)
# Check for comparison operators (off by one potential)
for child in n.children:
if hasattr(child, "text") and child.text.decode("utf-8") in [
"<",
">",
"<=",
">=",
]:
self._tags.add(CodeProperty.HAS_OFF_BY_ONE)

@property
def name(self) -> str:
return self._extract_name_from_node()

def _extract_name_from_node(self) -> str:
# Function declarations
if self.node.type == "function_declaration":
return self._find_child_text("identifier")
# Method definitions
if self.node.type == "method_definition":
return self._find_child_text("property_identifier")
# Class declarations
if self.node.type == "class_declaration":
return self._find_child_text("identifier")
# Variable declarations with function expressions
if self.node.type == "variable_declarator":
return self._find_child_text("identifier")
# Assignment expressions with function expressions
if self.node.type == "assignment_expression":
return self._find_child_text("identifier")
return ""

def _find_child_text(self, child_type: str) -> str:
for child in self.node.children:
if child.type == child_type:
return child.text.decode("utf-8")
return ""

@property
def signature(self) -> str:
for child in self.node.children:
if child.type in ["statement_block", "class_body"]:
body_start_byte = child.start_byte - self.node.start_byte
signature = self.src_code[:body_start_byte].strip()
if signature.endswith(" {"):
signature = signature[:-2].strip()
return signature
if self.node.type == "arrow_function" and "=>" in self.src_code:
return self.src_code.split("=>")[0].strip() + " =>"
if self.node.type == "variable_declarator":
src_lines = self.src_code.split("\n")
first_line = src_lines[0]
if " = function" in first_line:
brace_pos = first_line.find(" {")
if brace_pos != -1:
return first_line[:brace_pos].strip()
else:
result = first_line.strip()
if result.endswith(";"):
result = result[:-1].strip()
return result
return self.src_code.split("\n")[0].strip()

@property
def stub(self) -> str:
signature = self.signature
if self.node.type == "class_declaration":
return f"{signature} {{\n\t// {TODO_REWRITE}\n}}"
elif self.node.type == "arrow_function":
if "=>" in signature:
return f"{signature} {{\n\t// {TODO_REWRITE}\n}}"
else:
return f"{signature} => {{\n\t// {TODO_REWRITE}\n}}"
else:
return f"{signature} {{\n\t// {TODO_REWRITE}\n}}"

@property
def complexity(self) -> int:
def walk(node):
score = 0
if node.type in [
"if_statement",
"else_clause",
"for_statement",
"for_in_statement",
"for_of_statement",
"while_statement",
"do_statement",
"switch_statement",
"case_clause",
"catch_clause",
"conditional_expression",
]:
score += 1
if node.type == "binary_expression":
for child in node.children:
if hasattr(child, "text") and child.text.decode("utf-8") in [
"&&",
"||",
]:
score += 1
for child in node.children:
score += walk(child)
return score

return 1 + walk(self.node)


def get_entities_from_file_ts(
entities: list[TypeScriptEntity],
file_path: str,
max_entities: int = -1,
) -> list[TypeScriptEntity]:
"""
Parse a .ts file and return up to max_entities top-level functions and classes.
If max_entities < 0, collects them all.
"""
parser = Parser(TS_LANGUAGE)
try:
file_content = open(file_path, "r", encoding="utf8").read()
except UnicodeDecodeError:
warnings.warn(f"Could not decode file {file_path}", stacklevel=2)
return entities
tree = parser.parse(bytes(file_content, "utf8"))
root = tree.root_node
lines = file_content.splitlines()
_walk_and_collect_ts(root, entities, lines, str(file_path), max_entities)
return entities


def _walk_and_collect_ts(node, entities, lines, file_path, max_entities):
if 0 <= max_entities == len(entities):
return
if node.type == "ERROR":
warnings.warn(f"Error encountered parsing {file_path}", stacklevel=2)
return
# Only collect classes, functions, and methods (not variables)
if node.type == "export_statement":
for child in node.children:
if child.type in [
"function_declaration",
"method_definition",
"class_declaration",
]:
entity = _build_entity(child, lines, file_path)
if entity.name.strip() and entity.signature.strip():
entities.append(entity)
if 0 <= max_entities == len(entities):
return
_walk_and_collect_ts(child, entities, lines, file_path, max_entities)
else:
_walk_and_collect_ts(child, entities, lines, file_path, max_entities)
return
elif node.type in [
"function_declaration",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we test for function_declaration handling in the unit tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point - hmm ok maybe i'll find a better testing file in terms of coverage.

I'm also a little tempted to just add some code to the end of this file. It would change the original source code, but might just be an easier solution. I'll think about this!

"method_definition",
"class_declaration",
]:
entity = _build_entity(node, lines, file_path)
if entity.name.strip() and entity.signature.strip():
entities.append(entity)
if 0 <= max_entities == len(entities):
return
for child in node.children:
_walk_and_collect_ts(child, entities, lines, file_path, max_entities)


def _build_entity(node, lines, file_path: str) -> TypeScriptEntity:
start_row, _ = node.start_point
end_row, _ = node.end_point
snippet = lines[start_row : end_row + 1]
first = snippet[0]
m = re.match(r"^(?P<indent>[\t ]*)", first)
indent_str = m.group("indent") if m else ""
indent_size = 1 if "\t" in indent_str else (len(indent_str) or 2)
indent_level = len(indent_str) // indent_size
dedented = []
for line in snippet:
if len(line) >= indent_level * indent_size:
dedented.append(line[indent_level * indent_size :])
else:
dedented.append(line.lstrip("\t "))
return TypeScriptEntity(
file_path=file_path,
indent_level=indent_level,
indent_size=indent_size,
line_start=start_row + 1,
line_end=end_row + 1,
node=node,
src_code="\n".join(dedented),
)
Loading