@@ -36,6 +36,9 @@ def textwrap_for_yml(msg_string):
3636class InferenceBaseJob :
3737 """Base class for inference jobs supporting multiple frameworks."""
3838
39+ # Class-level results dictionary shared across all test iterations
40+ all_test_results = {}
41+
3942 def __init__ (
4043 self ,
4144 c_phdl ,
@@ -67,7 +70,7 @@ def __init__(
6770
6871 self .job_cmd = ''
6972 self .job_cmd_list = []
70- self .inference_result_dict = {}
73+ self .inference_results_dict = {}
7174 print (self .gpu_type )
7275
7376 # Needed only in the case of distributed inference - placeholder for future
@@ -192,6 +195,10 @@ def get_log_subdir(self):
192195 """Get log subdirectory name for this framework."""
193196 raise NotImplementedError ("Derived class must implement get_log_subdir()" )
194197
198+ def collect_test_result (self , status ):
199+ """Collect test results. Override in derived class if needed."""
200+ pass
201+
195202 def run_preinference_tasks (
196203 self ,
197204 ):
@@ -245,6 +252,19 @@ def exec_nic_setup_scripts(
245252 def build_server_inference_job_cmd (
246253 self ,
247254 ):
255+ # Build VLLM env vars from config, with defaults if not specified
256+ vllm_env_vars = self .bp_dict .get (
257+ 'vllm_env_vars' ,
258+ {
259+ 'VLLM_USE_AITER_UNIFIED_ATTENTION' : '1' ,
260+ 'VLLM_ROCM_USE_AITER_MHA' : '0' ,
261+ 'VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4' : '1' ,
262+ },
263+ )
264+
265+ # Convert dict to export statements
266+ vllm_exports = '\n ' .join ([f'export { k } ={ v } ' for k , v in vllm_env_vars .items ()])
267+
248268 s_cmd = f'''docker exec { self .container_name } /bin/bash -c "echo '
249269 export MODEL={ self .bp_dict ['model' ]}
250270 export ISL={ self .bp_dict ['input_sequence_length' ]}
@@ -254,9 +274,7 @@ def build_server_inference_job_cmd(
254274 export TP={ self .bp_dict ['tensor_parallelism' ]}
255275 export CONC={ self .bp_dict ['max_concurrency' ]}
256276 export HF_TOKEN={ self .hf_token }
257- export VLLM_USE_AITER_UNIFIED_ATTENTION=1
258- export VLLM_ROCM_USE_AITER_MHA=0
259- export VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4=1
277+ { vllm_exports }
260278 export PORT={ self .bp_dict ['port_no' ]} ' > /tmp/server_env_script.sh"
261279 '''
262280 time .sleep (3 )
@@ -680,51 +698,60 @@ def verify_inference_results(
680698 print (f"Validating results for config: { config_key } " )
681699 print (f"Expected thresholds: { expected_result_dict } " )
682700
683- # Validate metrics on a per-node basis
684- for node in self .inference_result_dict .keys ():
685- print (f"Validating node: { node } " )
686- print (f"Actual results: { self .inference_result_dict [node ]} " )
687-
688- for metric_name in expected_result_dict .keys ():
689- if metric_name not in self .inference_result_dict [node ]:
690- print (f"WARNING: Metric { metric_name } not found in actual results, skipping" )
691- continue
692-
693- actual_value = float (self .inference_result_dict [node ][metric_name ])
694- expected_value = float (expected_result_dict [metric_name ])
695-
696- # Latency metrics (ms): lower is better
697- if re .search ('ms' , metric_name , re .I ):
698- if actual_value > expected_value :
699- fail_test (
700- f"FAIL - Latency metric '{ metric_name } ' exceeded threshold for { config_key } \n "
701- f" Actual: { actual_value } ms\n "
702- f" Expected: <= { expected_value } ms\n "
703- f" Difference: +{ actual_value - expected_value :.2f} ms ({ ((actual_value / expected_value - 1 ) * 100 ):.1f} % worse)"
704- )
701+ validation_passed = True
702+ try :
703+ # Validate metrics on a per-node basis
704+ for node in self .inference_results_dict .keys ():
705+ print (f"Validating node: { node } " )
706+ print (f"Actual results: { self .inference_results_dict [node ]} " )
707+
708+ for metric_name in expected_result_dict .keys ():
709+ if metric_name not in self .inference_results_dict [node ]:
710+ print (f"WARNING: Metric { metric_name } not found in actual results, skipping" )
711+ continue
712+
713+ actual_value = float (self .inference_results_dict [node ][metric_name ])
714+ expected_value = float (expected_result_dict [metric_name ])
715+
716+ # Latency metrics (ms): lower is better
717+ if re .search ('ms' , metric_name , re .I ):
718+ if actual_value > expected_value :
719+ fail_test (
720+ f"FAIL - Latency metric '{ metric_name } ' exceeded threshold for { config_key } \n "
721+ f" Actual: { actual_value } ms\n "
722+ f" Expected: <= { expected_value } ms\n "
723+ f" Difference: +{ actual_value - expected_value :.2f} ms ({ ((actual_value / expected_value - 1 ) * 100 ):.1f} % worse)"
724+ )
725+ else :
726+ print (f"✓ { metric_name } : { actual_value } ms <= { expected_value } ms" )
727+
728+ # Throughput metrics (per_sec): higher is better
705729 else :
706- print (f"✓ { metric_name } : { actual_value } ms <= { expected_value } ms" )
707-
708- # Throughput metrics (per_sec): higher is better
709- else :
710- if actual_value < expected_value :
711- fail_test (
712- f"FAIL - Throughput metric '{ metric_name } ' below threshold for { config_key } \n "
713- f" Actual: { actual_value } \n "
714- f" Expected: >= { expected_value } \n "
715- f" Difference: -{ expected_value - actual_value :.2f} ({ ((1 - actual_value / expected_value ) * 100 ):.1f} % worse)"
716- )
717- else :
718- print (f"✓ { metric_name } : { actual_value } >= { expected_value } " )
719-
720- # Scan Dmesg for errors
721- self .inference_end_time = self .s_phdl .exec ('date +"%a %b %e %H:%M"' )
722- time .sleep (2 )
723- verify_dmesg_for_errors (self .s_phdl , self .inference_start_time , self .inference_end_time )
724-
725- print (f"✓ All validations passed for { config_key } " )
726- print (self .inference_result_dict )
727-
728- # Auto-store results if this is a VllmJob instance
729- if hasattr (self , 'store_test_result' ):
730- self .store_test_result ()
730+ if actual_value < expected_value :
731+ fail_test (
732+ f"FAIL - Throughput metric '{ metric_name } ' below threshold for { config_key } \n "
733+ f" Actual: { actual_value } \n "
734+ f" Expected: >= { expected_value } \n "
735+ f" Difference: -{ expected_value - actual_value :.2f} ({ ((1 - actual_value / expected_value ) * 100 ):.1f} % worse)"
736+ )
737+ else :
738+ print (f"✓ { metric_name } : { actual_value } >= { expected_value } " )
739+ except Exception as e :
740+ validation_passed = False
741+ raise e
742+ finally :
743+ # Scan Dmesg for errors
744+ self .inference_end_time = self .s_phdl .exec ('date +"%a %b %e %H:%M"' )
745+ time .sleep (2 )
746+ verify_dmesg_for_errors (self .s_phdl , self .inference_start_time , self .inference_end_time )
747+
748+ if validation_passed :
749+ print (f"✓ All validations passed for { config_key } " )
750+ print (self .inference_results_dict )
751+ # Auto-store results
752+ self .collect_test_result ("success" )
753+ else :
754+ print (f"✗ Validations failed for { config_key } " )
755+ print (self .inference_results_dict )
756+ # Auto-store results even on failure
757+ self .collect_test_result ("failed" )
0 commit comments