@@ -33,16 +33,28 @@ def _resolve_inference_component_name(sm_client, endpoint_name: str, debug: bool
3333                break 
3434        if  not  ics :
3535            return  None 
36-         name  =  ics [0 ].get ("InferenceComponentName" )
36+ 
37+         # Prefer components that are InService; fall back to the most recently created otherwise 
38+         def  sort_key (ic : dict ):
39+             # CreationTime is a datetime; if missing, use 0 
40+             return  (
41+                 ic .get ("InferenceComponentStatus" ) ==  "InService" ,
42+                 ic .get ("CreationTime" ) or  0 ,
43+             )
44+ 
45+         ics_sorted  =  sorted (ics , key = sort_key , reverse = True )
46+         chosen  =  ics_sorted [0 ]
47+         name  =  chosen .get ("InferenceComponentName" )
3748        if  name  and  debug :
38-             print (f"[debug] Using inference component: { name }  )
49+             status  =  chosen .get ("InferenceComponentStatus" )
50+             print (f"[debug] Using inference component: { name } { status }  )
3951        return  name 
4052    except  Exception  as  exc :
4153        if  debug :
4254            print (f"[debug] Could not list inference components: { exc }  )
4355        return  None 
4456
45- def  invoke_endpoint (runtime , endpoint_name : str , prompt : str , max_tokens : int  =  100 , temperature : float  =  0.7 , inference_component_name : str  |  None  =  None ) ->  Dict [str , Any ]:
57+ def  invoke_endpoint (runtime , sm_client ,  endpoint_name : str , prompt : str , max_tokens : int  =  100 , temperature : float  =  0.7 , inference_component_name : str  |  None  =  None ) ->  Dict [str , Any ]:
4658    """ 
4759    Invoke the SageMaker endpoint with a single prompt. 
4860    """ 
@@ -55,24 +67,26 @@ def invoke_endpoint(runtime, endpoint_name: str, prompt: str, max_tokens: int =
5567        "presence_penalty" : 0.0 
5668    }
5769
70+     # Use the provided IC name (resolved once in main) 
71+     ic_to_use  =  inference_component_name 
72+     
5873    try :
5974        kwargs  =  {
6075            "EndpointName" : endpoint_name ,
6176            "ContentType" : 'application/json' ,
6277            "Body" : json .dumps (payload ),
6378        }
64-         if  inference_component_name :
65-             kwargs ["InferenceComponentName" ] =  inference_component_name 
79+         if  ic_to_use :
80+             kwargs ["InferenceComponentName" ] =  ic_to_use 
6681
6782        response  =  runtime .invoke_endpoint (** kwargs )
68-         
6983        result  =  json .loads (response ['Body' ].read ().decode ())
7084        return  result 
7185    except  Exception  as  e :
7286        print (f"Error invoking endpoint: { str (e )}  )
7387        return  None 
7488
75- def  invoke_chat_endpoint (runtime , endpoint_name : str , messages : List [Dict [str , str ]], max_tokens : int  =  100 , temperature : float  =  0.7 , inference_component_name : str  |  None  =  None ) ->  Dict [str , Any ]:
89+ def  invoke_chat_endpoint (runtime , sm_client ,  endpoint_name : str , messages : List [Dict [str , str ]], max_tokens : int  =  100 , temperature : float  =  0.7 , inference_component_name : str  |  None  =  None ) ->  Dict [str , Any ]:
7690    """ 
7791    Invoke the SageMaker endpoint using chat format. 
7892    """ 
@@ -83,24 +97,26 @@ def invoke_chat_endpoint(runtime, endpoint_name: str, messages: List[Dict[str, s
8397        "top_p" : 0.9 
8498    }
8599
100+     # Use the provided IC name (resolved once in main) 
101+     ic_to_use  =  inference_component_name 
102+ 
86103    try :
87104        kwargs  =  {
88105            "EndpointName" : endpoint_name ,
89106            "ContentType" : 'application/json' ,
90107            "Body" : json .dumps (payload ),
91108        }
92-         if  inference_component_name :
93-             kwargs ["InferenceComponentName" ] =  inference_component_name 
109+         if  ic_to_use :
110+             kwargs ["InferenceComponentName" ] =  ic_to_use 
94111
95112        response  =  runtime .invoke_endpoint (** kwargs )
96-         
97113        result  =  json .loads (response ['Body' ].read ().decode ())
98114        return  result 
99115    except  Exception  as  e :
100116        print (f"Error invoking endpoint: { str (e )}  )
101117        return  None 
102118
103- def  test_completions_api (runtime , endpoint_name : str , inference_component_name : str  |  None ) ->  tuple [int , int ]:
119+ def  test_completions_api (runtime , sm_client ,  endpoint_name : str , inference_component_name : str  |  None ) ->  tuple [int , int ]:
104120    """ 
105121    Test the completions API with various prompts. 
106122    """ 
@@ -123,7 +139,7 @@ def test_completions_api(runtime, endpoint_name: str, inference_component_name:
123139        print ("-"  *  30 )
124140
125141        start_time  =  time .time ()
126-         result  =  invoke_endpoint (runtime , endpoint_name , prompt , max_tokens = 50 , inference_component_name = inference_component_name )
142+         result  =  invoke_endpoint (runtime , sm_client ,  endpoint_name , prompt , max_tokens = 50 , inference_component_name = inference_component_name )
127143        elapsed_time  =  time .time () -  start_time 
128144
129145        if  result :
@@ -141,7 +157,7 @@ def test_completions_api(runtime, endpoint_name: str, inference_component_name:
141157
142158    return  successes , failures 
143159
144- def  test_chat_api (runtime , endpoint_name : str , inference_component_name : str  |  None ) ->  tuple [int , int ]:
160+ def  test_chat_api (runtime , sm_client ,  endpoint_name : str , inference_component_name : str  |  None ) ->  tuple [int , int ]:
145161    """ 
146162    Test the chat completions API with conversation. 
147163    """ 
@@ -171,7 +187,7 @@ def test_chat_api(runtime, endpoint_name: str, inference_component_name: str | N
171187        print (f"Messages: { json .dumps (messages , indent = 2 )}  )
172188
173189        start_time  =  time .time ()
174-         result  =  invoke_chat_endpoint (runtime , endpoint_name , messages , max_tokens = 100 , inference_component_name = inference_component_name )
190+         result  =  invoke_chat_endpoint (runtime , sm_client ,  endpoint_name , messages , max_tokens = 100 , inference_component_name = inference_component_name )
175191        elapsed_time  =  time .time () -  start_time 
176192
177193        if  result :
@@ -199,13 +215,13 @@ def main():
199215
200216    runtime  =  boto3 .client ('runtime.sagemaker' , region_name = args .region )
201217    sm_client  =  boto3 .client ('sagemaker' , region_name = args .region )
202-     ic_name  =  _resolve_inference_component_name (sm_client , args .endpoint_name )
218+     ic_name  =  _resolve_inference_component_name (sm_client , args .endpoint_name ,  debug = True )
203219    if  ic_name :
204-         print (f"[info] Using  inference component: { ic_name }  )
220+         print (f"[info] Initial  inference component: { ic_name }  )
205221
206-     # Run tests 
207-     comp_ok , comp_fail  =  test_completions_api (runtime , args .endpoint_name , ic_name )
208-     chat_ok , chat_fail  =  test_chat_api (runtime , args .endpoint_name , ic_name )
222+     # Run tests with the resolved IC name  
223+     comp_ok , comp_fail  =  test_completions_api (runtime , sm_client ,  args .endpoint_name , ic_name )
224+     chat_ok , chat_fail  =  test_chat_api (runtime , sm_client ,  args .endpoint_name , ic_name )
209225
210226    total_ok  =  comp_ok  +  chat_ok 
211227    total_fail  =  comp_fail  +  chat_fail 
0 commit comments