Skip to content

Commit 4a2c63a

Browse files
committed
started unit tests, added trust-remote-code param detection
1 parent fb33939 commit 4a2c63a

19 files changed

+1051
-73
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,7 @@ def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]:
12911291
"""
12921292
deployment_config = self.get_deployment_config(model_id=kwargs.get("model_id"))
12931293
kwargs["deployment_config"] = deployment_config
1294-
print(deployment_config)
1294+
12951295
try:
12961296
request = RequestRecommend(**kwargs)
12971297
except ValidationError as e:

ads/aqua/shaperecommend/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
VLLM_PARAMS = {
8989
"max_model_len": "--max-model-len",
9090
"in_flight_quant": "--quantization bitsandbytes --load-format bitsandbytes",
91+
"trust_remote_code": "--trust-remote-code"
9192
}
9293

9394
DEFAULT_WEIGHT_SIZE = "bfloat16"

ads/aqua/shaperecommend/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def construct_deployment_params(self) -> str:
131131
# vLLM only supports 4bit in-flight quantization
132132
params.append(VLLM_PARAMS["in_flight_quant"])
133133

134+
# add trust-remote-code if custom modules are specified
135+
if c.trust_remote_code:
136+
params.append(VLLM_PARAMS["trust_remote_code"])
137+
134138
params = " ".join(params) if params else ""
135139
return params
136140

ads/aqua/shaperecommend/llm_config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ class LLMConfig(BaseModel):
7575
None, description="For MoE architectures, size of the MLP activation layer."
7676
)
7777

78-
tie_word_embeddings: Optional[bool] = Field(None)
78+
tie_word_embeddings: Optional[bool] = Field(True, description="if True, input and output embedding matrices share the same parameters in memory.")
79+
80+
trust_remote_code: Optional[bool] = Field(False, description="if True, the model requires custom code to operate.")
7981

8082
@property
8183
def bytes_per_parameter(self) -> float:
@@ -208,6 +210,17 @@ def validate_model_support(cls, raw: dict) -> ValueError:
208210
"Encoder-decoder models (ex. T5, Gemma) and encoder-only (BERT) are not supported at this time."
209211
)
210212

213+
@staticmethod
214+
def get_bool(raw, key, default=False):
215+
val = raw.get(key)
216+
if val is None:
217+
return default
218+
if isinstance(val, bool):
219+
return val
220+
if isinstance(val, str):
221+
return val.lower() == "true"
222+
return bool(val)
223+
211224
@classmethod
212225
def from_raw_config(cls, raw: dict) -> "LLMConfig":
213226
"""
@@ -258,6 +271,10 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
258271
"intermediate_size"
259272
)
260273

274+
tie_word_embeddings = LLMConfig.get_bool(raw, "tie_word_embeddings", True)
275+
276+
trust_remote_code = "auto_map" in raw # trust-remote-code is always needed when this key is present
277+
261278
# Type safety: minimal assertion
262279
if None in [
263280
num_hidden_layers,
@@ -281,4 +298,6 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
281298
max_seq_len=int(max_seq_len),
282299
num_local_experts=num_local_experts,
283300
intermediate_size=intermediate_size,
301+
tie_word_embeddings=tie_word_embeddings,
302+
trust_remote_code=trust_remote_code
284303
)

ads/aqua/shaperecommend/shape_report.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -250,55 +250,83 @@ class ShapeRecommendationReport(BaseModel):
250250
@classmethod
251251
def from_deployment_config(cls, deployment_config: AquaDeploymentConfig, model_name: str, valid_shapes: List[ComputeShapeSummary]) -> "ShapeRecommendationReport":
252252
"""
253-
For service models, pre-set deployment configurations (AquaDeploymentConfig) are available.
254-
Derives ShapeRecommendationReport from AquaDeploymentConfig (if service model & available)
253+
Creates a ShapeRecommendationReport from an AquaDeploymentConfig, extracting recommended
254+
model configurations for each valid compute shape.
255+
256+
Parameters
257+
----------
258+
deployment_config : AquaDeploymentConfig
259+
The object containing per-shape deployment configurations.
260+
model_name : str
261+
The name of the model for which to generate recommendations.
262+
valid_shapes : list of ComputeShapeSummary
263+
List of compute shapes to evaluate and recommend deployment configurations for.
264+
265+
Returns
266+
-------
267+
ShapeRecommendationReport
268+
Report containing recommendations for each valid compute shape.
269+
270+
Notes
271+
-----
272+
For service models, this method interprets pre-set deployment configurations to derive
273+
recommendations for each allowed compute shape, including environment variables, quantization,
274+
and maximum model length parameters.
255275
"""
256276

