@@ -795,3 +795,40 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se
795
795
pytorch ._hyperparameters .get ("role_arn" )
796
796
== "arn:aws:iam::123456789012:role/SageMakerRole"
797
797
)
798
+
799
+
800
+ @patch ("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save" )
801
+ def test_setup_for_nova_recipe_sets_model_type (mock_resolve_save , sagemaker_session ):
802
+ """Test that _setup_for_nova_recipe correctly sets model_type hyperparameter."""
803
+ # Create a mock nova recipe with model_type
804
+ recipe = OmegaConf .create (
805
+ {
806
+ "run" : {
807
+ "model_type" : "amazon.nova.llama-2-7b" ,
808
+ "model_name_or_path" : "llama/llama-2-7b" ,
809
+ "replicas" : 1 ,
810
+ }
811
+ }
812
+ )
813
+
814
+ with patch (
815
+ "sagemaker.pytorch.estimator.PyTorch._recipe_load" , return_value = ("nova_recipe" , recipe )
816
+ ):
817
+ mock_resolve_save .return_value = recipe
818
+
819
+ pytorch = PyTorch (
820
+ training_recipe = "nova_recipe" ,
821
+ role = ROLE ,
822
+ sagemaker_session = sagemaker_session ,
823
+ instance_count = INSTANCE_COUNT ,
824
+ instance_type = INSTANCE_TYPE_GPU ,
825
+ image_uri = IMAGE_URI ,
826
+ framework_version = "1.13.1" ,
827
+ py_version = "py3" ,
828
+ )
829
+
830
+ # Check that the Nova recipe was correctly identified
831
+ assert pytorch .is_nova_recipe is True
832
+
833
+ # Verify that model_type hyperparameter was set correctly
834
+ assert pytorch ._hyperparameters .get ("model_type" ) == "amazon.nova.llama-2-7b"
0 commit comments