Skip to content

Commit 0d7ade4

Browse files
committed
addressed comments
1 parent fcb5162 commit 0d7ade4

File tree

10 files changed

+188
-171
lines changed

10 files changed

+188
-171
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,11 @@ def get(self, id: Union[str, List[str]] = None):
5858
model_id=id.split(",") if "," in id else id
5959
)
6060
elif paths.startswith("aqua/deployments/recommend_shapes"):
61-
id = id or self.get_argument("model_id", default=None)
6261
if not id or not isinstance(id, str):
6362
raise HTTPError(
6463
400,
6564
f"Invalid request format for {self.request.path}. "
66-
"Expected a single model OCID",
65+
"Expected a single model OCID specified as --model_id",
6766
)
6867
id = id.replace(" ", "")
6968
return self.get_recommend_shape(model_id=id)
@@ -189,14 +188,10 @@ def get_recommend_shape(self, model_id: str):
189188

190189
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
191190

192-
generate_table = (
193-
self.get_argument("generate_table", default="True").lower() == "true"
194-
)
195-
196191
recommend_report = app.recommend_shape(
197192
model_id=model_id,
198193
compartment_id=compartment_id,
199-
generate_table=generate_table,
194+
generate_table=False,
200195
)
201196

202197
return self.finish(recommend_report)

ads/aqua/modeldeployment/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,3 @@
1212
DEFAULT_WAIT_TIME = 12000
1313
DEFAULT_POLL_INTERVAL = 10
1414

15-
SHAPE_MAP = {"NVIDIA_GPU": "GPU"}

ads/aqua/modeldeployment/deployment.py

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from ads.aqua.modeldeployment.constants import (
6868
DEFAULT_POLL_INTERVAL,
6969
DEFAULT_WAIT_TIME,
70-
SHAPE_MAP,
7170
)
7271
from ads.aqua.modeldeployment.entities import (
7372
AquaDeployment,
@@ -77,7 +76,10 @@
7776
)
7877
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
7978
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
80-
from ads.aqua.shaperecommend.shape_report import ShapeRecommendationReport
79+
from ads.aqua.shaperecommend.shape_report import (
80+
RequestRecommend,
81+
ShapeRecommendationReport,
82+
)
8183
from ads.common.object_storage_details import ObjectStorageDetails
8284
from ads.common.utils import UNKNOWN, get_log_links
8385
from ads.common.work_request import DataScienceWorkRequest
@@ -1250,60 +1252,6 @@ def validate_deployment_params(
12501252
)
12511253
return {"valid": True}
12521254

1253-
def valid_compute_shapes(self, **kwargs) -> List["ComputeShapeSummary"]:
1254-
"""
1255-
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
1256-
1257-
Parameters
1258-
----------
1259-
file : str
1260-
Path to the JSON file containing shape data.
1261-
1262-
Returns
1263-
-------
1264-
List[ComputeShapeSummary]
1265-
List of ComputeShapeSummary objects passing the checks.
1266-
1267-
Raises
1268-
------
1269-
ValueError
1270-
If the file cannot be opened, parsed, or the 'shapes' key is missing.
1271-
"""
1272-
compartment_id = kwargs.pop("compartment_id", COMPARTMENT_OCID)
1273-
oci_shapes: list[ModelDeploymentShapeSummary] = self.list_resource(
1274-
self.ds_client.list_model_deployment_shapes,
1275-
compartment_id=compartment_id,
1276-
**kwargs,
1277-
)
1278-
set_user_shapes = {shape.name: shape for shape in oci_shapes}
1279-
1280-
gpu_shapes_metadata = load_gpu_shapes_index().shapes
1281-
1282-
valid_shapes = []
1283-
# only loops through GPU shapes, update later to include CPU shapes
1284-
for name, spec in gpu_shapes_metadata.items():
1285-
if name in set_user_shapes:
1286-
oci_shape = set_user_shapes.get(name)
1287-
1288-
compute_shape = ComputeShapeSummary(
1289-
available=True,
1290-
core_count=oci_shape.core_count,
1291-
memory_in_gbs=oci_shape.memory_in_gbs,
1292-
shape_series=SHAPE_MAP.get(oci_shape.shape_series, "GPU"),
1293-
name=oci_shape.name,
1294-
gpu_specs=spec,
1295-
)
1296-
else:
1297-
compute_shape = ComputeShapeSummary(
1298-
available=False, name=name, shape_series="GPU", gpu_specs=spec
1299-
)
1300-
valid_shapes.append(compute_shape)
1301-
1302-
valid_shapes.sort(
1303-
key=lambda shape: shape.gpu_specs.gpu_memory_in_gbs, reverse=True
1304-
)
1305-
return valid_shapes
1306-
13071255
def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]:
13081256
"""
13091257
For the CLI (set generate_table = True), generates the table (in rich diff) with valid
@@ -1335,13 +1283,16 @@ def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]:
13351283
AquaValueError
13361284
If model type is unsupported by tool (no recommendation report generated)
13371285
"""
1338-
compartment_id = kwargs.get("compartment_id", COMPARTMENT_OCID)
1339-
1340-
kwargs["shapes"] = self.valid_compute_shapes(compartment_id=compartment_id)
1286+
try:
1287+
request = RequestRecommend(**kwargs)
1288+
except ValidationError as e:
1289+
custom_error = build_pydantic_error_message(e)
1290+
raise AquaValueError( # noqa: B904
1291+
f"Failed to request shape recommendation due to invalid input parameters: {custom_error}"
1292+
)
13411293

