15
15
build_pydantic_error_message ,
16
16
get_resource_type ,
17
17
load_config ,
18
+ load_gpu_shapes_index ,
18
19
)
19
20
from ads .aqua .shaperecommend .constants import (
20
21
SAFETENSORS ,
22
+ SHAPE_MAP ,
21
23
TEXT_GENERATION ,
22
24
TROUBLESHOOT_MSG ,
23
25
)
30
32
ShapeReport ,
31
33
)
32
34
from ads .model .datascience_model import DataScienceModel
35
+ from ads .model .service .oci_datascience_model_deployment import (
36
+ OCIDataScienceModelDeployment ,
37
+ )
33
38
34
39
35
- class AquaShapeRecommend ( BaseModel ) :
40
+ class AquaShapeRecommend :
36
41
"""
37
42
Interface for recommending GPU shapes for machine learning model deployments
38
43
on Oracle Cloud Infrastructure Data Science service.
@@ -42,7 +47,7 @@ class AquaShapeRecommend(BaseModel):
42
47
Must be used within a properly configured and authenticated OCI environment.
43
48
"""
44
49
45
- def which_shapes (self , ** kwargs ) -> Union [ShapeRecommendationReport , Table ]:
50
+ def which_shapes (self , request : RequestRecommend ) -> Union [ShapeRecommendationReport , Table ]:
46
51
"""
47
52
Lists valid GPU deployment shapes for the provided model and configuration.
48
53
@@ -77,7 +82,8 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
77
82
If parameters are missing or invalid, or if no valid sequence length is requested.
78
83
"""
79
84
try :
80
- request = RequestRecommend (** kwargs )
85
+ shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
86
+
81
87
ds_model = self ._validate_model_ocid (request .model_id )
82
88
data = self ._get_model_config (ds_model )
83
89
@@ -86,7 +92,7 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
86
92
model_name = ds_model .display_name if ds_model .display_name else ""
87
93
88
94
shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
89
- llm_config , request . shapes , model_name
95
+ llm_config , shapes , model_name
90
96
)
91
97
92
98
if request .generate_table and shape_recommendation_report .recommendations :
@@ -107,10 +113,61 @@ def which_shapes(self, **kwargs) -> Union[ShapeRecommendationReport, Table]:
107
113
) from ex
108
114
except AquaValueError as ex :
109
115
logger .error (f"Error with LLM config: { ex } " )
110
- raise
116
+ raise AquaValueError ( # noqa: B904
117
+ f"An error occured while producing recommendations: { ex } "
118
+ )
111
119
112
120
return shape_recommendation_report
113
121
122
+ def valid_compute_shapes (self , compartment_id : str ) -> List ["ComputeShapeSummary" ]:
123
+ """
124
+ Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
125
+
126
+ Parameters
127
+ ----------
128
+ file : str
129
+ Path to the JSON file containing shape data.
130
+
131
+ Returns
132
+ -------
133
+ List[ComputeShapeSummary]
134
+ List of ComputeShapeSummary objects passing the checks.
135
+
136
+ Raises
137
+ ------
138
+ ValueError
139
+ If the file cannot be opened, parsed, or the 'shapes' key is missing.
140
+ """
141
+ oci_shapes = OCIDataScienceModelDeployment .shapes (compartment_id = compartment_id )
142
+ set_user_shapes = {shape .name : shape for shape in oci_shapes }
143
+
144
+ gpu_shapes_metadata = load_gpu_shapes_index ().shapes
145
+
146
+ valid_shapes = []
147
+ # only loops through GPU shapes, update later to include CPU shapes
148
+ for name , spec in gpu_shapes_metadata .items ():
149
+ if name in set_user_shapes :
150
+ oci_shape = set_user_shapes .get (name )
151
+
152
+ compute_shape = ComputeShapeSummary (
153
+ available = True ,
154
+ core_count = oci_shape .core_count ,
155
+ memory_in_gbs = oci_shape .memory_in_gbs ,
156
+ shape_series = SHAPE_MAP .get (oci_shape .shape_series , "GPU" ),
157
+ name = oci_shape .name ,
158
+ gpu_specs = spec ,
159
+ )
160
+ else :
161
+ compute_shape = ComputeShapeSummary (
162
+ available = False , name = name , shape_series = "GPU" , gpu_specs = spec
163
+ )
164
+ valid_shapes .append (compute_shape )
165
+
166
+ valid_shapes .sort (
167
+ key = lambda shape : shape .gpu_specs .gpu_memory_in_gbs , reverse = True
168
+ )
169
+ return valid_shapes
170
+
114
171
@staticmethod
115
172
def _rich_diff_table (shape_report : ShapeRecommendationReport ) -> Table :
116
173
"""
@@ -321,7 +378,7 @@ def _summarize_shapes_for_seq_lens(
321
378
recommendations = []
322
379
323
380
if not shapes :
324
- raise ValueError (
381
+ raise AquaValueError (
325
382
"No GPU shapes were passed for recommendation. Ensure shape parsing succeeded."
326
383
)
327
384
0 commit comments