Skip to content

Commit 1249aec

Browse files
authored
Merge branch 'master' into fix-djl-lmi-regions
2 parents 9ac5b6c + 5d766c4 commit 1249aec

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,9 @@ def _setup_for_nova_recipe(
11801180
# Set up Nova-specific configuration
11811181
run_config = recipe.get("run", {})
11821182
model_name_or_path = run_config.get("model_name_or_path")
1183+
# Set hyperparameters model_type
1184+
model_type = run_config.get("model_type")
1185+
args["hyperparameters"]["model_type"] = model_type
11831186

11841187
# Set hyperparameters based on model_name_or_path
11851188
if model_name_or_path:

tests/integ/sagemaker/serve/test_base_model_builder_deploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test_serverless_deployment(xgboost_model_builder):
211211

212212
def test_async_deployment(xgboost_model_builder, mb_sagemaker_session):
213213
async_predictor = xgboost_model_builder.deploy(
214-
endpoint_name="test2",
214+
endpoint_name=f"test2-{uuid.uuid1().hex}",
215215
inference_config=AsyncInferenceConfig(
216216
output_path=s3_path_join(
217217
"s3://", mb_sagemaker_session.default_bucket(), "async_inference/output"

tests/unit/test_pytorch_nova.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,40 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se
795795
pytorch._hyperparameters.get("role_arn")
796796
== "arn:aws:iam::123456789012:role/SageMakerRole"
797797
)
798+
799+
800+
@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save")
801+
def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_session):
802+
"""Test that _setup_for_nova_recipe correctly sets model_type hyperparameter."""
803+
# Create a mock nova recipe with model_type
804+
recipe = OmegaConf.create(
805+
{
806+
"run": {
807+
"model_type": "amazon.nova.llama-2-7b",
808+
"model_name_or_path": "llama/llama-2-7b",
809+
"replicas": 1,
810+
}
811+
}
812+
)
813+
814+
with patch(
815+
"sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe)
816+
):
817+
mock_resolve_save.return_value = recipe
818+
819+
pytorch = PyTorch(
820+
training_recipe="nova_recipe",
821+
role=ROLE,
822+
sagemaker_session=sagemaker_session,
823+
instance_count=INSTANCE_COUNT,
824+
instance_type=INSTANCE_TYPE_GPU,
825+
image_uri=IMAGE_URI,
826+
framework_version="1.13.1",
827+
py_version="py3",
828+
)
829+
830+
# Check that the Nova recipe was correctly identified
831+
assert pytorch.is_nova_recipe is True
832+
833+
# Verify that model_type hyperparameter was set correctly
834+
assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"

0 commit comments

Comments
 (0)