Skip to content

Commit b314edd

Browse files
committed
Add arguments linter
Checks types in `api/arguments.py` and `api/arguments_typed.py` align.
1 parent e565648 commit b314edd

File tree

6 files changed

+233
-25
lines changed

6 files changed

+233
-25
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
uses: astral-sh/setup-uv@v6
3939
- run: uv sync --group test --no-default-groups
4040
- run: uv run mypy
41+
- run: uv run python scripts/lint_arguments_sync.py
4142

4243
spell-check:
4344
runs-on: ubuntu-24.04

scripts/dev-lint.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ uv run ruff check
88
echo "Execute mypy..."
99
uv run mypy
1010

11+
echo "Execute arguments type check..."
12+
uv run python scripts/lint_arguments_sync.py
13+
1114
echo "Linting complete!"

scripts/lint_arguments_sync.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
#!/usr/bin/env python
2+
"""
3+
Lint script to verify that arguments in arguments.py and arguments_typed.py are in sync.
4+
5+
Due to Python's typing limitations (ParamSpec args cannot be concatenated), we maintain
6+
two separate definitions of global arguments:
7+
- arguments.py: TypedDicts (ConnectorArguments, MetaArguments, ExecutionArguments)
8+
- arguments_typed.py: PyinfraOperation.__call__ method signature
9+
10+
This script ensures they stay synchronized.
11+
"""
12+
13+
import ast
14+
import sys
15+
from os import path
16+
from typing import NamedTuple
17+
18+
19+
class ArgumentInfo(NamedTuple):
20+
name: str
21+
type_annotation: str
22+
has_default: bool
23+
24+
25+
def get_typeddict_keys(tree: ast.Module, class_names: list[str]) -> dict[str, str]:
26+
"""Extract keys and their type annotations from TypedDict classes."""
27+
keys: dict[str, str] = {}
28+
29+
for node in ast.walk(tree):
30+
if isinstance(node, ast.ClassDef) and node.name in class_names:
31+
for item in node.body:
32+
if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
33+
key = item.target.id
34+
type_str = ast.unparse(item.annotation)
35+
keys[key] = type_str
36+
37+
return keys
38+
39+
40+
def get_argument_meta_keys(tree: ast.Module) -> set[str]:
41+
"""Extract keys from all *_argument_meta dictionaries."""
42+
keys: set[str] = set()
43+
44+
for node in ast.walk(tree):
45+
if isinstance(node, ast.Assign):
46+
for target in node.targets:
47+
if isinstance(target, ast.Name) and target.id.endswith("_argument_meta"):
48+
if isinstance(node.value, ast.Dict):
49+
for key in node.value.keys:
50+
if isinstance(key, ast.Constant) and isinstance(key.value, str):
51+
keys.add(key.value)
52+
53+
return keys
54+
55+
56+
def get_call_parameters(tree: ast.Module, class_name: str, method_name: str) -> list[ArgumentInfo]:
57+
"""Extract parameters from a class method."""
58+
params: list[ArgumentInfo] = []
59+
60+
for node in ast.walk(tree):
61+
if isinstance(node, ast.ClassDef) and node.name == class_name:
62+
for item in node.body:
63+
if isinstance(item, ast.FunctionDef) and item.name == method_name:
64+
args = item.args
65+
66+
# Count defaults - they align with the end of the args list
67+
num_defaults = len(args.defaults)
68+
num_args = len(args.args)
69+
70+
for i, arg in enumerate(args.args):
71+
if arg.arg == "self":
72+
continue
73+
74+
type_str = ast.unparse(arg.annotation) if arg.annotation else ""
75+
76+
# Check if this arg has a default
77+
default_index = i - (num_args - num_defaults)
78+
has_default = default_index >= 0
79+
80+
params.append(
81+
ArgumentInfo(
82+
name=arg.arg,
83+
type_annotation=type_str,
84+
has_default=has_default,
85+
)
86+
)
87+
88+
# Also check kwonly args
89+
for i, arg in enumerate(args.kwonlyargs):
90+
type_str = ast.unparse(arg.annotation) if arg.annotation else ""
91+
has_default = args.kw_defaults[i] is not None
92+
params.append(
93+
ArgumentInfo(
94+
name=arg.arg,
95+
type_annotation=type_str,
96+
has_default=has_default,
97+
)
98+
)
99+
100+
return params
101+
102+
103+
def normalize_type(type_str: str) -> str:
104+
"""Normalize type string for comparison (handle Optional, Union, etc.)."""
105+
# Remove whitespace
106+
type_str = type_str.replace(" ", "")
107+
108+
# Sort Union members for consistent comparison
109+
if type_str.startswith("Union[") or type_str.startswith("Optional["):
110+
# This is a simplified normalization - just for basic comparison
111+
pass
112+
113+
return type_str
114+
115+
116+
def main() -> int:
117+
this_dir = path.dirname(path.realpath(__file__))
118+
repo_root = path.abspath(path.join(this_dir, ".."))
119+
120+
arguments_path = path.join(repo_root, "src", "pyinfra", "api", "arguments.py")
121+
arguments_typed_path = path.join(repo_root, "src", "pyinfra", "api", "arguments_typed.py")
122+
123+
# Parse both files
124+
with open(arguments_path, "r", encoding="utf-8") as f:
125+
arguments_tree = ast.parse(f.read())
126+
127+
with open(arguments_typed_path, "r", encoding="utf-8") as f:
128+
arguments_typed_tree = ast.parse(f.read())
129+
130+
# Extract TypedDict keys from arguments.py
131+
typeddict_classes = ["ConnectorArguments", "MetaArguments", "ExecutionArguments"]
132+
typeddict_keys = get_typeddict_keys(arguments_tree, typeddict_classes)
133+
134+
# Extract argument meta keys (the actual source of truth for what arguments exist)
135+
meta_keys = get_argument_meta_keys(arguments_tree)
136+
137+
# Extract PyinfraOperation.__call__ parameters from arguments_typed.py
138+
call_params = get_call_parameters(arguments_typed_tree, "PyinfraOperation", "__call__")
139+
call_param_names = {p.name for p in call_params}
140+
call_param_types = {p.name: p.type_annotation for p in call_params}
141+
142+
errors: list[str] = []
143+
warnings: list[str] = []
144+
145+
# Check that all TypedDict keys are in PyinfraOperation.__call__
146+
for key, type_str in typeddict_keys.items():
147+
if key not in call_param_names:
148+
errors.append(
149+
f"Argument '{key}' is in arguments.py TypedDicts but missing from "
150+
f"PyinfraOperation.__call__ in arguments_typed.py"
151+
)
152+
elif key in call_param_types:
153+
typed_type = normalize_type(call_param_types[key])
154+
expected_type = normalize_type(type_str)
155+
156+
# TypedDict uses non-Optional types, but PyinfraOperation uses Optional
157+
# So we do a loose check - the base type should match
158+
# This is a simplified check that may need refinement
159+
if expected_type not in typed_type and typed_type not in expected_type:
160+
# Check if it's just an Optional wrapper difference
161+
if f"Optional[{expected_type}]" != typed_type and expected_type != typed_type:
162+
warnings.append(
163+
f"Type mismatch for '{key}': "
164+
f"arguments.py has '{type_str}', "
165+
f"arguments_typed.py has '{call_param_types[key]}'"
166+
)
167+
168+
# Check that all PyinfraOperation.__call__ params (except special ones) are in TypedDicts
169+
# Skip *args and **kwargs which are the P.args/P.kwargs for operation-specific args
170+
special_params = {"args", "kwargs"}
171+
for param in call_params:
172+
if param.name in special_params:
173+
continue
174+
if param.name not in typeddict_keys:
175+
errors.append(
176+
f"Parameter '{param.name}' is in PyinfraOperation.__call__ but missing from "
177+
f"TypedDicts in arguments.py"
178+
)
179+
180+
# Check that all meta keys are represented
181+
all_typeddict_keys = set(typeddict_keys.keys())
182+
for key in meta_keys:
183+
if key not in all_typeddict_keys:
184+
errors.append(f"Argument '{key}' is in argument_meta dicts but missing from TypedDicts")
185+
186+
# Report results
187+
if warnings:
188+
print("Warnings:")
189+
for warning in warnings:
190+
print(f" ⚠️ {warning}")
191+
print()
192+
193+
if errors:
194+
print("Errors:")
195+
for error in errors:
196+
print(f" ❌ {error}")
197+
print()
198+
print(f"Found {len(errors)} error(s) and {len(warnings)} warning(s)")
199+
return 1
200+
201+
print("✅ arguments.py and arguments_typed.py are in sync!")
202+
if warnings:
203+
print(f" ({len(warnings)} warning(s) - type annotations may differ slightly)")
204+
return 0
205+
206+
207+
if __name__ == "__main__":
208+
sys.exit(main())

