77
88from lib .commands import ssh
99
10- JOBS = {
10+ from typing import NotRequired , TypedDict , cast
11+
12+ class JobData (TypedDict ):
13+ description : str
14+ requirements : list [str ]
15+ nb_pools : int
16+ params : dict [str , str ]
17+ paths : list [str ]
18+ markers : NotRequired [str ]
19+ name_filter : NotRequired [str ]
20+
21+ JOBS : dict [str , JobData ] = {
1122 "main" : {
1223 "description" : "a group of not-too-long tests that run either without a VM, or with a single small one" ,
1324 "requirements" : [
472483 "tests/storage/zfsvol/test_zfsvol_sr.py::TestZfsvolVm::test_quicktest" ,
473484]
474485
486+ VmDef = str | tuple [str , str ]
487+ VMSDef = dict [str , dict [str , VmDef | list [VmDef ]]]
488+
475489# Returns the vm filename or None if a host_version is passed and matches the one specified
476490# with the vm filename in vm_data.py. ex: ("centos6-32-hvm-created_8.2-zstd.xva", "8\.2\..*")
477- def filter_vm (vm , host_version ) :
491+ def filter_vm (vm : VmDef , host_version : str | None ) -> str | None :
478492 import re
479493
480- if type (vm ) is tuple :
494+ if isinstance (vm , tuple ) :
481495 if len (vm ) != 2 :
482496 print (f"ERROR: VM definition from vm_data.py is a tuple so it should contain exactly two items:\n { vm } " )
483497 sys .exit (1 )
@@ -496,34 +510,34 @@ def filter_vm(vm, host_version):
496510
497511 return vm
498512
499- def get_vm_or_vms_refs (handle , host_version = None ):
513+ def get_vm_or_vms_refs (handle : str , host_version : str | None = None ) -> str | list [ str ] :
500514 try :
501- from vm_data import VMS
515+ from vm_data import VMS as VMS_untyped
502516 except ImportError :
503517 print ("ERROR: Could not import VMS from vm_data.py." )
504- print ("Get the latest vm_data.py from XCP-ng's internal lab or copy data.py-dist and fill with your VM refs." )
518+ print ("Get the latest vm_data.py from XCP-ng's internal lab or copy vm_data.py-dist and fill"
519+ " with your VM refs." )
505520 print ("You may also bypass this error by providing your own --vm parameter(s)." )
506521 sys .exit (1 )
507522
523+ VMS = cast (VMSDef , VMS_untyped )
508524 category , key = handle .split ("/" )
509- if category not in VMS or not VMS [category ]. get ( key ) :
525+ if category not in VMS or key not in VMS [category ]:
510526 print (f"ERROR: Could not find VMS['{ category } ']['{ key } '] in vm_data.py, or it's empty." )
511527 print ("You need to update your local vm_data.py." )
512528 print ("You may also bypass this error by providing your own --vm parameter(s)." )
513529 sys .exit (1 )
514530
515- if type (VMS [category ][key ]) is list :
531+ vms : str | list [str ] | None = []
532+ vms_unfiltered = VMS [category ][key ]
533+ if isinstance (vms_unfiltered , list ):
516534 # Multi VMs
517- vms = list ()
518- for vm in VMS [category ][key ]:
519- xva = filter_vm (vm , host_version )
520- if xva is not None :
521- vms .append (xva )
522- if len (vms ) == 0 :
535+ vms = [xva for vm in vms_unfiltered if (xva := filter_vm (vm , host_version )) is not None ]
536+ if vms == []:
523537 vms = None
524- else :
538+ elif isinstance ( vms_unfiltered , str ) :
525539 # Single VMs
526- vms = filter_vm (VMS [ category ][ key ] , host_version )
540+ vms = filter_vm (vms_unfiltered , host_version )
527541
528542 if vms is None :
529543 print (f"ERROR: Could not find VMS['{ category } ']['{ key } '] for host version { host_version } ." )
@@ -533,7 +547,8 @@ def get_vm_or_vms_refs(handle, host_version=None):
533547
534548 return vms
535549
536- def build_pytest_cmd (job_data , hosts = None , host_version = None , pytest_args = []):
550+ def build_pytest_cmd (job_data : JobData , hosts : str | None = None , host_version : str | None = None ,
551+ pytest_args : list [str ] = []) -> list [str ]:
537552 markers = job_data .get ("markers" , None )
538553 name_filter = job_data .get ("name_filter" , None )
539554
@@ -543,13 +558,12 @@ def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
543558 if hosts is not None :
544559 try :
545560 host = hosts .split (',' )[0 ]
546- cmd = "lsb_release -sr"
547- host_version = ssh (host , cmd )
561+ host_version = ssh (host , "lsb_release -sr" )
548562 except Exception as e :
549563 print (e , file = sys .stderr )
550564
551- def _join_pytest_args (arg , option ) :
552- cli_args = []
565+ def _join_pytest_args (arg : str | None , option : str ) -> str | None :
566+ cli_args : list [ str ] = []
553567 try :
554568 while True :
555569 i = pytest_args .index (option )
@@ -600,21 +614,21 @@ def _join_pytest_args(arg, option):
600614 cmd += pytest_args
601615 return cmd
602616
603- def action_list (args ) :
617+ def action_list (args : argparse . Namespace ) -> None :
604618 for job , data in JOBS .items ():
605619 print (f"{ job } : { data ['description' ]} " )
606620
607- def action_show (args ) :
621+ def action_show (args : argparse . Namespace ) -> None :
608622 print (json .dumps (JOBS [args .job ], indent = 4 ))
609623
610- def action_collect (args ) :
624+ def action_collect (args : argparse . Namespace ) -> None :
611625 cmd = build_pytest_cmd (JOBS [args .job ], None , args .host_version , ["--collect-only" ] + args .pytest_args )
612626 subprocess .run (cmd )
613627
614- def action_check (args ) :
628+ def action_check (args : argparse . Namespace ) -> None :
615629 error = False
616630
617- def extract_tests (cmd ) :
631+ def extract_tests (cmd : list [ str ]) -> set [ str ] :
618632 tests = set ()
619633 res = subprocess .run (cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
620634 if res .returncode != 0 and res .returncode != 5 : # 5 means no test found
@@ -663,6 +677,7 @@ def extract_tests(cmd):
663677 multi_vm_tests = extract_tests (["pytest" , "--collect-only" , "-q" , "-m" , "multi_vms" ]) - broken_tests
664678 job_tests = set ()
665679 for job_data in JOBS .values ():
680+ assert isinstance (job_data ["params" ], dict )
666681 if "--vm[]" in job_data ["params" ]:
667682 job_tests |= extract_tests (build_pytest_cmd (job_data , None , None , ["--collect-only" , "-q" , "--vm=a_vm" ]))
668683 tests_missing = sorted (list (multi_vm_tests - job_tests ))
@@ -676,23 +691,25 @@ def extract_tests(cmd):
676691 if error :
677692 sys .exit (1 )
678693
679- def action_run (args ) :
694+ def action_run (args : argparse . Namespace ) -> None :
680695 cmd = build_pytest_cmd (JOBS [args .job ], args .hosts , None , args .pytest_args )
681696 print (subprocess .list2cmdline (cmd ))
682697 if args .print_only :
683698 return
684699
685700 # check that enough pool masters have been provided
686701 nb_pools = len (args .hosts .split ("," ))
687- if nb_pools < JOBS [args .job ]["nb_pools" ]:
688- print (f"Error: only { nb_pools } master host(s) provided, { JOBS [args .job ]['nb_pools' ]} required." )
702+ job_nb_pools = JOBS [args .job ]["nb_pools" ]
703+ assert isinstance (job_nb_pools , int )
704+ if nb_pools < job_nb_pools :
705+ print (f"Error: only { nb_pools } master host(s) provided, { job_nb_pools } required." )
689706 sys .exit (1 )
690707
691708 res = subprocess .run (cmd )
692709 if res .returncode :
693710 sys .exit (1 )
694711
695- def main ():
712+ def main () -> None :
696713 parser = argparse .ArgumentParser (description = "Manage test jobs" )
697714 subparsers = parser .add_subparsers (dest = "action" , metavar = "action" )
698715 subparsers .required = True
0 commit comments