Skip to content

Commit 9a8d4f3

Browse files
committed
jobs: add type hints
Signed-off-by: Gaëtan Lehmann <gaetan.lehmann@vates.tech>
1 parent 058a757 commit 9a8d4f3

1 file changed

Lines changed: 47 additions & 30 deletions

File tree

jobs.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77

88
from lib.commands import ssh
99

10-
JOBS = {
10+
from typing import NotRequired, TypedDict, cast
11+
12+
class JobData(TypedDict):
13+
description: str
14+
requirements: list[str]
15+
nb_pools: int
16+
params: dict[str, str]
17+
paths: list[str]
18+
markers: NotRequired[str]
19+
name_filter: NotRequired[str]
20+
21+
JOBS: dict[str, JobData] = {
1122
"main": {
1223
"description": "a group of not-too-long tests that run either without a VM, or with a single small one",
1324
"requirements": [
@@ -472,12 +483,15 @@
472483
"tests/storage/zfsvol/test_zfsvol_sr.py::TestZfsvolVm::test_quicktest",
473484
]
474485

486+
VmDef = str | tuple[str, str]
487+
VMSDef = dict[str, dict[str, VmDef | list[VmDef]]]
488+
475489
# Returns the vm filename or None if a host_version is passed and matches the one specified
476490
# with the vm filename in vm_data.py. ex: ("centos6-32-hvm-created_8.2-zstd.xva", "8\.2\..*")
477-
def filter_vm(vm, host_version):
491+
def filter_vm(vm: VmDef, host_version: str | None) -> str | None:
478492
import re
479493

480-
if type(vm) is tuple:
494+
if isinstance(vm, tuple):
481495
if len(vm) != 2:
482496
print(f"ERROR: VM definition from vm_data.py is a tuple so it should contain exactly two items:\n{vm}")
483497
sys.exit(1)
@@ -496,34 +510,34 @@ def filter_vm(vm, host_version):
496510

497511
return vm
498512

499-
def get_vm_or_vms_refs(handle, host_version=None):
513+
def get_vm_or_vms_refs(handle: str, host_version: str | None = None) -> str | list[str]:
500514
try:
501-
from vm_data import VMS
515+
from vm_data import VMS as VMS_untyped
502516
except ImportError:
503517
print("ERROR: Could not import VMS from vm_data.py.")
504-
print("Get the latest vm_data.py from XCP-ng's internal lab or copy data.py-dist and fill with your VM refs.")
518+
print("Get the latest vm_data.py from XCP-ng's internal lab or copy vm_data.py-dist and fill"
519+
" with your VM refs.")
505520
print("You may also bypass this error by providing your own --vm parameter(s).")
506521
sys.exit(1)
507522

523+
VMS = cast(VMSDef, VMS_untyped)
508524
category, key = handle.split("/")
509-
if category not in VMS or not VMS[category].get(key):
525+
if category not in VMS or key not in VMS[category]:
510526
print(f"ERROR: Could not find VMS['{category}']['{key}'] in vm_data.py, or it's empty.")
511527
print("You need to update your local vm_data.py.")
512528
print("You may also bypass this error by providing your own --vm parameter(s).")
513529
sys.exit(1)
514530

515-
if type(VMS[category][key]) is list:
531+
vms: str | list[str] | None = []
532+
vms_unfiltered = VMS[category][key]
533+
if isinstance(vms_unfiltered, list):
516534
# Multi VMs
517-
vms = list()
518-
for vm in VMS[category][key]:
519-
xva = filter_vm(vm, host_version)
520-
if xva is not None:
521-
vms.append(xva)
522-
if len(vms) == 0:
535+
vms = [xva for vm in vms_unfiltered if (xva := filter_vm(vm, host_version)) is not None]
536+
if vms == []:
523537
vms = None
524-
else:
538+
elif isinstance(vms_unfiltered, str):
525539
# Single VMs
526-
vms = filter_vm(VMS[category][key], host_version)
540+
vms = filter_vm(vms_unfiltered, host_version)
527541

528542
if vms is None:
529543
print(f"ERROR: Could not find VMS['{category}']['{key}'] for host version {host_version}.")
@@ -533,7 +547,8 @@ def get_vm_or_vms_refs(handle, host_version=None):
533547

534548
return vms
535549

536-
def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
550+
def build_pytest_cmd(job_data: JobData, hosts: str | None = None, host_version: str | None = None,
551+
pytest_args: list[str] = []) -> list[str]:
537552
markers = job_data.get("markers", None)
538553
name_filter = job_data.get("name_filter", None)
539554

@@ -543,13 +558,12 @@ def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
543558
if hosts is not None:
544559
try:
545560
host = hosts.split(',')[0]
546-
cmd = "lsb_release -sr"
547-
host_version = ssh(host, cmd)
561+
host_version = ssh(host, "lsb_release -sr")
548562
except Exception as e:
549563
print(e, file=sys.stderr)
550564

551-
def _join_pytest_args(arg, option):
552-
cli_args = []
565+
def _join_pytest_args(arg: str | None, option: str) -> str | None:
566+
cli_args: list[str] = []
553567
try:
554568
while True:
555569
i = pytest_args.index(option)
@@ -600,21 +614,21 @@ def _join_pytest_args(arg, option):
600614
cmd += pytest_args
601615
return cmd
602616

603-
def action_list(args):
617+
def action_list(args: argparse.Namespace) -> None:
604618
for job, data in JOBS.items():
605619
print(f"{job}: {data['description']}")
606620

607-
def action_show(args):
621+
def action_show(args: argparse.Namespace) -> None:
608622
print(json.dumps(JOBS[args.job], indent=4))
609623

610-
def action_collect(args):
624+
def action_collect(args: argparse.Namespace) -> None:
611625
cmd = build_pytest_cmd(JOBS[args.job], None, args.host_version, ["--collect-only"] + args.pytest_args)
612626
subprocess.run(cmd)
613627

614-
def action_check(args):
628+
def action_check(args: argparse.Namespace) -> None:
615629
error = False
616630

617-
def extract_tests(cmd):
631+
def extract_tests(cmd: list[str]) -> set[str]:
618632
tests = set()
619633
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
620634
if res.returncode != 0 and res.returncode != 5: # 5 means no test found
@@ -663,6 +677,7 @@ def extract_tests(cmd):
663677
multi_vm_tests = extract_tests(["pytest", "--collect-only", "-q", "-m", "multi_vms"]) - broken_tests
664678
job_tests = set()
665679
for job_data in JOBS.values():
680+
assert isinstance(job_data["params"], dict)
666681
if "--vm[]" in job_data["params"]:
667682
job_tests |= extract_tests(build_pytest_cmd(job_data, None, None, ["--collect-only", "-q", "--vm=a_vm"]))
668683
tests_missing = sorted(list(multi_vm_tests - job_tests))
@@ -676,23 +691,25 @@ def extract_tests(cmd):
676691
if error:
677692
sys.exit(1)
678693

679-
def action_run(args):
694+
def action_run(args: argparse.Namespace) -> None:
680695
cmd = build_pytest_cmd(JOBS[args.job], args.hosts, None, args.pytest_args)
681696
print(subprocess.list2cmdline(cmd))
682697
if args.print_only:
683698
return
684699

685700
# check that enough pool masters have been provided
686701
nb_pools = len(args.hosts.split(","))
687-
if nb_pools < JOBS[args.job]["nb_pools"]:
688-
print(f"Error: only {nb_pools} master host(s) provided, {JOBS[args.job]['nb_pools']} required.")
702+
job_nb_pools = JOBS[args.job]["nb_pools"]
703+
assert isinstance(job_nb_pools, int)
704+
if nb_pools < job_nb_pools:
705+
print(f"Error: only {nb_pools} master host(s) provided, {job_nb_pools} required.")
689706
sys.exit(1)
690707

691708
res = subprocess.run(cmd)
692709
if res.returncode:
693710
sys.exit(1)
694711

695-
def main():
712+
def main() -> None:
696713
parser = argparse.ArgumentParser(description="Manage test jobs")
697714
subparsers = parser.add_subparsers(dest="action", metavar="action")
698715
subparsers.required = True

0 commit comments

Comments
 (0)