src/pyinfra/api/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ class ConnectorArguments(TypedDict, total=False):
7070
_success_exit_codes: Iterable[int]
7171
_timeout: int
7272
_get_pty: bool
73-
_stdin: Union[str, Iterable[str]]
73+
_stdin: Union[str, list[str], Iterable[str]]
7474

7575
# Retry arguments
7676
_retries: int
7777
_retry_delay: Union[int, float]
78-
_retry_until: Optional[Callable[[dict], bool]]
78+
_retry_until: Callable[[dict], bool]
7979

8080
# Temp directory argument
8181
_temp_dir: str

src/pyinfra/api/arguments_typed.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import (
4-
TYPE_CHECKING,
5-
Callable,
6-
Generator,
7-
Generic,
8-
Iterable,
9-
List,
10-
Mapping,
11-
Optional,
12-
Union,
13-
)
3+
from typing import TYPE_CHECKING, Callable, Generator, Generic, Iterable, List, Mapping, Union
144

155
from typing_extensions import ParamSpec, Protocol
166

@@ -36,36 +26,41 @@ def __call__(
3626
#
3727
# Auth
3828
_sudo: bool = False,
39-
_sudo_user: Optional[str] = None,
29+
_sudo_user: None | str = None,
4030
_use_sudo_login: bool = False,
41-
_sudo_password: Optional[str] = None,
31+
_sudo_password: None | str = None,
4232
_preserve_sudo_env: bool = False,
43-
_su_user: Optional[str] = None,
33+
_su_user: None | str = None,
4434
_use_su_login: bool = False,
4535
_preserve_su_env: bool = False,
46-
_su_shell: Optional[str] = None,
36+
_su_shell: None | str = None,
4737
_doas: bool = False,
48-
_doas_user: Optional[str] = None,
38+
_doas_user: None | str = None,
4939
# Shell arguments
50-
_shell_executable: Optional[str] = None,
51-
_chdir: Optional[str] = None,
52-
_env: Optional[Mapping[str, str]] = None,
40+
_shell_executable: None | str = None,
41+
_chdir: None | str = None,
42+
_env: None | Mapping[str, str] = None,
5343
# Connector control
5444
_success_exit_codes: Iterable[int] = (0,),
55-
_timeout: Optional[int] = None,
45+
_timeout: None | int = None,
5646
_get_pty: bool = False,
57-
_stdin: Union[None, str, list[str], tuple[str, ...]] = None,
47+
_stdin: None | Union[str, list[str], Iterable[str]] = None,
48+
# Retry arguments
49+
_retries: None | int = None,
50+
_retry_delay: None | Union[int, float] = None,
51+
_retry_until: None | Callable[[dict], bool] = None,
52+
_temp_dir: None | str = None,
5853
#
5954
# MetaArguments
6055
#
61-
name: Optional[str] = None,
56+
name: None | str = None,
6257
_ignore_errors: bool = False,
6358
_continue_on_error: bool = False,
6459
_if: Union[List[Callable[[], bool]], Callable[[], bool], None] = None,
6560
#
6661
# ExecutionArguments
6762
#
68-
_parallel: Optional[int] = None,
63+
_parallel: None | int = None,
6964
_run_once: bool = False,
7065
_serial: bool = False,
7166
#

src/pyinfra/api/command.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def execute_function() -> None | Exception:
249249
self.function(*self.args, **self.kwargs)
250250
except Exception as e:
251251
return e
252+
return None
252253

253254
greenlet = gevent.spawn(execute_function)
254255
exception = greenlet.get()

0 commit comments

Comments
 (0)