Skip to content

Commit 2274961

Browse files
authored
feat: ensure correct inference component resolution (#142)
1 parent a2d876e commit 2274961

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

integrations/SageMaker/testing_scripts/test_endpoint.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,28 @@ def _resolve_inference_component_name(sm_client, endpoint_name: str, debug: bool
3333
break
3434
if not ics:
3535
return None
36-
name = ics[0].get("InferenceComponentName")
36+
37+
# Prefer components that are InService; fall back to the most recently created otherwise
38+
def sort_key(ic: dict):
39+
# CreationTime is a datetime; if missing, use 0
40+
return (
41+
ic.get("InferenceComponentStatus") == "InService",
42+
ic.get("CreationTime") or 0,
43+
)
44+
45+
ics_sorted = sorted(ics, key=sort_key, reverse=True)
46+
chosen = ics_sorted[0]
47+
name = chosen.get("InferenceComponentName")
3748
if name and debug:
38-
print(f"[debug] Using inference component: {name}")
49+
status = chosen.get("InferenceComponentStatus")
50+
print(f"[debug] Using inference component: {name} (status={status})")
3951
return name
4052
except Exception as exc:
4153
if debug:
4254
print(f"[debug] Could not list inference components: {exc}. Assuming non-IC endpoint.")
4355
return None
4456

45-
def invoke_endpoint(runtime, endpoint_name: str, prompt: str, max_tokens: int = 100, temperature: float = 0.7, inference_component_name: str | None = None) -> Dict[str, Any]:
57+
def invoke_endpoint(runtime, sm_client, endpoint_name: str, prompt: str, max_tokens: int = 100, temperature: float = 0.7, inference_component_name: str | None = None) -> Dict[str, Any]:
4658
"""
4759
Invoke the SageMaker endpoint with a single prompt.
4860
"""
@@ -55,24 +67,26 @@ def invoke_endpoint(runtime, endpoint_name: str, prompt: str, max_tokens: int =
5567
"presence_penalty": 0.0
5668
}
5769

70+
# Use the provided IC name (resolved once in main)
71+
ic_to_use = inference_component_name
72+
5873
try:
5974
kwargs = {
6075
"EndpointName": endpoint_name,
6176
"ContentType": 'application/json',
6277
"Body": json.dumps(payload),
6378
}
64-
if inference_component_name:
65-
kwargs["InferenceComponentName"] = inference_component_name
79+
if ic_to_use:
80+
kwargs["InferenceComponentName"] = ic_to_use
6681

6782
response = runtime.invoke_endpoint(**kwargs)
68-
6983
result = json.loads(response['Body'].read().decode())
7084
return result
7185
except Exception as e:
7286
print(f"Error invoking endpoint: {str(e)}")
7387
return None
7488

75-
def invoke_chat_endpoint(runtime, endpoint_name: str, messages: List[Dict[str, str]], max_tokens: int = 100, temperature: float = 0.7, inference_component_name: str | None = None) -> Dict[str, Any]:
89+
def invoke_chat_endpoint(runtime, sm_client, endpoint_name: str, messages: List[Dict[str, str]], max_tokens: int = 100, temperature: float = 0.7, inference_component_name: str | None = None) -> Dict[str, Any]:
7690
"""
7791
Invoke the SageMaker endpoint using chat format.
7892
"""
@@ -83,24 +97,26 @@ def invoke_chat_endpoint(runtime, endpoint_name: str, messages: List[Dict[str, s
8397
"top_p": 0.9
8498
}
8599

100+
# Use the provided IC name (resolved once in main)
101+
ic_to_use = inference_component_name
102+
86103
try:
87104
kwargs = {
88105
"EndpointName": endpoint_name,
89106
"ContentType": 'application/json',
90107
"Body": json.dumps(payload),
91108
}
92-
if inference_component_name:
93-
kwargs["InferenceComponentName"] = inference_component_name
109+
if ic_to_use:
110+
kwargs["InferenceComponentName"] = ic_to_use
94111

95112
response = runtime.invoke_endpoint(**kwargs)
96-
97113
result = json.loads(response['Body'].read().decode())
98114
return result
99115
except Exception as e:
100116
print(f"Error invoking endpoint: {str(e)}")
101117
return None
102118

103-
def test_completions_api(runtime, endpoint_name: str, inference_component_name: str | None) -> tuple[int, int]:
119+
def test_completions_api(runtime, sm_client, endpoint_name: str, inference_component_name: str | None) -> tuple[int, int]:
104120
"""
105121
Test the completions API with various prompts.
106122
"""
@@ -123,7 +139,7 @@ def test_completions_api(runtime, endpoint_name: str, inference_component_name:
123139
print("-" * 30)
124140

125141
start_time = time.time()
126-
result = invoke_endpoint(runtime, endpoint_name, prompt, max_tokens=50, inference_component_name=inference_component_name)
142+
result = invoke_endpoint(runtime, sm_client, endpoint_name, prompt, max_tokens=50, inference_component_name=inference_component_name)
127143
elapsed_time = time.time() - start_time
128144

129145
if result:
@@ -141,7 +157,7 @@ def test_completions_api(runtime, endpoint_name: str, inference_component_name:
141157

142158
return successes, failures
143159

144-
def test_chat_api(runtime, endpoint_name: str, inference_component_name: str | None) -> tuple[int, int]:
160+
def test_chat_api(runtime, sm_client, endpoint_name: str, inference_component_name: str | None) -> tuple[int, int]:
145161
"""
146162
Test the chat completions API with conversation.
147163
"""
@@ -171,7 +187,7 @@ def test_chat_api(runtime, endpoint_name: str, inference_component_name: str | N
171187
print(f"Messages: {json.dumps(messages, indent=2)}")
172188

173189
start_time = time.time()
174-
result = invoke_chat_endpoint(runtime, endpoint_name, messages, max_tokens=100, inference_component_name=inference_component_name)
190+
result = invoke_chat_endpoint(runtime, sm_client, endpoint_name, messages, max_tokens=100, inference_component_name=inference_component_name)
175191
elapsed_time = time.time() - start_time
176192

177193
if result:
@@ -199,13 +215,13 @@ def main():
199215

200216
runtime = boto3.client('runtime.sagemaker', region_name=args.region)
201217
sm_client = boto3.client('sagemaker', region_name=args.region)
202-
ic_name = _resolve_inference_component_name(sm_client, args.endpoint_name)
218+
ic_name = _resolve_inference_component_name(sm_client, args.endpoint_name, debug=True)
203219
if ic_name:
204-
print(f"[info] Using inference component: {ic_name}")
220+
print(f"[info] Initial inference component: {ic_name}")
205221

206-
# Run tests
207-
comp_ok, comp_fail = test_completions_api(runtime, args.endpoint_name, ic_name)
208-
chat_ok, chat_fail = test_chat_api(runtime, args.endpoint_name, ic_name)
222+
# Run tests with the resolved IC name
223+
comp_ok, comp_fail = test_completions_api(runtime, sm_client, args.endpoint_name, ic_name)
224+
chat_ok, chat_fail = test_chat_api(runtime, sm_client, args.endpoint_name, ic_name)
209225

210226
total_ok = comp_ok + chat_ok
211227
total_fail = comp_fail + chat_fail

0 commit comments

Comments
 (0)