Skip to content

Commit

Permalink
Added custom hash (#4047)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinlu1248 authored Jun 19, 2024
2 parents 080b571 + 3673fcd commit bb50fde
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 77 deletions.
13 changes: 10 additions & 3 deletions sweepai/agents/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_code_suggestions(
modify_files_dict: dict[str, dict[str, str]],
fcrs: list[FileChangeRequest],
error_messages_dict: dict[int, str],
cloned_repo: ClonedRepo,
) -> list[StatefulCodeSuggestion]:
modify_order = []
for fcr in fcrs:
Expand All @@ -33,6 +34,7 @@ def generate_code_suggestions(
file_path=file_path,
original_code=file_data["original_contents"],
new_code=file_data["contents"],
file_contents=file_data["original_contents"],
state="done"
))

Expand All @@ -43,11 +45,16 @@ def generate_code_suggestions(
continue
else:
parsed_fcr = parse_fcr(fcr)
try:
file_contents = cloned_repo.get_file_contents(fcr.filename)
except FileNotFoundError:
file_contents = ""
code_suggestions.append(StatefulCodeSuggestion(
file_path=fcr.filename,
original_code=parsed_fcr["original_code"][0] if parsed_fcr["original_code"] else "",
new_code=parsed_fcr["new_code"][0] if parsed_fcr["new_code"] else "",
state=("processing" if i == current_fcr_index else "pending") if i not in error_messages_dict else "error",
file_contents=file_contents,
state=("processing" if i == current_fcr_index else "pending"),
error=error_messages_dict.get(i, None)
))
return code_suggestions
Expand Down Expand Up @@ -147,7 +154,7 @@ def modify(
error_messages_dict = get_error_message_dict(fcrs, cloned_repo, modify_files_dict, renames_dict)
previous_modify_files_dict = copy.deepcopy(modify_files_dict)
for i in range(len(fcrs) * 15):
yield generate_code_suggestions(modify_files_dict, fcrs, error_messages_dict)
yield generate_code_suggestions(modify_files_dict, fcrs, error_messages_dict, cloned_repo)
function_call = validate_and_parse_function_call(function_calls_string, chat_gpt)
if function_call:
num_of_tasks_done = tasks_completed(fcrs)
Expand Down Expand Up @@ -312,7 +319,7 @@ def modify(
):
diff_string += f"\nChanges made to {file_name}:\n{diff}"
logger.info("\n".join(generate_diff(file_data["original_contents"], file_data["contents"]) for file_data in modify_files_dict.values())) # adding this as a useful way to render the diffs
yield generate_code_suggestions(modify_files_dict, fcrs, error_messages_dict)
yield generate_code_suggestions(modify_files_dict, fcrs, error_messages_dict, cloned_repo)
return modify_files_dict


Expand Down
1 change: 1 addition & 0 deletions sweepai/agents/modify_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,7 @@ def get_file_contents(file_path):

if not best_match.strip():
error_messages.append(f"<original_code> does not exist in `{file_change_request.filename}`. Your proposed <original_code> contains:\n```\n{indent(original_code, best_indent)}\n```\nBut the code is no where to be found in the file. There are also no similar code snippets in this file.{too_long_message}{ellipses_message}")
error_indices.append(i)
continue
if best_score != 100:
if not check_valid_parentheses(best_match):
Expand Down
138 changes: 67 additions & 71 deletions sweepai/chat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,39 @@

DEFAULT_K = 8

def get_cloned_repo(
repo_name: str,
access_token: str,
branch: str = None,
messages: list[Message] = [],
):
org_name, repo = repo_name.split("/")
if branch:
cloned_repo = ClonedRepo(
repo_name,
token=access_token,
installation_id=get_cached_installation_id(org_name)
)
cloned_repo.branch = branch
try:
cloned_repo.git_repo.git.checkout(branch)
except Exception as e:
logger.warning(f"Error checking out branch {branch}: {e}. Trying to checkout PRs.")
for message in messages:
for pull in message.annotations["pulls"]:
if pull["branch"] == branch:
pr = cloned_repo.repo.get_pull(pull["number"])
sha = pr.head.sha
cloned_repo.git_repo.git.fetch("origin", sha)
cloned_repo.git_repo.git.checkout(sha)
logger.info(f"Checked out PR {pull['number']} with SHA {sha}")
return cloned_repo
raise Exception(f"Branch {branch} not found")
else:
cloned_repo = MockClonedRepo(f"{repo_cache}/{repo}", repo_name, token=access_token)
cloned_repo.git_repo.git.pull()
return cloned_repo

def get_pr_snippets(
repo_name: str,
annotations: dict,
Expand Down Expand Up @@ -240,57 +273,14 @@ def check_repo_exists(
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}

@app.get("/backend/search")
def search_codebase_endpoint_get(
repo_name: str,
query: str,
stream: bool = False,
access_token: str = Depends(get_token_header)
):
"""
DEPRECATED, use POST instead.
"""
with Timer() as timer:
g = get_authenticated_github_client(repo_name, access_token)
logger.debug(f"Getting authenticated GitHub client took {timer.time_elapsed} seconds")
if not g:
return {"success": False, "error": "The repository may not exist or you may not have access to this repository."}
username = Github(access_token).get_user().login
token = g.token if isinstance(g, CustomGithub) else access_token
if stream:
def stream_response():
yield json.dumps(["Building lexical index...", []])
for message, snippets in wrapped_search_codebase.stream(
username,
repo_name,
query,
access_token=token,
metadata={
"repo_name": repo_name,
"query": query,
}
):
yield json.dumps((message, [snippet.model_dump() for snippet in snippets]))
return StreamingResponse(stream_response())
else:
return [snippet.model_dump() for snippet in wrapped_search_codebase(
username,
repo_name,
query,
access_token=token,
metadata={
"repo_name": repo_name,
"query": query,
}
)]

@app.post("/backend/search")
def search_codebase_endpoint_post(
def search_codebase_endpoint(
repo_name: str = Body(...),
query: str = Body(...),
annotations: dict = Body({}),
access_token: str = Depends(get_token_header)
access_token: str = Depends(get_token_header),
branch: str = Body(None),
):
with Timer() as timer:
g = get_authenticated_github_client(repo_name, access_token)
Expand All @@ -307,6 +297,7 @@ def stream_response():
query,
token,
annotations=annotations,
branch=branch,
metadata={
"repo_name": repo_name,
"query": query,
Expand All @@ -323,19 +314,24 @@ def wrapped_search_codebase(
query: str,
access_token: str,
annotations: dict = {},
branch: str = None,
metadata: dict = {},
):
org_name, repo = repo_name.split("/")
if not os.path.exists(f"{repo_cache}/{repo}"):
if not os.path.exists(f"{repo_cache}/{repo}") and not branch:
yield "Cloning repository...", []
print(f"Cloning {repo_name} to {repo_cache}/{repo}")
git.Repo.clone_from(f"https://x-access-token:{access_token}@github.com/{repo_name}", f"{repo_cache}/{repo}")
print(f"Cloned {repo_name} to {repo_cache}/{repo}")
yield "Repository cloned.", []
cloned_repo = MockClonedRepo(f"{repo_cache}/{repo}", repo_name, token=access_token)
else:
cloned_repo = MockClonedRepo(f"{repo_cache}/{repo}", repo_name, token=access_token)
cloned_repo.pull()
yield f"Cloning into {repo_name}:{branch}...", []
cloned_repo = get_cloned_repo(repo_name, access_token, branch, [Message(
content=query,
role="user",
annotations=annotations,
)])
yield "Repository pulled.", []
if annotations:
yield "Getting pull request snippets...", []
Expand All @@ -353,7 +349,7 @@ def wrapped_search_codebase(
repo_name,
query,
access_token,
use_optimized_query=not bool(annotations),
use_optimized_query=not bool(annotations["pulls"]),
):
yield message, snippets

Expand Down Expand Up @@ -402,7 +398,7 @@ def chat_codebase(
messages: list[Message] = Body(...),
snippets: list[Snippet] = Body(...),
model: str = Body(...),
use_patch: bool = Body(True),
branch: str = Body(None),
k: int = Body(DEFAULT_K),
access_token: str = Depends(get_token_header)
):
Expand All @@ -427,7 +423,7 @@ def chat_codebase(
"snippets": [snippet.model_dump() for snippet in snippets],
},
model=model,
use_patch=use_patch,
branch=branch,
k=k
)

Expand Down Expand Up @@ -456,14 +452,13 @@ def chat_codebase_stream(
metadata: dict = {},
k: int = DEFAULT_K,
model: str = "claude-3-opus-20240229",
use_patch: bool = False,
branch: str = None,
):
EXPAND_SIZE = 100
if not snippets:
raise ValueError("No snippets were sent.")
org_name, repo = repo_name.split("/")
cloned_repo = MockClonedRepo(f"{repo_cache}/{repo}", repo_name, token=access_token)
cloned_repo.git_repo.git.pull()
cloned_repo = get_cloned_repo(repo_name, access_token, branch, messages)
repo_specific_description = get_repo_specific_description(cloned_repo=cloned_repo)
use_openai = model.startswith("gpt")
snippets_message = relevant_snippets_message.format(
Expand Down Expand Up @@ -751,24 +746,18 @@ def stream_state(
"messages": [message.model_dump() for message in messages],
})

def postprocessed_stream(*args, use_patch=False, **kwargs):
def postprocessed_stream(*args, **kwargs):
previous_state = []
try:
for messages in stream_state(*args, **kwargs):
if not use_patch:
yield json.dumps([
message.model_dump()
for message in messages
])
else:
current_state = [
message.model_dump()
for message in messages
]
patch = jsonpatch.JsonPatch.from_diff(previous_state, current_state)
if patch:
yield patch.to_string()
previous_state = current_state
current_state = [
message.model_dump()
for message in messages
]
patch = jsonpatch.JsonPatch.from_diff(previous_state, current_state)
if patch:
yield patch.to_string()
previous_state = current_state
except Exception as e:
yield json.dumps([
{
Expand All @@ -786,7 +775,6 @@ def postprocessed_stream(*args, use_patch=False, **kwargs):
metadata,
model,
use_openai=use_openai,
use_patch=use_patch,
k=k
)
)
Expand Down Expand Up @@ -952,6 +940,8 @@ async def create_pull(
head=new_branch,
base=base_branch,
)
g = get_authenticated_github_client(repo_name, access_token)
pull_request.add_to_assignees(g.get_user().login)
file_diffs = pull_request.get_files()

return {
Expand Down Expand Up @@ -1014,8 +1004,11 @@ async def write_message_to_disk(
repo_name: str = Body(...),
messages: list[Message] = Body(...),
snippets: list[Snippet] = Body(...),
original_code_suggestions: list[CodeSuggestion] = Body([]),
code_suggestions: list = Body([]),
pull_request: dict | None = Body(None),
pull_request_title: str = Body(""),
pull_request_description: str = Body(""),
message_id: str = Body(""),
):
if not message_id:
Expand All @@ -1025,8 +1018,11 @@ async def write_message_to_disk(
"repo_name": repo_name,
"messages": [message.model_dump() for message in messages],
"snippets": [snippet.model_dump() for snippet in snippets],
"original_code_suggestions": [code_suggestion.__dict__ if isinstance(code_suggestion, CodeSuggestion) else code_suggestion for code_suggestion in original_code_suggestions],
"code_suggestions": [code_suggestion.__dict__ if isinstance(code_suggestion, CodeSuggestion) else code_suggestion for code_suggestion in code_suggestions],
"pull_request": pull_request,
"pull_request_title": pull_request_title,
"pull_request_description": pull_request_description,
}
with open(f"{CACHE_DIRECTORY}/messages/{message_id}.json", "w") as file:
json.dump(data, file)
Expand Down
2 changes: 1 addition & 1 deletion sweepai/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def is_file_bad(self, file_name: str, repo_dir: str) -> tuple[bool, str]:
if bool(match):
return True, "The filename means that this file is likely auto generated."
except Exception as e:
logger.error(f"Error when checking if file is autogenerated: {e}")
logger.error(f"Error when checking if file is autogenerated: {e}, run `sudo apt-get install cmake pkg-config libicu-dev zlib1g-dev libcurl4-openssl-dev libssl-dev ruby-dev && gem install github-linguist`")
posthog.capture(
"is_file_auto_generated_or_vendored",
"is_file_auto_generated_or_vendored error",
Expand Down
4 changes: 3 additions & 1 deletion sweepai/dataclasses/code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ class CodeSuggestion:
original_code: str
new_code: str

file_contents: str = ""

@dataclass
class StatefulCodeSuggestion(CodeSuggestion):
state: Literal["pending", "processing", "done", "error"]
state: Literal["pending", "processing", "done", "error"] = "pending"
error: Optional[str] = None

1 change: 0 additions & 1 deletion sweepai/utils/ticket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from loguru import logger
from tqdm import tqdm
import networkx as nx
from sweepai.core.chat import call_llm
from sweepai.utils.streamable_functions import streamable

from sweepai.utils.timer import Timer
Expand Down

0 comments on commit bb50fde

Please sign in to comment.