Skip to content

Commit 6e6d9c8

Browse files
[autorevert] properly resolve github workflow names (#7010)
Introduces `WorkflowRestartChecker` that resolves githb workflows name <-> display name. Uses it in the codebase instead of relying on naming convention (that apparently is not reliable, see lint.yml). ### Testing python -m unittest -v pytorch_auto_revert/tests/test_workflow_resolver.py ``` python -m unittest -v pytorch_auto_revert/tests/test_workflow_resolver.py test_resolve_pull_workflow (pytorch_auto_revert.tests.test_workflow_resolver.TestWorkflowResolverRealRepo.test_resolve_pull_workflow) ... ok test_resolve_trunk_workflow (pytorch_auto_revert.tests.test_workflow_resolver.TestWorkflowResolverRealRepo.test_resolve_trunk_workflow) ... ok ---------------------------------------------------------------------- Ran 2 tests in 3.369s OK ``` ``` python -m pytorch_auto_revert Fetching commits for workflow(s) 'Lint, trunk, pull, inductor' (last 48h)... Commit data by workflow: Fetching workflow data for 4 workflows since 2025-08-11T13:40:22.263487... Found 115 commits with job data for workflow 'Lint' Found 113 commits with job data for workflow 'inductor' Found 117 commits with job data for workflow 'pull' Found 112 commits with job data for workflow 'trunk' Lint: 115 commits trunk: 112 commits pull: 117 commits inductor: 113 commits ✓ 2 AUTOREVERT PATTERNS DETECTED Pattern #1: Failure rule: 'Lintrunner failure' Recent commits with failure: a3df3aa8 d3329604 Older commit without failure: 01bcf9a4 ✗ NOT REVERTED: d3329604e2f2c5523c68003a1093ba31e4937d42 was not reverted INFO:root:Successfully dispatched workflow Lint for commit d3329604e2f2c5523c68003a1093ba31e4937d42 View at: https://github.com/pytorch/pytorch/actions/workflows/lint.yml?query=branch%3Atrunk%2Fd3329604e2f2c5523c68003a1093ba31e4937d42 ✓ RESTARTED: Lint for d3329604 INFO:root:Successfully dispatched workflow Lint for commit 01bcf9a40dea937637d2cdd530bed2652510943d View at: https://github.com/pytorch/pytorch/actions/workflows/lint.yml?query=branch%3Atrunk%2F01bcf9a40dea937637d2cdd530bed2652510943d ✓ RESTARTED: Lint for 01bcf9a4 Failed jobs (1): - lintrunner-noclang / linux-job Pattern #2: Failure rule: 'pytest failure' Recent commits with failure: 9708fcf9 a288b15e Older commit without failure: d2d29707 ✗ NOT REVERTED: a288b15ea9f87ddd665f249d492e0fb0861f5a69 was not reverted ⟳ ALREADY RESTARTED: pull for a288b15e INFO:root:Successfully dispatched workflow pull for commit d2d2970734e0c3d241002f386bb8ca208e2594bb View at: https://github.com/pytorch/pytorch/actions/workflows/pull.yml?query=branch%3Atrunk%2Fd2d2970734e0c3d241002f386bb8ca208e2594bb ✓ RESTARTED: pull for d2d29707 Failed jobs (8): - linux-jammy-py3.9-gcc11 / test (distributed, 2, 2, linux.2xlarge) - linux-jammy-py3.9-clang12 / test (default, 1, 5, linux.4xlarge) - linux-jammy-py3.10-clang18-asan / test (default, 2, 6, linux.4xlarge) - linux-jammy-py3.10-clang18-asan / test (default, 3, 6, linux.4xlarge) - linux-jammy-py3.13-clang12 / test (crossref, 1, 2, linux.2xlarge) ... and 3 more ``` --------- Co-authored-by: Jean Schmidt <[email protected]>
1 parent 5382f4d commit 6e6d9c8

File tree

6 files changed

+202
-59
lines changed

6 files changed

+202
-59
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def main(*args, **kwargs) -> None:
174174
os.environ.get("WORKFLOWS", "Lint,trunk,pull,inductor").split(","),
175175
do_restart=True,
176176
do_revert=True,
177-
hours=int(os.environ.get("HOURS", 48)),
177+
hours=int(os.environ.get("HOURS", 16)),
178178
verbose=True,
179179
dry_run=opts.dry_run,
180180
ignore_common_errors=True,

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/autorevert.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def autorevert_checker(
221221
print(f" ... and {len(pattern['failed_job_names']) - 5} more")
222222

223223
# Job coverage overlap logging removed (older_job_coverage dropped from pattern)
224-
225224
if revert_result and verbose:
226225
print(f"Revert message: {revert_result['revert_message'][:100]}...")
227226

@@ -281,7 +280,6 @@ def autorevert_checker(
281280
if len_non_ghfirst_reverts > 0
282281
else 0
283282
)
284-
# recall_non_ghfirst = 1 - ratio_non_ghfirst_reverts
285283
print(
286284
"Reverts (excluding ghfirst) that dont match any auto revert pattern detected (%): "
287285
+ f"({len_not_found_non_ghfirst}) ({ratio_non_ghfirst_reverts * 100:.1f}%)"
@@ -302,14 +300,19 @@ def autorevert_checker(
302300
else 0.0
303301
)
304302

305-
print()
306-
print("*********************************************************************")
307-
print("STATS SUMMARY:")
308-
print(f" PRECISION: {stats_precision * 100:.1f}%")
309-
print(f" RECALL: {stats_recall * 100:.1f}%")
310-
print(f" F1: {stats_f1 * 100:.1f}%")
311-
print("*********************************************************************")
312-
print()
303+
if verbose:
304+
print()
305+
print(
306+
"*********************************************************************"
307+
)
308+
print("STATS SUMMARY:")
309+
print(f" PRECISION: {stats_precision * 100:.1f}%")
310+
print(f" RECALL: {stats_recall * 100:.1f}%")
311+
print(f" F1: {stats_f1 * 100:.1f}%")
312+
print(
313+
"*********************************************************************"
314+
)
315+
print()
313316

314317
workflow_statistics = defaultdict(
315318
lambda: {"match_pattern": 0, "reverts": 0, "reverts_non_ghfirst": 0}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
6+
# Ensure package import when running from repo root
7+
sys.path.insert(0, "aws/lambda/pytorch-auto-revert")
8+
9+
from pytorch_auto_revert.github_client_helper import GHClientFactory
10+
from pytorch_auto_revert.workflow_resolver import WorkflowResolver
11+
12+
13+
class TestWorkflowResolverRealRepo(unittest.TestCase):
14+
@classmethod
15+
def setUpClass(cls):
16+
token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN")
17+
if token:
18+
GHClientFactory.setup_client(
19+
app_id="",
20+
app_secret="",
21+
installation_id=0,
22+
token=token,
23+
)
24+
25+
def test_resolve_pull_workflow(self):
26+
resolver = WorkflowResolver.get("pytorch/pytorch")
27+
28+
# Resolve by display name
29+
pull_by_name = resolver.resolve("pull")
30+
self.assertIsNotNone(pull_by_name, "Expected to resolve 'pull' by display name")
31+
32+
# Resolve by basename
33+
pull_by_file = resolver.resolve("pull.yml")
34+
self.assertIsNotNone(
35+
pull_by_file, "Expected to resolve 'pull.yml' by file name"
36+
)
37+
self.assertTrue(
38+
pull_by_file.file_name.endswith("pull.yml"),
39+
"Resolved file name should be 'pull.yml'",
40+
)
41+
42+
def test_resolve_trunk_workflow(self):
43+
resolver = WorkflowResolver.get("pytorch/pytorch")
44+
45+
trunk_by_name = resolver.resolve("trunk")
46+
self.assertIsNotNone(
47+
trunk_by_name, "Expected to resolve 'trunk' by display name"
48+
)
49+
50+
trunk_by_file = resolver.resolve("trunk.yml")
51+
self.assertIsNotNone(
52+
trunk_by_file, "Expected to resolve 'trunk.yml' by file name"
53+
)
54+
self.assertTrue(
55+
trunk_by_file.file_name.endswith("trunk.yml"),
56+
"Resolved file name should be 'trunk.yml'",
57+
)
58+
59+
60+
if __name__ == "__main__":
61+
unittest.main()

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_checker.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""
2-
WorkflowRestartChecker for querying restarted workflows via ClickHouse.
2+
WorkflowRestartChecker for querying restarted workflows via ClickHouse and
3+
dispatching workflows via GitHub with consistent workflow name resolution.
34
"""
45

56
import logging
67
from datetime import datetime, timedelta
8+
from functools import cached_property
79
from typing import Dict, Set
810

911
from .clickhouse_client_helper import CHCliFactory
12+
from .workflow_resolver import WorkflowResolver
1013

1114

1215
class WorkflowRestartChecker:
@@ -28,9 +31,10 @@ def has_restarted_workflow(self, workflow_name: str, commit_sha: str) -> bool:
2831
Returns:
2932
bool: True if workflow was restarted (workflow_dispatch with trunk/* branch)
3033
"""
31-
# Normalize workflow name - remove .yml extension for consistency
32-
normalized_workflow_name = workflow_name.replace(".yml", "")
33-
cache_key = f"{normalized_workflow_name}:{commit_sha}"
34+
# Resolve to display name via GitHub (exact display or file name)
35+
display_name = self.resolver.require(workflow_name).display_name
36+
37+
cache_key = f"{display_name}:{commit_sha}"
3438
if cache_key in self._cache:
3539
return self._cache[cache_key]
3640

@@ -54,7 +58,7 @@ def has_restarted_workflow(self, workflow_name: str, commit_sha: str) -> bool:
5458
"commit_sha": commit_sha,
5559
"workflow_event": "workflow_dispatch",
5660
"head_branch": f"trunk/{commit_sha}",
57-
"workflow_name": normalized_workflow_name,
61+
"workflow_name": display_name,
5862
},
5963
)
6064

@@ -73,8 +77,7 @@ def get_restarted_commits(self, workflow_name: str, days_back: int = 7) -> Set[s
7377
Returns:
7478
Set of commit SHAs that have restarted workflows
7579
"""
76-
# Normalize workflow name - remove .yml extension for consistency
77-
normalized_workflow_name = workflow_name.replace(".yml", "")
80+
display_name = self.resolver.require(workflow_name).display_name
7881
since_date = datetime.now() - timedelta(days=days_back)
7982

8083
query = """
@@ -87,14 +90,14 @@ def get_restarted_commits(self, workflow_name: str, days_back: int = 7) -> Set[s
8790
"""
8891

8992
result = CHCliFactory().client.query(
90-
query, {"workflow_name": normalized_workflow_name, "since_date": since_date}
93+
query, {"workflow_name": display_name, "since_date": since_date}
9194
)
9295

9396
commits = {row[0] for row in result.result_rows}
9497

9598
# Update cache
9699
for commit_sha in commits:
97-
cache_key = f"{normalized_workflow_name}:{commit_sha}"
100+
cache_key = f"{display_name}:{commit_sha}"
98101
self._cache[cache_key] = True
99102

100103
return commits
@@ -114,13 +117,10 @@ def restart_workflow(self, workflow_name: str, commit_sha: str) -> bool:
114117
Returns:
115118
bool: True if workflow was successfully dispatched, False otherwise
116119
"""
117-
# Normalize workflow name
118-
normalized_workflow_name = workflow_name.replace(".yml", "")
119-
120120
# Check if already restarted
121-
if self.has_restarted_workflow(normalized_workflow_name, commit_sha):
121+
if self.has_restarted_workflow(workflow_name, commit_sha):
122122
logging.warning(
123-
f"Workflow {normalized_workflow_name} already restarted for commit {commit_sha}"
123+
f"Workflow {workflow_name} already restarted for commit {commit_sha}"
124124
)
125125
return False
126126

@@ -144,33 +144,35 @@ def restart_workflow(self, workflow_name: str, commit_sha: str) -> bool:
144144
# Use trunk/{sha} tag format
145145
tag_ref = f"trunk/{commit_sha}"
146146

147-
# Add .yml extension for workflow name
148-
workflow_file_name = f"{normalized_workflow_name}.yml"
147+
# Resolve workflow
148+
wf_ref = self.resolver.require(workflow_name)
149149

150-
# Get repo and workflow objects
151-
repo = client.get_repo(f"{self.repo_owner}/{self.repo_name}")
152-
workflow = repo.get_workflow(workflow_file_name)
153-
154-
# Dispatch the workflow
155-
workflow.create_dispatch(ref=tag_ref, inputs={})
150+
# Dispatch via file name
151+
client.get_repo(f"{self.repo_owner}/{self.repo_name}").get_workflow(
152+
wf_ref.file_name
153+
).create_dispatch(ref=tag_ref, inputs={})
156154

157155
# Construct the workflow runs URL
158156
workflow_url = (
159157
f"https://github.com/{self.repo_owner}/{self.repo_name}"
160-
f"/actions/workflows/{workflow_file_name}"
158+
f"/actions/workflows/{wf_ref.file_name}"
161159
f"?query=branch%3Atrunk%2F{commit_sha}"
162160
)
163161
logging.info(
164-
f"Successfully dispatched workflow {normalized_workflow_name} for commit {commit_sha}\n"
162+
f"Successfully dispatched workflow {wf_ref.display_name} for commit {commit_sha}\n"
165163
f" View at: {workflow_url}"
166164
)
167165

168166
# Invalidate cache for this workflow/commit
169-
cache_key = f"{normalized_workflow_name}:{commit_sha}"
167+
cache_key = f"{wf_ref.display_name}:{commit_sha}"
170168
if cache_key in self._cache:
171169
del self._cache[cache_key]
172170
return True
173171

174172
except Exception as e:
175-
logging.error(f"Error dispatching workflow {normalized_workflow_name}: {e}")
173+
logging.error(f"Error dispatching workflow {workflow_name}: {e}")
176174
return False
175+
176+
@cached_property
177+
def resolver(self) -> WorkflowResolver:
178+
return WorkflowResolver.get(f"{self.repo_owner}/{self.repo_name}")
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
WorkflowResolver: Resolve GitHub Actions workflows by exact display or file name.
3+
4+
- Exact matches only (no lowercasing or fuzzy matching)
5+
- Caches per-repo resolver instances
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
from dataclasses import dataclass
12+
from functools import lru_cache
13+
from typing import Optional
14+
15+
import github
16+
17+
from .github_client_helper import GHClientFactory
18+
19+
20+
@dataclass(frozen=True)
21+
class WorkflowRef:
22+
"""Reference to a workflow's identities in a repository."""
23+
24+
display_name: str
25+
file_name: str # basename, e.g., "pull.yml"
26+
27+
28+
class WorkflowResolver:
29+
"""Caches workflows for a repo and resolves by exact names.
30+
31+
Usage:
32+
resolver = WorkflowResolver.get("owner/repo")
33+
wf = resolver.resolve("pull") # display name
34+
wf = resolver.resolve("pull.yml") # file basename
35+
"""
36+
37+
def __init__(
38+
self, repo_full_name: str, repository: "github.Repository.Repository"
39+
) -> None:
40+
self._repo_full_name = repo_full_name
41+
self._repository = repository
42+
self._by_display: dict[str, WorkflowRef] = {}
43+
self._by_file: dict[str, WorkflowRef] = {}
44+
self._build_indices()
45+
46+
@staticmethod
47+
@lru_cache(maxsize=None)
48+
def get(repo: str) -> "WorkflowResolver":
49+
"""Get a cached resolver for a repo in owner/repo format.
50+
51+
Internally creates a GitHub Repository client using GHClientFactory when
52+
available; otherwise falls back to an anonymous client for public repos.
53+
"""
54+
# Build a client: prefer configured factory; fall back to anonymous
55+
try:
56+
client = GHClientFactory().client
57+
except Exception:
58+
# Anonymous client for public data; may be rate limited
59+
client = github.Github()
60+
61+
repository = client.get_repo(repo)
62+
return WorkflowResolver(repo_full_name=repo, repository=repository)
63+
64+
def resolve(self, input_name: str) -> Optional[WorkflowRef]:
65+
"""Resolve by exact display name, file basename, or full path.
66+
67+
Returns None if no exact match is found.
68+
"""
69+
if input_name in self._by_display:
70+
return self._by_display[input_name]
71+
if input_name in self._by_file:
72+
return self._by_file[input_name]
73+
return None
74+
75+
def require(self, input_name: str) -> WorkflowRef:
76+
"""Resolve or raise ValueError with a helpful message."""
77+
ref = self.resolve(input_name)
78+
if ref is None:
79+
# Build an informative message with available names
80+
display = ", ".join(sorted(self._by_display))
81+
files = ", ".join(sorted(self._by_file))
82+
raise ValueError(
83+
f"Workflow '{input_name}' not found in {self._repo_full_name}. "
84+
f"Available display names: [{display}]. Available files: [{files}]"
85+
)
86+
return ref
87+
88+
# Internal helpers
89+
90+
def _build_indices(self) -> None:
91+
for w in self._repository.get_workflows():
92+
name = getattr(w, "name", "") or ""
93+
path = getattr(w, "path", "") or ""
94+
base = os.path.basename(path) if path else ""
95+
if not (name and base):
96+
continue
97+
ref = WorkflowRef(display_name=name, file_name=base)
98+
self._by_display[name] = ref
99+
self._by_file[base] = ref

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/workflow_restart.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)