diff --git a/sweepai/core/context_pruning.py b/sweepai/core/context_pruning.py index 68e3315562..51fd390226 100644 --- a/sweepai/core/context_pruning.py +++ b/sweepai/core/context_pruning.py @@ -172,6 +172,19 @@ def escape_ripgrep(text): text = text.replace(s, "\\" + s) return text +def run_ripgrep_command(code_entity, repo_dir): + rg_command = [ + "rg", + "-n", + "-i", + code_entity, + repo_dir, + ] + result = subprocess.run( + " ".join(rg_command), text=True, shell=True, capture_output=True + ) + return result.stdout + @staticmethod def can_add_snippet(snippet: Snippet, current_snippets: list[Snippet]): return ( @@ -752,18 +765,8 @@ def handle_function_call( if function_name == "code_search": code_entity = f'"{function_input["code_entity"]}"' # handles cases with two words code_entity = escape_ripgrep(code_entity) # escape special characters - rg_command = [ - "rg", - "-n", - "-i", - code_entity, - repo_context_manager.cloned_repo.repo_dir, - ] try: - result = subprocess.run( - " ".join(rg_command), text=True, shell=True, capture_output=True - ) - rg_output = result.stdout + rg_output = run_ripgrep_command(code_entity, repo_context_manager.cloned_repo.repo_dir) if rg_output: # post process rip grep output to be more condensed rg_output_pretty, file_output_dict, file_to_num_occurrences = post_process_rg_output( diff --git a/tests/test_context_pruning.py b/tests/test_context_pruning.py new file mode 100644 index 0000000000..af1ff650d1 --- /dev/null +++ b/tests/test_context_pruning.py @@ -0,0 +1,50 @@ +import unittest +from sweepai.core.context_pruning import ( + build_full_hierarchy, + load_graph_from_file, + RepoContextManager, + get_relevant_context, +) +import networkx as nx + +class TestContextPruning(unittest.TestCase): + def test_build_full_hierarchy(self): + G = nx.DiGraph() + G.add_edge("main.py", "database.py") + G.add_edge("database.py", "models.py") + G.add_edge("utils.py", "models.py") + hierarchy = build_full_hierarchy(G, "main.py", 2) + expected_hierarchy = """main.py +├── database.py +│ └── models.py +└── utils.py + └── models.py +""" + self.assertEqual(hierarchy, expected_hierarchy) + + def test_load_graph_from_file(self): + graph = load_graph_from_file("tests/test_import_tree.txt") + self.assertIsInstance(graph, nx.DiGraph) + self.assertEqual(len(graph.nodes), 5) + self.assertEqual(len(graph.edges), 4) + + def test_get_relevant_context(self): + cloned_repo = ClonedRepo("sweepai/sweep", "123", "main") + repo_context_manager = RepoContextManager( + dir_obj=None, + current_top_tree="", + snippets=[], + snippet_scores={}, + cloned_repo=cloned_repo, + ) + query = "allow 'sweep.yaml' to be read from the user/organization's .github repository. this is found in client.py and we need to change this to optionally read from .github/sweep.yaml if it exists there" + rcm = get_relevant_context( + query, + repo_context_manager, + seed=42, + ticket_progress=None, + chat_logger=None, + ) + self.assertIsInstance(rcm, RepoContextManager) + self.assertTrue(len(rcm.current_top_snippets) > 0) + self.assertTrue(any("client.py" in snippet.file_path for snippet in rcm.current_top_snippets)) \ No newline at end of file