diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 9e2f0f0dd4..611f6757b0 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -163,6 +163,23 @@ def _is_nova_recipe(recipe): return bool(has_nova_model) or bool(has_distillation) +def _is_eval_recipe(recipe): + """Check if the recipe is an eval recipe. + + An eval recipe is identified by: + 1. Having a evaluation section + + Args: + recipe (OmegaConf): The loaded recipe configuration + + Returns: + bool: True if the recipe is an eval recipe, False otherwise + """ + # Check for eval model + eval_config = recipe.get("evaluation", {}) + return bool(eval_config) + + def _recipe_initialize_args(source_dir): """Initialize the arguments dictionary for recipe setup. @@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe): if "instance_type" not in kwargs: raise ValueError("Must pass instance type to estimator when using training recipes.") - if not _is_nova_recipe(recipe) and "trainer" not in recipe: + if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe): raise ValueError("Supplied recipe does not contain required field trainer.") instance_type = kwargs["instance_type"].split(".")[1] @@ -973,7 +990,7 @@ def _device_handle_instance_count(kwargs, recipe): """ # Check if instance_count is already provided in kwargs - is_nova = _is_nova_recipe(recipe) + is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe) if "instance_count" in kwargs: # Warn if there are conflicting configurations in the recipe if "num_nodes" in recipe.get("trainer", {}): @@ -981,7 +998,7 @@ def _device_handle_instance_count(kwargs, recipe): "Using instance_count argument to estimator to set number " "of nodes. Ignoring trainer -> num_nodes in recipe." ) - if is_nova and "replicas" in recipe.get("run", {}): + if is_nova_or_eval and "replicas" in recipe.get("run", {}): logger.warning( "Using instance_count argument to estimator to set number " "of nodes. Ignoring run -> replicas in recipe." @@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe): kwargs["instance_count"] = recipe["trainer"]["num_nodes"] return - if is_nova and "run" in recipe and "replicas" in recipe["run"]: + if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]: kwargs["instance_count"] = recipe["run"]["replicas"] return