Skip to content

Commit 79c8294

Browse files
sylvie7788xibei chen
andauthored
feature: add model_type hyperparameter support for Nova recipes (#5291)
Co-authored-by: xibei chen <[email protected]>
1 parent 143c128 commit 79c8294

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,9 @@ def _setup_for_nova_recipe(
11801180
# Set up Nova-specific configuration
11811181
run_config = recipe.get("run", {})
11821182
model_name_or_path = run_config.get("model_name_or_path")
1183+
# Set hyperparameters model_type
1184+
model_type = run_config.get("model_type")
1185+
args["hyperparameters"]["model_type"] = model_type
11831186

11841187
# Set hyperparameters based on model_name_or_path
11851188
if model_name_or_path:

tests/unit/test_pytorch_nova.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,40 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se
795795
pytorch._hyperparameters.get("role_arn")
796796
== "arn:aws:iam::123456789012:role/SageMakerRole"
797797
)
798+
799+
800+
@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save")
801+
def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_session):
802+
"""Test that _setup_for_nova_recipe correctly sets model_type hyperparameter."""
803+
# Create a mock nova recipe with model_type
804+
recipe = OmegaConf.create(
805+
{
806+
"run": {
807+
"model_type": "amazon.nova.llama-2-7b",
808+
"model_name_or_path": "llama/llama-2-7b",
809+
"replicas": 1,
810+
}
811+
}
812+
)
813+
814+
with patch(
815+
"sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe)
816+
):
817+
mock_resolve_save.return_value = recipe
818+
819+
pytorch = PyTorch(
820+
training_recipe="nova_recipe",
821+
role=ROLE,
822+
sagemaker_session=sagemaker_session,
823+
instance_count=INSTANCE_COUNT,
824+
instance_type=INSTANCE_TYPE_GPU,
825+
image_uri=IMAGE_URI,
826+
framework_version="1.13.1",
827+
py_version="py3",
828+
)
829+
830+
# Check that the Nova recipe was correctly identified
831+
assert pytorch.is_nova_recipe is True
832+
833+
# Verify that model_type hyperparameter was set correctly
834+
assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"

0 commit comments

Comments
 (0)