From 658303836821fb05b78d41d3d601c264a0ad12b9 Mon Sep 17 00:00:00 2001 From: Dmytro Soltysiuk Date: Wed, 30 Apr 2025 17:03:07 -0700 Subject: [PATCH] Merge kwargs for fit and deploy Sagemaker methods --- .../cloud/backend/sagemaker_backend.py | 25 ++++++++------- src/autogluon/cloud/job/sagemaker_job.py | 31 +++++++++++-------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/autogluon/cloud/backend/sagemaker_backend.py b/src/autogluon/cloud/backend/sagemaker_backend.py index 5ce8853..ccdda78 100644 --- a/src/autogluon/cloud/backend/sagemaker_backend.py +++ b/src/autogluon/cloud/backend/sagemaker_backend.py @@ -489,19 +489,20 @@ def deploy( else: model_kwargs_env = {SAGEMAKER_MODEL_SERVER_WORKERS: "1"} - model = model_cls( - model_data=predictor_path, - role=self.role_arn, - region=self._region, - framework_version=framework_version, - py_version=py_version, - instance_type=instance_type, - custom_image_uri=custom_image_uri, - entry_point=entry_point, - predictor_cls=predictor_cls, - env=model_kwargs_env, + merged_model_kwargs = { + "model_data": predictor_path, + "role": self.role_arn, + "region": self._region, + "framework_version": framework_version, + "py_version": py_version, + "instance_type": instance_type, + "custom_image_uri": custom_image_uri, + "entry_point": entry_point, + "predictor_cls": predictor_cls, + "env": model_kwargs_env, **model_kwargs, - ) + } + model = model_cls(**merged_model_kwargs) if deploy_kwargs is None: deploy_kwargs = {} diff --git a/src/autogluon/cloud/job/sagemaker_job.py b/src/autogluon/cloud/job/sagemaker_job.py index 764a8c1..7e97a73 100644 --- a/src/autogluon/cloud/job/sagemaker_job.py +++ b/src/autogluon/cloud/job/sagemaker_job.py @@ -199,21 +199,26 @@ def run( **kwargs, ): self._local_mode = instance_type in (LOCAL_MODE, LOCAL_MODE_GPU) - sagemaker_estimator = AutoGluonSagemakerEstimator( - role=role, - entry_point=entry_point, - region=region, - instance_type=instance_type, - instance_count=instance_count, - volume_size=volume_size, - framework_version=framework_version, - py_version=py_version, - base_job_name=base_job_name, - output_path=output_path, - code_location=code_location, - image_uri=custom_image_uri, + + merged_kwargs = { + "role": role, + "entry_point": entry_point, + "region": region, + "instance_type": instance_type, + "instance_count": instance_count, + "volume_size": volume_size, + "framework_version": framework_version, + "py_version": py_version, + "base_job_name": base_job_name, + "output_path": output_path, + "code_location": code_location, + "image_uri": custom_image_uri, **autogluon_sagemaker_estimator_kwargs, + } + sagemaker_estimator = AutoGluonSagemakerEstimator( + **merged_kwargs ) + logger.log(20, f"Start sagemaker training job `{job_name}`") try: sagemaker_estimator.fit(inputs=inputs, wait=wait, job_name=job_name, **kwargs)