13421294
shape_recommend = AquaShapeRecommend()
1343-
1344-
shape_recommend_report = shape_recommend.which_shapes(**kwargs)
1295+
shape_recommend_report = shape_recommend.which_shapes(request)
13451296

13461297
return shape_recommend_report
13471298

ads/aqua/shaperecommend/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@
6767
"4bit": 0.5,
6868
"int4": 0.5,
6969
}
70+
71+
SHAPE_MAP = {
72+
"NVIDIA_GPU": "GPU",
73+
"AMD_ROME": "CPU",
74+
"GENERIC": "CPU",
75+
"LEGACY": "CPU",
76+
"ARM": "CPU",
77+
"UNKNOWN_ENUM_VALUE": "N/A",
78+
}

ads/aqua/shaperecommend/llm_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def validate_model_support(cls, raw: dict) -> ValueError:
212212
):
213213
raise AquaRecommendationError(
214214
"Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc). "
215-
"Encoder-decoder models (ex. T5, Gemma) and encoder-only (BERT) are not supported in this tool at this time."
215+
"Encoder-decoder models (ex. T5, Gemma) and encoder-only (BERT) are not supported at this time."
216216
)
217217

218218
@classmethod

ads/aqua/shaperecommend/recommend.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
build_pydantic_error_message,
1616
get_resource_type,
1717
load_config,
18+
load_gpu_shapes_index,
1819
)
1920
from ads.aqua.shaperecommend.constants import (
2021
SAFETENSORS,
22+
SHAPE_MAP,
2123
TEXT_GENERATION,
2224
TROUBLESHOOT_MSG,
2325
)
@@ -30,9 +32,12 @@
3032
ShapeReport,
3133
)
3234
from ads.model.datascience_model import DataScienceModel
35+
from ads.model.service.oci_datascience_model_deployment import (
36+
OCIDataScienceModelDeployment,
37+
)
3338

3439

35-
class AquaShapeRecommend(BaseModel):
40+
class AquaShapeRecommend:
3641
"""
3742
Interface for recommending GPU shapes for machine learning model deployments
3843
on Oracle Cloud Infrastructure Data Science service.
@@ -42,7 +47,7 @@ class AquaShapeRecommend(BaseModel):
4247
Must be used within a properly configured and authenticated OCI environment.
4348
"""
4449

45-
def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
50+
def which_shapes(self, request: RequestRecommend) -> Union[ShapeRecommendationReport, Table]:
4651
"""
4752
Lists valid GPU deployment shapes for the provided model and configuration.
4853
@@ -77,7 +82,8 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
7782
If parameters are missing or invalid, or if no valid sequence length is requested.
7883
"""
7984
try:
80-
request = RequestRecommend(**kwargs)
85+
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
86+
8187
ds_model = self._validate_model_ocid(request.model_id)
8288
data = self._get_model_config(ds_model)
8389

@@ -86,7 +92,7 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
8692
model_name = ds_model.display_name if ds_model.display_name else ""
8793

8894
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
89-
llm_config, request.shapes, model_name
95+
llm_config, shapes, model_name
9096
)
9197

9298
if request.generate_table and shape_recommendation_report.recommendations:
@@ -107,10 +113,61 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
107113
) from ex
108114
except AquaValueError as ex:
109115
logger.error(f"Error with LLM config: {ex}")
110-
raise
116+
raise AquaValueError( # noqa: B904
117+
f"An error occured while producing recommendations: {ex}"
118+
)
111119

112120
return shape_recommendation_report
113121

