Skip to content

Commit 8d3ccf8

Browse files
committed
[autorevert] properly resolve github workflow names
1 parent 8aee8b0 commit 8d3ccf8

File tree

4 files changed

+179
-48
lines changed

4 files changed

+179
-48
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import sys
2+
import unittest
3+
4+
5+
# Ensure package import when running from repo root
6+
sys.path.insert(0, "aws/lambda/pytorch-auto-revert")
7+
8+
from pytorch_auto_revert.workflow_resolver import WorkflowResolver
9+
10+
11+
class TestWorkflowResolverRealRepo(unittest.TestCase):
12+
def test_resolve_pull_workflow(self):
13+
resolver = WorkflowResolver.get("pytorch/pytorch")
14+
15+
# Resolve by display name
16+
pull_by_name = resolver.resolve("pull")
17+
self.assertIsNotNone(pull_by_name, "Expected to resolve 'pull' by display name")
18+
19+
# Resolve by basename
20+
pull_by_file = resolver.resolve("pull.yml")
21+
self.assertIsNotNone(
22+
pull_by_file, "Expected to resolve 'pull.yml' by file name"
23+
)
24+
self.assertTrue(
25+
pull_by_file.file_name.endswith("pull.yml"),
26+
"Resolved file name should be 'pull.yml'",
27+
)
28+
29+
def test_resolve_trunk_workflow(self):
30+
resolver = WorkflowResolver.get("pytorch/pytorch")
31+
32+
trunk_by_name = resolver.resolve("trunk")
33+
self.assertIsNotNone(
34+
trunk_by_name, "Expected to resolve 'trunk' by display name"
35+
)
36+
37+
trunk_by_file = resolver.resolve("trunk.yml")
38+
self.assertIsNotNone(
39+
trunk_by_file, "Expected to resolve 'trunk.yml' by file name"
40+
)
41+
self.assertTrue(
42+
trunk_by_file.file_name.endswith("trunk.yml"),
43+
"Resolved file name should be 'trunk.yml'",
44+
)
45+
46+
47+
if __name__ == "__main__":
48+
unittest.main()

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_checker.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""
2-
WorkflowRestartChecker for querying restarted workflows via ClickHouse.
2+
WorkflowRestartChecker for querying restarted workflows via ClickHouse and
3+
dispatching workflows via GitHub with consistent workflow name resolution.
34
"""
45

56
import logging
67
from datetime import datetime, timedelta
8+
from functools import cached_property
79
from typing import Dict, Set
810

911
from .clickhouse_client_helper import CHCliFactory
12+
from .workflow_resolver import WorkflowResolver
1013

1114

1215
class WorkflowRestartChecker:
@@ -28,9 +31,10 @@ def has_restarted_workflow(self, workflow_name: str, commit_sha: str) -> bool:
2831
Returns:
2932
bool: True if workflow was restarted (workflow_dispatch with trunk/* branch)
3033
"""
31-
# Normalize workflow name - remove .yml extension for consistency
32-
normalized_workflow_name = workflow_name.replace(".yml", "")
33-
cache_key = f"{normalized_workflow_name}:{commit_sha}"
34+
# Resolve to display name via GitHub (exact display or file name)
35+
display_name = self.resolver.require(workflow_name).display_name
36+
37+
cache_key = f"{display_name}:{commit_sha}"
3438
if cache_key in self._cache:
3539
return self._cache[cache_key]
3640

@@ -54,7 +58,7 @@ def has_restarted_workflow(self, workflow_name: str, commit_sha: str) -> bool:
5458
"commit_sha": commit_sha,
5559
"workflow_event": "workflow_dispatch",
5660
"head_branch": f"trunk/{commit_sha}",
57-
"workflow_name": normalized_workflow_name,
61+
"workflow_name": display_name,
5862
},
5963
)
6064