257277
recs = []
258-
# may need to sort?
259278
for shape in valid_shapes:
260279
current_config = deployment_config.configuration.get(shape.name)
261-
if current_config:
262-
quantization = None
263-
max_model_len = None
264-
recommendation = ""
265-
current_params = current_config.parameters.get(VLLM_PARAMS_KEY)
266-
current_env = current_config.env.get(VLLM_ENV_KEY)
280+
if not current_config:
281+
continue
267282

268-
if current_params:
269-
param_list = current_params.split()
283+
quantization = None
284+
max_model_len = None
285+
recommendation = ""
286+
current_params = current_config.parameters.get(VLLM_PARAMS_KEY)
287+
current_env = current_config.env.get(VLLM_ENV_KEY)
270288

271-
if QUANT_FLAG in param_list and (idx := param_list.index(QUANT_FLAG)) + 1 < len(param_list):
289+
if current_params:
290+
param_list = current_params.split()
291+
292+
if QUANT_FLAG in param_list:
293+
idx = param_list.index(QUANT_FLAG)
294+
if idx + 1 < len(param_list):
272295
quantization = param_list[idx + 1]
273296

274-
if MAX_MODEL_LEN_FLAG in param_list and (idx := param_list.index(MAX_MODEL_LEN_FLAG)) + 1 < len(param_list):
275-
max_model_len = param_list[idx + 1]
276-
max_model_len = int(max_model_len)
297+
if MAX_MODEL_LEN_FLAG in param_list:
298+
idx = param_list.index(MAX_MODEL_LEN_FLAG)
299+
if idx + 1 < len(param_list):
300+
try:
301+
max_model_len = int(param_list[idx + 1])
302+
except ValueError:
303+
max_model_len = None
277304

278-
if current_env:
279-
recommendation += f"ENV: {json.dumps(current_env)}\n\n"
305+
if current_env:
306+
recommendation += f"ENV: {json.dumps(current_env)}\n\n"
280307

281-
recommendation += "Model fits well within the allowed compute shape."
308+
if not current_params and not current_env: # model works with default params and no extra env variables
309+
recommendation += "No override PARAMS and ENV variables needed. \n\n"
282310

283-
deployment_params = DeploymentParams(
284-
quantization=quantization if quantization else DEFAULT_WEIGHT_SIZE,
285-
max_model_len=max_model_len,
286-
params=current_params if current_params else "",
287-
)
311+
recommendation += "Model fits well within the allowed compute shape."
288312

289-
# TODO: calculate memory footprint based on params??
290-
# TODO: add --env vars not just params, current_config.env
291-
# are there multiple configurations in the SMM configs per shape??
292-
configuration = [ModelConfig(
293-
deployment_params=deployment_params,
294-
recommendation=recommendation,
295-
)]
296-
297-
recs.append(ShapeReport(
298-
shape_details=shape,
299-
configurations=configuration
300-
)
301-
)
313+
deployment_params = DeploymentParams(
314+
quantization=quantization if quantization else DEFAULT_WEIGHT_SIZE,
315+
max_model_len=max_model_len,
316+
params=current_params if current_params else "",
317+
)
318+
319+
# need to adjust for multiple configs per shape
320+
configuration = [ModelConfig(
321+
deployment_params=deployment_params,
322+
recommendation=recommendation,
323+
)]
324+
325+
recs.append(ShapeReport(
326+
shape_details=shape,
327+
configurations=configuration
328+
)
329+
)
302330

303331
return ShapeRecommendationReport(
304332
display_name=model_name,

0 commit comments

Comments
 (0)