122+
def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]:
123+
"""
124+
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
125+
126+
Parameters
127+
----------
128+
file : str
129+
Path to the JSON file containing shape data.
130+
131+
Returns
132+
-------
133+
List[ComputeShapeSummary]
134+
List of ComputeShapeSummary objects passing the checks.
135+
136+
Raises
137+
------
138+
ValueError
139+
If the file cannot be opened, parsed, or the 'shapes' key is missing.
140+
"""
141+
oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id)
142+
set_user_shapes = {shape.name: shape for shape in oci_shapes}
143+
144+
gpu_shapes_metadata = load_gpu_shapes_index().shapes
145+
146+
valid_shapes = []
147+
# only loops through GPU shapes, update later to include CPU shapes
148+
for name, spec in gpu_shapes_metadata.items():
149+
if name in set_user_shapes:
150+
oci_shape = set_user_shapes.get(name)
151+
152+
compute_shape = ComputeShapeSummary(
153+
available=True,
154+
core_count=oci_shape.core_count,
155+
memory_in_gbs=oci_shape.memory_in_gbs,
156+
shape_series=SHAPE_MAP.get(oci_shape.shape_series, "GPU"),
157+
name=oci_shape.name,
158+
gpu_specs=spec,
159+
)
160+
else:
161+
compute_shape = ComputeShapeSummary(
162+
available=False, name=name, shape_series="GPU", gpu_specs=spec
163+
)
164+
valid_shapes.append(compute_shape)
165+
166+
valid_shapes.sort(
167+
key=lambda shape: shape.gpu_specs.gpu_memory_in_gbs, reverse=True
168+
)
169+
return valid_shapes
170+
114171
@staticmethod
115172
def _rich_diff_table(shape_report: ShapeRecommendationReport) -> Table:
116173
"""
@@ -321,7 +378,7 @@ def _summarize_shapes_for_seq_lens(
321378
recommendations = []
322379

323380
if not shapes:
324-
raise ValueError(
381+
raise AquaValueError(
325382
"No GPU shapes were passed for recommendation. Ensure shape parsing succeeded."
326383
)
327384

ads/aqua/shaperecommend/shape_report.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ class RequestRecommend(BaseModel):
2020
model_id: str = Field(
2121
..., description="The OCID of the model to recommend feasible compute shapes."
2222
)
23-
shapes: List[ComputeShapeSummary] = Field(
24-
..., description="The list of shapes on OCI."
25-
)
2623
generate_table: Optional[bool] = (
2724
Field(
2825
True,

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,74 @@ def test_get_model_list(self, mock_get, mock_finish):
284284
mock_finish.side_effect = lambda x: x
285285
result = self.aqua_model_list_handler.get(model_id="test_model_id")
286286
mock_get.assert_called()
287+
288+
from unittest.mock import MagicMock, patch
289+
290+
import pytest
291+
from tornado.web import HTTPError
292+
293+
from ads.aqua.extension.base_handler import AquaAPIhandler
294+
from ads.aqua.extension.errors import Errors
295+
from ads.aqua.extension.recommend_handler import AquaRecommendHandler
296+
297+
298+
@pytest.fixture
299+
def handler():
300+
# Patch AquaAPIhandler.__init__ for unit test stubbing
301+
AquaAPIhandler.__init__ = lambda self, *args, **kwargs: None
302+
h = AquaRecommendHandler(MagicMock(), MagicMock())
303+
h.finish = MagicMock()
304+
h.request = MagicMock()
305+
# Set required Tornado internal fields
306+
h._headers = {}
307+
h._write_buffer = []
308+
return h
309+
310+
311+
def test_post_valid_input(monkeypatch, handler):
312+
input_data = {"model_ocid": "ocid1.datasciencemodel.oc1.XYZ"}
313+
expected = {"recommendations": ["VM.GPU.A10.1"], "troubleshoot": ""}
314+
315+
# Patch class on correct import path, so handler sees our fake implementation
316+
class FakeAquaRecommendApp:
317+
def which_gpu(self, **kwargs):
318+
return expected
319+
320+
monkeypatch.setattr(
321+
"ads.aqua.extension.recommend_handler.AquaRecommendApp", FakeAquaRecommendApp
322+
)
323+
324+
handler.get_json_body = MagicMock(return_value=input_data)
325+
handler.post()
326+
handler.finish.assert_called_once_with(expected)
327+
328+
329+
def test_post_no_input(handler):
330+
handler.get_json_body = MagicMock(return_value=None)
331+
handler._headers = {}
332+
handler._write_buffer = []
333+
handler.write_error = MagicMock()
334+
handler.post()
335+
handler.write_error.assert_called_once()
336+
exc_info = handler.write_error.call_args.kwargs.get("exc_info")
337+
assert exc_info is not None
338+
exc_type, exc_value, _ = exc_info
339+
assert exc_type is HTTPError
340+
assert exc_value.status_code == 400
341+
assert exc_value.log_message == Errors.NO_INPUT_DATA
342+
343+
344+
def test_post_invalid_input(handler):
345+
handler.get_json_body = MagicMock(side_effect=Exception("bad input"))
346+
handler._headers = {}
347+
handler._write_buffer = []
348+
handler.write_error = MagicMock()
349+
handler.post()
350+
handler.write_error.assert_called_once()
351+
exc_info = handler.write_error.call_args.kwargs.get("exc_info")
352+
assert exc_info is not None
353+
exc_type, exc_value, _ = exc_info
354+
assert exc_type is HTTPError
355+
assert exc_value.status_code == 400
356+
assert exc_value.log_message == Errors.INVALID_INPUT_DATA_FORMAT
357+

0 commit comments

Comments
 (0)