diff --git a/petercat_utils/data_class.py b/petercat_utils/data_class.py index 4426baec..7377c4f9 100644 --- a/petercat_utils/data_class.py +++ b/petercat_utils/data_class.py @@ -110,3 +110,4 @@ class GitDocTaskNodeType(AutoNameEnum): class GitIssueTaskNodeType(AutoNameEnum): REPO = auto() ISSUE = auto() + ISSUE_PAGE = auto() diff --git a/petercat_utils/rag_helper/git_issue_task.py b/petercat_utils/rag_helper/git_issue_task.py index fac2a077..fa5a8ac0 100644 --- a/petercat_utils/rag_helper/git_issue_task.py +++ b/petercat_utils/rag_helper/git_issue_task.py @@ -4,7 +4,8 @@ from ..data_class import GitIssueTaskNodeType, TaskStatus, TaskType, RAGGitIssueConfig from ..rag_helper import issue_retrieval -g = Github() +GITHUB_PER_PAGE = 30 +g = Github(per_page=GITHUB_PER_PAGE) def add_rag_git_issue_task(config: RAGGitIssueConfig): @@ -22,8 +23,21 @@ def add_rag_git_issue_task(config: RAGGitIssueConfig): return res +def create_rag_git_issue_task(record): + return GitIssueTask(id=record["id"], + issue_id=record["issue_id"], + repo_name=record["repo_name"], + node_type=record["node_type"], + bot_id=record["bot_id"], + status=record["status"], + from_id=record["from_task_id"], + page_index=record["page_index"] + ) + + class GitIssueTask(GitTask): - issue_id: str + issue_id: int + page_index: int node_type: GitIssueTaskNodeType def __init__(self, @@ -33,11 +47,13 @@ def __init__(self, repo_name, status=TaskStatus.NOT_STARTED, from_id=None, - id=None + id=None, + page_index=None ): super().__init__(bot_id=bot_id, type=TaskType.GIT_ISSUE, from_id=from_id, id=id, status=status, repo_name=repo_name) self.issue_id = issue_id + self.page_index = page_index self.node_type = GitIssueTaskNodeType(node_type) def extra_save_data(self): @@ -50,6 +66,8 @@ def handle(self): self.update_status(TaskStatus.IN_PROGRESS) if self.node_type == GitIssueTaskNodeType.REPO: return self.handle_repo_node() + elif self.node_type == GitIssueTaskNodeType.ISSUE_PAGE: + return self.handle_issue_page_node() elif self.node_type == GitIssueTaskNodeType.ISSUE: return self.handle_issue_node() else: @@ -57,13 +75,54 @@ def handle(self): def handle_repo_node(self): repo = g.get_repo(self.repo_name) - repo.get_issues() - issues = [issue for issue in repo.get_issues()] + issues = repo.get_issues(state='all') + latest_page = (self.get_table() + .select('*') + .eq('repo_name', self.repo_name) + .eq('node_type', GitIssueTaskNodeType.ISSUE_PAGE.value) + .order('page_index', desc=True) + .limit(1) + .execute()).data + + slice_page_index = latest_page[0]["page_index"] if len(latest_page) > 0 else 0 + + # The latest page might have a new issue. + if len(latest_page) > 0: + create_rag_git_issue_task(latest_page[0]).send() + + if issues.totalCount > 0: + pages = issues.totalCount // GITHUB_PER_PAGE + (1 if issues.totalCount % GITHUB_PER_PAGE != 0 else 0) + pages_array = list(range(1, pages + 1))[slice_page_index:] + task_list = list( + map( + lambda item: { + "repo_name": self.repo_name, + "status": TaskStatus.NOT_STARTED.value, + "node_type": GitIssueTaskNodeType.ISSUE_PAGE.value, + "from_task_id": self.id, + "bot_id": self.bot_id, + "page_index": item + }, + pages_array, + ), + ) + if len(task_list) > 0: + result = self.get_table().insert(task_list).execute() + for record in result.data: + issue_task = create_rag_git_issue_task(record) + issue_task.send() + + return self.update_status(TaskStatus.COMPLETED) + + def handle_issue_page_node(self): + repo = g.get_repo(self.repo_name) + issues = repo.get_issues(state='all').get_page(self.page_index) + task_list = list( map( lambda item: { "repo_name": self.repo_name, - "issue_id": str(item.number), + "issue_id": item.number, "status": TaskStatus.NOT_STARTED.value, "node_type": GitIssueTaskNodeType.ISSUE.value, "from_task_id": self.id, @@ -73,22 +132,25 @@ def handle_repo_node(self): ), ) if len(task_list) > 0: - result = self.get_table().insert(task_list).execute() - for record in result.data: - issue_task = GitIssueTask(id=record["id"], - issue_id=record["issue_id"], - repo_name=record["repo_name"], - node_type=record["node_type"], - bot_id=record["bot_id"], - status=record["status"], - from_id=record["from_task_id"] - ) - issue_task.send() - - return (self.get_table().update( - {"status": TaskStatus.COMPLETED.value}) - .eq("id", self.id) - .execute()) + existing_issues = (self.get_table() + .select('*') + .in_('issue_id', [item['issue_id'] for item in task_list]) + .eq('repo_name', self.repo_name) + .eq('node_type', GitIssueTaskNodeType.ISSUE.value) + .execute() + ) + + existing_issue_ids = {int(issue['issue_id']) for issue in existing_issues.data} + + new_task_list = [item for item in task_list if item['issue_id'] not in existing_issue_ids] + if len(new_task_list) > 0: + result = self.get_table().insert(new_task_list).execute() + for record in result.data: + issue_task = create_rag_git_issue_task(record) + issue_task.send() + + return self.update_status(TaskStatus.COMPLETED) + def handle_issue_node(self): issue_retrieval.add_knowledge_by_issue( diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index e73bee7f..07f8ffc5 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -5,7 +5,7 @@ import boto3 from .git_doc_task import GitDocTask -from .git_issue_task import GitIssueTask +from .git_issue_task import GitIssueTask, create_rag_git_issue_task from .git_task import GitTask from ..utils.env import get_env_variable @@ -77,15 +77,7 @@ def get_task(task_type: TaskType, task_id: str) -> GitTask: from_id=data["from_task_id"], ) if task_type == TaskType.GIT_ISSUE: - return GitIssueTask( - id=data["id"], - issue_id=data["issue_id"], - repo_name=data["repo_name"], - node_type=data["node_type"], - bot_id=data["bot_id"], - status=data["status"], - from_id=data["from_task_id"], - ) + return create_rag_git_issue_task(data) def trigger_task(task_type: TaskType, task_id: Optional[str]): diff --git a/subscriber/handler.py b/subscriber/handler.py index 84ba74e7..fc5b3508 100644 --- a/subscriber/handler.py +++ b/subscriber/handler.py @@ -9,6 +9,8 @@ def lambda_handler(event, context): batch_item_failures = [] sqs_batch_response = {} + if len(event): + print(f"event batch size is ${len(event)}") for record in event["Records"]: try: body = record["body"]