@@ -73,8 +77,7 @@ def get_restarted_commits(self, workflow_name: str, days_back: int = 7) -> Set[s
7377
Returns:
7478
Set of commit SHAs that have restarted workflows
7579
"""
76-
# Normalize workflow name - remove .yml extension for consistency
77-
normalized_workflow_name = workflow_name.replace(".yml", "")
80+
display_name = self.resolver.require(workflow_name).display_name
7881
since_date = datetime.now() - timedelta(days=days_back)
7982

8083
query = """
@@ -87,14 +90,14 @@ def get_restarted_commits(self, workflow_name: str, days_back: int = 7) -> Set[s
8790
"""
8891

8992
result = CHCliFactory().client.query(
90-
query, {"workflow_name": normalized_workflow_name, "since_date": since_date}
93+
query, {"workflow_name": display_name, "since_date": since_date}
9194
)
9295

9396
commits = {row[0] for row in result.result_rows}
9497

9598
# Update cache
9699
for commit_sha in commits:
97-
cache_key = f"{normalized_workflow_name}:{commit_sha}"
100+
cache_key = f"{display_name}:{commit_sha}"
98101
self._cache[cache_key] = True
99102

100103
return commits
@@ -114,13 +117,10 @@ def restart_workflow(self, workflow_name: str, commit_sha: str) -> bool:
114117
Returns:
115118
bool: True if workflow was successfully dispatched, False otherwise
116119
"""
117-
# Normalize workflow name
118-
normalized_workflow_name = workflow_name.replace(".yml", "")
119-
120120
# Check if already restarted
121-
if self.has_restarted_workflow(normalized_workflow_name, commit_sha):
121+
if self.has_restarted_workflow(workflow_name, commit_sha):
122122
logging.warning(
123-
f"Workflow {normalized_workflow_name} already restarted for commit {commit_sha}"
123+
f"Workflow {workflow_name} already restarted for commit {commit_sha}"
124124
)
125125
return False
126126

@@ -144,33 +144,35 @@ def restart_workflow(self, workflow_name: str, commit_sha: str) -> bool:
144144
# Use trunk/{sha} tag format
145145
tag_ref = f"trunk/{commit_sha}"
146146

147-
# Add .yml extension for workflow name
148-
workflow_file_name = f"{normalized_workflow_name}.yml"
147+
# Resolve workflow
148+
wf_ref = self.resolver.require(workflow_name)
149149

150-
# Get repo and workflow objects
151-
repo = client.get_repo(f"{self.repo_owner}/{self.repo_name}")
152-
workflow = repo.get_workflow(workflow_file_name)
153-
154-
# Dispatch the workflow
155-
workflow.create_dispatch(ref=tag_ref, inputs={})
150+
# Dispatch via file name
151+
client.get_repo(f"{self.repo_owner}/{self.repo_name}").get_workflow(
152+
wf_ref.file_name
153+
).create_dispatch(ref=tag_ref, inputs={})
156154

157155
# Construct the workflow runs URL
158156
workflow_url = (
159157
f"https://github.com/{self.repo_owner}/{self.repo_name}"
160-
f"/actions/workflows/{workflow_file_name}"
158+
f"/actions/workflows/{wf_ref.file_name}"
161159
f"?query=branch%3Atrunk%2F{commit_sha}"
162160
)
163161
logging.info(
164-
f"Successfully dispatched workflow {normalized_workflow_name} for commit {commit_sha}\n"
162+
f"Successfully dispatched workflow {wf_ref.display_name} for commit {commit_sha}\n"
165163
f" View at: {workflow_url}"
166164
)
167165

168166
# Invalidate cache for this workflow/commit
169-
cache_key = f"{normalized_workflow_name}:{commit_sha}"
167+
cache_key = f"{wf_ref.display_name}:{commit_sha}"
170168
if cache_key in self._cache:
171169
del self._cache[cache_key]
172170
return True
173171

174172
except Exception as e:
175-
logging.error(f"Error dispatching workflow {normalized_workflow_name}: {e}")
173+
logging.error(f"Error dispatching workflow {workflow_name}: {e}")
176174
return False
175+
176+
@cached_property
177+
def resolver(self) -> WorkflowResolver:
178+
return WorkflowResolver.get(f"{self.repo_owner}/{self.repo_name}")
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
WorkflowResolver: Resolve GitHub Actions workflows by exact display or file name.
3+
4+
- Exact matches only (no lowercasing or fuzzy matching)
5+
- Caches per-repo resolver instances
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
from dataclasses import dataclass
12+
from functools import lru_cache
13+
from typing import Optional
14+
15+
import github
16+
17+
from .github_client_helper import GHClientFactory
18+
19+
20+
@dataclass(frozen=True)
21+
class WorkflowRef:
22+
"""Reference to a workflow's identities in a repository."""
23+
24+
display_name: str
25+
file_name: str # basename, e.g., "pull.yml"
26+
27+
28+
class WorkflowResolver:
29+
"""Caches workflows for a repo and resolves by exact names.
30+
31+
Usage:
32+
resolver = WorkflowResolver.get("owner/repo")
33+
wf = resolver.resolve("pull") # display name
34+
wf = resolver.resolve("pull.yml") # file basename
35+
wf = resolver.resolve(".github/workflows/pull.yml") # full path
36+
"""
37+
38+
def __init__(
39+
self, repo_full_name: str, repository: "github.Repository.Repository"
40+
) -> None:
41+
self._repo_full_name = repo_full_name
42+
self._repository = repository
43+
self._by_display: dict[str, WorkflowRef] = {}
44+
self._by_file: dict[str, WorkflowRef] = {}
45+
self._build_indices()
46+
47+
@staticmethod
48+
@lru_cache(maxsize=None)
49+
def get(repo: str) -> "WorkflowResolver":
50+
"""Get a cached resolver for a repo in owner/repo format.
51+
52+
Internally creates a GitHub Repository client using GHClientFactory when
53+
available; otherwise falls back to an anonymous client for public repos.
54+
"""
55+
# Build a client: prefer configured factory; fall back to anonymous
56+
try:
57+
client = GHClientFactory().client
58+
except Exception as e:
59+
# log exception
60+
print(f"Warning: Failed to create GitHub client: {e}")
61+
# Anonymous client for public data; may be rate limited
62+
# (used mostly for the tests)
63+
client = github.Github()
64+
65+
repository = client.get_repo(repo)
66+
return WorkflowResolver(repo_full_name=repo, repository=repository)
67+
68+
def resolve(self, input_name: str) -> Optional[WorkflowRef]:
69+
"""Resolve by exact display name, file basename, or full path.
70+
71+
Returns None if no exact match is found.
72+
"""
73+
if input_name in self._by_display:
74+
return self._by_display[input_name]
75+
if input_name in self._by_file:
76+
return self._by_file[input_name]
77+
return None
78+
79+
def require(self, input_name: str) -> WorkflowRef:
80+
"""Resolve or raise ValueError with a helpful message."""
81+
ref = self.resolve(input_name)
82+
if ref is None:
83+
# Build an informative message with available names
84+
display = ", ".join(sorted(self._by_display))
85+
files = ", ".join(sorted(self._by_file))
86+
raise ValueError(
87+
f"Workflow '{input_name}' not found in {self._repo_full_name}. "
88+
f"Available display names: [{display}]. Available files: [{files}]"
89+
)
90+
return ref
91+
92+
# Internal helpers
93+
94+
def _build_indices(self) -> None:
95+
for w in self._repository.get_workflows():
96+
name = getattr(w, "name", "") or ""
97+
path = getattr(w, "path", "") or ""
98+
base = os.path.basename(path) if path else ""
99+
if not (name and base):
100+
continue
101+
ref = WorkflowRef(display_name=name, file_name=base)
102+
self._by_display[name] = ref
103+
self._by_file[base] = ref

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_restart.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)