From b14349b1b5bc247e66568b6eecef00c2b32aac6c Mon Sep 17 00:00:00 2001 From: pintaoz Date: Wed, 26 Mar 2025 13:24:07 -0700 Subject: [PATCH 1/4] Initialize a branch for SageMaker Core v2 --- CONTRIBUTING.md | 2 +- docs/index.rst | 4 +- example_notebooks/get_started.ipynb | 22 ++++----- .../inference_and_resource_chaining.ipynb | 14 +++--- .../intelligent_defaults_and_logging.ipynb | 22 ++++----- ...core-llama-3-8B-speculative-decoding.ipynb | 12 ++--- .../sagemaker-core-llama-3-8B.ipynb | 12 ++--- .../sagemaker_core_overview.ipynb | 14 +++--- .../track_local_pytorch_experiment.ipynb | 25 +++++------ integ/sagemaker_cleaner.py | 2 +- integ/test_codegen.py | 6 +-- pyproject.toml | 2 +- .../helper => sagemaker}/__init__.py | 0 src/sagemaker/core/__init__.py | 4 ++ .../core}/_version.py | 0 .../main => sagemaker/core}/config_schema.py | 0 .../core/helper}/__init__.py | 0 .../core}/helper/session_helper.py | 2 +- .../main => sagemaker/core}/resources.py | 24 +++++++--- .../main => sagemaker/core}/shapes.py | 2 +- src/sagemaker/core/tools/__init__.py | 1 + .../core}/tools/additional_operations.json | 0 .../core}/tools/api_coverage.json | 0 .../core}/tools/codegen.py | 8 ++-- .../core}/tools/constants.py | 8 ++-- .../core}/tools/data_extractor.py | 2 +- .../core}/tools/method.py | 2 +- .../core}/tools/resource_plan.csv | 0 .../core}/tools/resources_codegen.py | 45 +++++++++++-------- .../core}/tools/resources_extractor.py | 8 ++-- .../core}/tools/shapes_codegen.py | 14 +++--- .../core}/tools/shapes_extractor.py | 6 +-- .../core}/tools/templates.py | 0 .../core/utils}/__init__.py | 0 .../core/utils/code_injection/__init__.py | 0 .../core/utils}/code_injection/base.py | 0 .../core/utils}/code_injection/codec.py | 4 +- .../core/utils}/code_injection/constants.py | 0 .../core/utils}/code_injection/shape_dag.py | 0 .../core/utils}/exceptions.py | 0 .../utils}/intelligent_defaults_helper.py | 6 +-- .../main => sagemaker/core/utils}/logs.py | 2 +- .../core/utils}/user_agent.py | 0 .../main => sagemaker/core/utils}/utils.py | 6 +-- src/sagemaker_core/__init__.py | 4 -- src/sagemaker_core/resources/__init__.py | 1 - src/sagemaker_core/shapes/__init__.py | 1 - src/sagemaker_core/tools/__init__.py | 1 - tst/generated/test_config_schema.py | 2 +- tst/generated/test_logs.py | 4 +- tst/generated/test_resources.py | 24 +++++----- tst/generated/test_shapes.py | 6 +-- tst/generated/test_user_agent.py | 8 ++-- tst/generated/test_utils.py | 6 +-- tst/test_codec.py | 6 +-- tst/tools/test_api_coverage.py | 4 +- tst/tools/test_resources_codegen.py | 8 ++-- workflow_helper/compute_boto_api_coverage.py | 4 +- workflow_helper/compute_resource_coverage.py | 2 +- 59 files changed, 189 insertions(+), 173 deletions(-) rename src/{sagemaker_core/helper => sagemaker}/__init__.py (100%) create mode 100644 src/sagemaker/core/__init__.py rename src/{sagemaker_core => sagemaker/core}/_version.py (100%) rename src/{sagemaker_core/main => sagemaker/core}/config_schema.py (100%) rename src/{sagemaker_core/main => sagemaker/core/helper}/__init__.py (100%) rename src/{sagemaker_core => sagemaker/core}/helper/session_helper.py (99%) rename src/{sagemaker_core/main => sagemaker/core}/resources.py (99%) rename src/{sagemaker_core/main => sagemaker/core}/shapes.py (99%) create mode 100644 src/sagemaker/core/tools/__init__.py rename src/{sagemaker_core => sagemaker/core}/tools/additional_operations.json (100%) rename src/{sagemaker_core => sagemaker/core}/tools/api_coverage.json (100%) rename src/{sagemaker_core => sagemaker/core}/tools/codegen.py (87%) rename src/{sagemaker_core => sagemaker/core}/tools/constants.py (91%) rename src/{sagemaker_core => sagemaker/core}/tools/data_extractor.py (97%) rename src/{sagemaker_core => sagemaker/core}/tools/method.py (93%) rename src/{sagemaker_core => sagemaker/core}/tools/resource_plan.csv (100%) rename src/{sagemaker_core => sagemaker/core}/tools/resources_codegen.py (98%) rename src/{sagemaker_core => sagemaker/core}/tools/resources_extractor.py (98%) rename src/{sagemaker_core => sagemaker/core}/tools/shapes_codegen.py (96%) rename src/{sagemaker_core => sagemaker/core}/tools/shapes_extractor.py (98%) rename src/{sagemaker_core => sagemaker/core}/tools/templates.py (100%) rename src/{sagemaker_core/main/code_injection => sagemaker/core/utils}/__init__.py (100%) create mode 100644 src/sagemaker/core/utils/code_injection/__init__.py rename src/{sagemaker_core/main => sagemaker/core/utils}/code_injection/base.py (100%) rename src/{sagemaker_core/main => sagemaker/core/utils}/code_injection/codec.py (98%) rename src/{sagemaker_core/main => sagemaker/core/utils}/code_injection/constants.py (100%) rename src/{sagemaker_core/main => sagemaker/core/utils}/code_injection/shape_dag.py (100%) rename src/{sagemaker_core/main => sagemaker/core/utils}/exceptions.py (100%) rename src/{sagemaker_core/main => sagemaker/core/utils}/intelligent_defaults_helper.py (97%) rename src/{sagemaker_core/main => sagemaker/core/utils}/logs.py (99%) rename src/{sagemaker_core/main => sagemaker/core/utils}/user_agent.py (100%) rename src/{sagemaker_core/main => sagemaker/core/utils}/utils.py (98%) delete mode 100644 src/sagemaker_core/__init__.py delete mode 100644 src/sagemaker_core/resources/__init__.py delete mode 100644 src/sagemaker_core/shapes/__init__.py delete mode 100644 src/sagemaker_core/tools/__init__.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1aecae76..56cd3ac8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,7 +33,7 @@ source .env ## Run CodeGen * To generate all CodeGen code run the below ``` -python src/sagemaker_core/tools/codegen.py +python src/sagemaker/core/tools/codegen.py ``` ## Testing diff --git a/docs/index.rst b/docs/index.rst index 3059d9fd..ac905bd4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ SageMaker Core Resources ######################## -.. automodule:: sagemaker_core.main.resources +.. automodule:: sagemaker.core.resources :members: :noindex: @@ -13,6 +13,6 @@ SageMaker Core Resources SageMaker Core Shapes ######################## -.. automodule:: sagemaker_core.main.shapes +.. automodule:: sagemaker.core.shapes :members: :noindex: \ No newline at end of file diff --git a/example_notebooks/get_started.ipynb b/example_notebooks/get_started.ipynb index b0fb281a..ee1a29e4 100644 --- a/example_notebooks/get_started.ipynb +++ b/example_notebooks/get_started.ipynb @@ -36,7 +36,7 @@ "outputs": [], "source": [ "import time\n", - "from sagemaker_core.helper.session_helper import Session, get_execution_role\n", + "from sagemaker.core.helper.session_helper import Session, get_execution_role\n", "\n", "# Set up region, role and bucket parameters used throughout the notebook.\n", "sagemaker_session = Session()\n", @@ -197,8 +197,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.resources import TrainingJob\n", - "from sagemaker_core.shapes import AlgorithmSpecification, Channel, DataSource, S3DataSource, ResourceConfig, StoppingCondition, OutputDataConfig\n", + "from sagemaker.core.resources import TrainingJob\n", + "from sagemaker.core.shapes import AlgorithmSpecification, Channel, DataSource, S3DataSource, ResourceConfig, StoppingCondition, OutputDataConfig\n", "\n", "job_name = 'xgboost-churn-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime()) # Name of training job\n", "instance_type = 'ml.m4.xlarge' # SageMaker instance type to use for training\n", @@ -291,8 +291,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.resources import HyperParameterTuningJob\n", - "from sagemaker_core.shapes import HyperParameterTuningJobConfig, \\\n", + "from sagemaker.core.resources import HyperParameterTuningJob\n", + "from sagemaker.core.shapes import HyperParameterTuningJobConfig, \\\n", " ResourceLimits, HyperParameterTuningJobWarmStartConfig, ParameterRanges, AutoParameter, \\\n", " Autotune, HyperParameterTrainingJobDefinition, HyperParameterTuningJobObjective, HyperParameterAlgorithmSpecification, \\\n", " OutputDataConfig, StoppingCondition, ResourceConfig\n", @@ -436,8 +436,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.resources import Model\n", - "from sagemaker_core.shapes import ContainerDefinition\n", + "from sagemaker.core.resources import Model\n", + "from sagemaker.core.shapes import ContainerDefinition\n", "\n", "#model_s3_uri = training_job.model_artifacts.s3_model_artifacts # Get URI of model artifacts from the training job.\n", "model_s3_uri = TrainingJob.get(tuning_job.best_training_job.training_job_name).model_artifacts.s3_model_artifacts # Get URI of model artifacts of the best model from the tuning job.\n", @@ -469,8 +469,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.resources import TransformJob\n", - "from sagemaker_core.shapes import TransformInput, TransformDataSource, TransformS3DataSource, TransformOutput, TransformResources\n", + "from sagemaker.core.resources import TransformJob\n", + "from sagemaker.core.shapes import TransformInput, TransformDataSource, TransformS3DataSource, TransformOutput, TransformResources\n", "\n", "model_name = customer_churn_model.get_name()\n", "transform_job_name = 'churn-prediction' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime()) # Name of TranformJob\n", @@ -576,8 +576,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.resources import Endpoint, EndpointConfig\n", - "from sagemaker_core.shapes import ProductionVariant\n", + "from sagemaker.core.resources import Endpoint, EndpointConfig\n", + "from sagemaker.core.shapes import ProductionVariant\n", "\n", "endpoint_config_name = 'churn-prediction-endpoint-config' # Name of endpoint configuration\n", "model_name = customer_churn_model.get_name() # Get name of SageMaker model created in previous step\n", diff --git a/example_notebooks/inference_and_resource_chaining.ipynb b/example_notebooks/inference_and_resource_chaining.ipynb index cdae56dc..0a410056 100644 --- a/example_notebooks/inference_and_resource_chaining.ipynb +++ b/example_notebooks/inference_and_resource_chaining.ipynb @@ -135,7 +135,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.helper.session_helper import get_execution_role, Session\n", + "from sagemaker.core.helper.session_helper import get_execution_role, Session\n", "from rich import print\n", "\n", "# Get region, role, bucket\n", @@ -266,7 +266,7 @@ "# Create TrainingJob with SageMakerCore\n", "\n", "import time\n", - "from sagemaker_core.resources import TrainingJob, AlgorithmSpecification, Channel, DataSource, S3DataSource, \\\n", + "from sagemaker.core.resources import TrainingJob, AlgorithmSpecification, Channel, DataSource, S3DataSource, \\\n", " OutputDataConfig, ResourceConfig, StoppingCondition\n", "\n", "job_name_v3 = 'xgboost-iris-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", @@ -377,8 +377,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.shapes import ContainerDefinition, ProductionVariant\n", - "from sagemaker_core.resources import Model, EndpointConfig, Endpoint\n", + "from sagemaker.core.shapes import ContainerDefinition, ProductionVariant\n", + "from sagemaker.core.resources import Model, EndpointConfig, Endpoint\n", "from time import gmtime, strftime\n", "\n", "# Get model_data_url from training_job object\n", @@ -611,7 +611,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.shapes import ProductionVariant, AsyncInferenceConfig, AsyncInferenceOutputConfig, AsyncInferenceClientConfig\n", + "from sagemaker.core.shapes import ProductionVariant, AsyncInferenceConfig, AsyncInferenceOutputConfig, AsyncInferenceClientConfig\n", "\n", "async_endpoint_config = EndpointConfig.create(\n", " endpoint_config_name=key,\n", @@ -755,7 +755,7 @@ "# Delete any sagemaker core resource objects created in this notebook\n", "def delete_all_sagemaker_resources():\n", " all_objects = list(locals().values()) + list(globals().values())\n", - " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker_core.main.resources']\n", + " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker.core.resources']\n", " \n", " for obj in deletable_objects:\n", " obj.delete()\n", @@ -766,7 +766,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "py3.10", "language": "python", "name": "python3" }, diff --git a/example_notebooks/intelligent_defaults_and_logging.ipynb b/example_notebooks/intelligent_defaults_and_logging.ipynb index 535f1b9f..ff3bdeda 100644 --- a/example_notebooks/intelligent_defaults_and_logging.ipynb +++ b/example_notebooks/intelligent_defaults_and_logging.ipynb @@ -106,7 +106,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Uninstall previous version of sagemaker_core and restart kernel\n", + "# Uninstall previous version of sagemaker-core and restart kernel\n", "!pip uninstall sagemaker-core -y" ] }, @@ -116,7 +116,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Install the latest version of sagemaker_core\n", + "# Install the latest version of sagemaker-core\n", "\n", "!pip install sagemaker-core --upgrade" ] @@ -127,7 +127,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Check the version of sagemaker_core\n", + "# Check the version of sagemaker-core\n", "!pip show -v sagemaker-core" ] }, @@ -167,7 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.helper.session_helper import Session, get_execution_role\n", + "from sagemaker.core.helper.session_helper import Session, get_execution_role\n", "from rich import print\n", "\n", "# Get region, role, bucket\n", @@ -362,8 +362,8 @@ "outputs": [], "source": [ "import time\n", - "from sagemaker_core.resources import Cluster\n", - "from sagemaker_core.shapes import ClusterInstanceGroupSpecification, ClusterLifeCycleConfig\n", + "from sagemaker.core.resources import Cluster\n", + "from sagemaker.core.shapes import ClusterInstanceGroupSpecification, ClusterLifeCycleConfig\n", " \n", "cluster_name_v3 = 'xgboost-cluster-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", "\n", @@ -408,8 +408,8 @@ "outputs": [], "source": [ "import time\n", - "from sagemaker_core.resources import TrainingJob\n", - "from sagemaker_core.shapes import AlgorithmSpecification, Channel, DataSource, S3DataSource, ResourceConfig, StoppingCondition\n", + "from sagemaker.core.resources import TrainingJob\n", + "from sagemaker.core.shapes import AlgorithmSpecification, Channel, DataSource, S3DataSource, ResourceConfig, StoppingCondition\n", "\n", "job_name_v3 = 'xgboost-iris-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", "\n", @@ -481,7 +481,7 @@ "outputs": [], "source": [ "# Setting log_level to DEBUG using configure_logging with string parameter \n", - "from sagemaker_core.main.utils import configure_logging\n", + "from sagemaker.core.utils import configure_logging\n", "\n", "configure_logging('DEBUG')" ] @@ -493,7 +493,7 @@ "outputs": [], "source": [ "# Get TrainingJob with DEBUG log_level\n", - "from sagemaker_core.resources import TrainingJob\n", + "from sagemaker.core.resources import TrainingJob\n", "\n", "training_job = TrainingJob.get(job_name_v3)" ] @@ -549,7 +549,7 @@ "# Delete any sagemaker core resource objects created in this notebook\n", "def delete_all_sagemaker_resources():\n", " all_objects = list(locals().values()) + list(globals().values())\n", - " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker_core.main.resources']\n", + " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker.core.resources']\n", " \n", " for obj in deletable_objects:\n", " obj.delete()\n", diff --git a/example_notebooks/sagemaker-core-llama-3-8B-speculative-decoding.ipynb b/example_notebooks/sagemaker-core-llama-3-8B-speculative-decoding.ipynb index 71f1e597..9c791acb 100644 --- a/example_notebooks/sagemaker-core-llama-3-8B-speculative-decoding.ipynb +++ b/example_notebooks/sagemaker-core-llama-3-8B-speculative-decoding.ipynb @@ -18,7 +18,7 @@ "## 1. Dependency Installation\n", "### 1.1. Python Dependencies & Imports\n", "This notebook requires the following Python dependencies:\n", - "* AWS [`sagemaker_core`]()\n", + "* AWS [`sagemaker-core`]()\n", "\n", "Let's install or upgrade these dependencies using the following command:" ] @@ -46,7 +46,7 @@ "import json\n", "import os\n", "\n", - "from sagemaker_core.helper.session_helper import get_execution_role, Session\n", + "from sagemaker.core.helper.session_helper import get_execution_role, Session\n", "import pathlib \n", "import huggingface_hub" ] @@ -135,8 +135,8 @@ }, "outputs": [], "source": [ - "from sagemaker_core.shapes import ContainerDefinition, ProductionVariant\n", - "from sagemaker_core.resources import Model, EndpointConfig, Endpoint\n", + "from sagemaker.core.shapes import ContainerDefinition, ProductionVariant\n", + "from sagemaker.core.resources import Model, EndpointConfig, Endpoint\n", "from time import gmtime, strftime" ] }, @@ -194,7 +194,7 @@ }, "outputs": [], "source": [ - "from sagemaker_core.shapes import ProductionVariantRoutingConfig\n", + "from sagemaker.core.shapes import ProductionVariantRoutingConfig\n", "\n", "routing_config = ProductionVariantRoutingConfig(\n", " routing_strategy=\"LEAST_OUTSTANDING_REQUESTS\"\n", @@ -314,7 +314,7 @@ "# Delete any sagemaker core resource objects created in this notebook\n", "def delete_all_sagemaker_resources():\n", " all_objects = list(locals().values()) + list(globals().values())\n", - " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker_core.main.resources']\n", + " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker.core.resources']\n", " \n", " for obj in deletable_objects:\n", " obj.delete()\n", diff --git a/example_notebooks/sagemaker-core-llama-3-8B.ipynb b/example_notebooks/sagemaker-core-llama-3-8B.ipynb index 317a434f..78ee1e17 100644 --- a/example_notebooks/sagemaker-core-llama-3-8B.ipynb +++ b/example_notebooks/sagemaker-core-llama-3-8B.ipynb @@ -15,7 +15,7 @@ "## 1. Dependency Installation\n", "### 1.1. Python Dependencies & Imports\n", "This notebook requires the following Python dependencies:\n", - "* AWS [`sagemaker_core`]()\n", + "* AWS [`sagemaker-core`]()\n", "\n", "Let's install or upgrade these dependencies using the following command:" ] @@ -43,7 +43,7 @@ "import json\n", "import os\n", "\n", - "from sagemaker_core.helper.session_helper import get_execution_role, Session\n", + "from sagemaker.core.helper.session_helper import get_execution_role, Session\n", "import pathlib \n", "import huggingface_hub" ] @@ -124,8 +124,8 @@ }, "outputs": [], "source": [ - "from sagemaker_core.shapes import ContainerDefinition, ProductionVariant\n", - "from sagemaker_core.resources import Model, EndpointConfig, Endpoint\n", + "from sagemaker.core.shapes import ContainerDefinition, ProductionVariant\n", + "from sagemaker.core.resources import Model, EndpointConfig, Endpoint\n", "from time import gmtime, strftime" ] }, @@ -182,7 +182,7 @@ }, "outputs": [], "source": [ - "from sagemaker_core.shapes import ProductionVariantRoutingConfig\n", + "from sagemaker.core.shapes import ProductionVariantRoutingConfig\n", "\n", "routing_config = ProductionVariantRoutingConfig(\n", " routing_strategy=\"LEAST_OUTSTANDING_REQUESTS\"\n", @@ -302,7 +302,7 @@ "# Delete any sagemaker core resource objects created in this notebook\n", "def delete_all_sagemaker_resources():\n", " all_objects = list(locals().values()) + list(globals().values())\n", - " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker_core.main.resources']\n", + " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker.core.resources']\n", " \n", " for obj in deletable_objects:\n", " obj.delete()\n", diff --git a/example_notebooks/sagemaker_core_overview.ipynb b/example_notebooks/sagemaker_core_overview.ipynb index 58214c65..819a65c3 100644 --- a/example_notebooks/sagemaker_core_overview.ipynb +++ b/example_notebooks/sagemaker_core_overview.ipynb @@ -92,7 +92,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Uninstall previous version of sagemaker_core and restart kernel\n", + "# Uninstall previous version of sagemaker-core and restart kernel\n", "!pip uninstall sagemaker-core -y" ] }, @@ -102,7 +102,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Install the latest version of sagemaker_core\n", + "# Install the latest version of sagemaker-core\n", "\n", "!pip install sagemaker-core --upgrade" ] @@ -113,7 +113,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Check the version of sagemaker_core\n", + "# Check the version of sagemaker-core\n", "!pip show -v sagemaker-core" ] }, @@ -153,7 +153,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.helper.session_helper import Session, get_execution_role\n", + "from sagemaker.core.helper.session_helper import Session, get_execution_role\n", "from rich import print\n", "\n", "# Get region, role, bucket\n", @@ -382,7 +382,7 @@ "# Create TrainingJob with SageMakerCore\n", "\n", "import time\n", - "from sagemaker_core.resources import TrainingJob, AlgorithmSpecification, Channel, DataSource, S3DataSource, \\\n", + "from sagemaker.core.resources import TrainingJob, AlgorithmSpecification, Channel, DataSource, S3DataSource, \\\n", " OutputDataConfig, ResourceConfig, StoppingCondition\n", "\n", "job_name_v3 = 'xgboost-iris-' + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", @@ -531,7 +531,7 @@ "source": [ "# List TrainingJobs with SageMakerCore\n", "import datetime\n", - "from sagemaker_core.resources import TrainingJob\n", + "from sagemaker.core.resources import TrainingJob\n", "\n", "creation_time_after = datetime.datetime.now() - datetime.timedelta(days=1)\n", "\n", @@ -562,7 +562,7 @@ "# Delete any sagemaker core resource objects created in this notebook\n", "def delete_all_sagemaker_resources():\n", " all_objects = list(locals().values()) + list(globals().values())\n", - " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker_core.main.resources']\n", + " deletable_objects = [obj for obj in all_objects if hasattr(obj, 'delete') and obj.__class__.__module__ == 'sagemaker.core.resources']\n", " \n", " for obj in deletable_objects:\n", " obj.delete()\n", diff --git a/example_notebooks/track_local_pytorch_experiment.ipynb b/example_notebooks/track_local_pytorch_experiment.ipynb index b330e96e..39c27da7 100644 --- a/example_notebooks/track_local_pytorch_experiment.ipynb +++ b/example_notebooks/track_local_pytorch_experiment.ipynb @@ -49,8 +49,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Uninstall previous version of sagemaker_core and restart kernel\n", - "!pip uninstall sagemaker_core -y" + "# Uninstall previous version of sagemaker-core and restart kernel\n", + "!pip uninstall sagemaker-core -y" ] }, { @@ -59,7 +59,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Make dist/ directory to hold the sagemaker_core beta distribution file\n", + "# Make dist/ directory to hold the sagemaker-core beta distribution file\n", "!mkdir dist" ] }, @@ -69,7 +69,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Download and Install the latest version of sagemaker_core\n", + "# Download and Install the latest version of sagemaker-core\n", "!aws s3 cp s3://sagemaker-core-beta-artifacts/sagemaker_core-latest.tar.gz dist/\n", "\n", "!pip install dist/sagemaker_core-latest.tar.gz" @@ -81,8 +81,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Check the version of sagemaker_core\n", - "!pip show -v sagemaker_core" + "# Check the version of sagemaker-core\n", + "!pip show -v sagemaker-core" ] }, { @@ -125,8 +125,8 @@ "import os\n", "import time\n", "from matplotlib import pyplot as plt\n", - "from sagemaker_core.helper.session_helper import Session\n", - "from sagemaker_core.main.utils import get_textual_rich_logger\n", + "from sagemaker.core.helper.session_helper import Session\n", + "from sagemaker.core.utils import get_textual_rich_logger\n", "\n", "logger = get_textual_rich_logger(__name__)\n", "session = Session()\n", @@ -228,9 +228,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sagemaker_core.main.resources import Experiment, Trial, TrialComponent\n", - "from sagemaker_core.main.shapes import TrialComponentParameterValue, TrialComponentArtifact\n", - "from sagemaker_core.main.utils import configure_logging\n", + "from sagemaker.core.resources import Experiment, Trial, TrialComponent\n", + "from sagemaker.core.shapes import TrialComponentParameterValue, TrialComponentArtifact\n", "\n", "experiment = Experiment.create(experiment_name=experiment_name)\n", "trial = Trial.create(trial_name=run_group_name, experiment_name=experiment_name)\n", @@ -363,7 +362,7 @@ " \"\"\"\n", " Function that trains the CNN classifier to identify the MNIST digits.\n", " Args:\n", - " trial_component (sagemaker_core.main.resources.Run): SageMaker Experiment run object\n", + " trial_component (sagemaker.core.resources.Run): SageMaker Experiment run object\n", " train_set (torchvision.datasets.mnist.MNIST): train dataset\n", " test_set (torchvision.datasets.mnist.MNIST): test dataset\n", " data_dir (str): local directory where the MNIST datasource is stored\n", @@ -518,7 +517,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "py3.10", "language": "python", "name": "python3" }, diff --git a/integ/sagemaker_cleaner.py b/integ/sagemaker_cleaner.py index 722d10d9..6e753616 100644 --- a/integ/sagemaker_cleaner.py +++ b/integ/sagemaker_cleaner.py @@ -1,5 +1,5 @@ import datetime -from sagemaker_core.main.resources import Model, EndpointConfig, Endpoint +from sagemaker.core.resources import Model, EndpointConfig, Endpoint class SageMakerCleaner: diff --git a/integ/test_codegen.py b/integ/test_codegen.py index 34e449b7..89ae3213 100644 --- a/integ/test_codegen.py +++ b/integ/test_codegen.py @@ -10,8 +10,8 @@ from sklearn.model_selection import train_test_split from sagemaker_cleaner import handle_cleanup -from sagemaker_core.main.shapes import ContainerDefinition, ProductionVariant, ProfilerConfig -from sagemaker_core.main.resources import ( +from sagemaker.core.shapes import ContainerDefinition, ProductionVariant, ProfilerConfig +from sagemaker.core.resources import ( TrainingJob, AlgorithmSpecification, Channel, @@ -24,7 +24,7 @@ EndpointConfig, Endpoint, ) -from sagemaker_core.helper.session_helper import Session, get_execution_role +from sagemaker.core.helper.session_helper import Session, get_execution_role logger = logging.getLogger() diff --git a/pyproject.toml b/pyproject.toml index 6fe874a3..5166eb24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,4 +49,4 @@ line-length = 100 exclude = '\.ipynb$' [tool.setuptools.dynamic] -version = { attr = "sagemaker_core._version.__version__"} +version = { attr = "sagemaker.core._version.__version__"} diff --git a/src/sagemaker_core/helper/__init__.py b/src/sagemaker/__init__.py similarity index 100% rename from src/sagemaker_core/helper/__init__.py rename to src/sagemaker/__init__.py diff --git a/src/sagemaker/core/__init__.py b/src/sagemaker/core/__init__.py new file mode 100644 index 00000000..8d7103af --- /dev/null +++ b/src/sagemaker/core/__init__.py @@ -0,0 +1,4 @@ +from sagemaker.core.utils.utils import enable_textual_rich_console_and_traceback + + +enable_textual_rich_console_and_traceback() diff --git a/src/sagemaker_core/_version.py b/src/sagemaker/core/_version.py similarity index 100% rename from src/sagemaker_core/_version.py rename to src/sagemaker/core/_version.py diff --git a/src/sagemaker_core/main/config_schema.py b/src/sagemaker/core/config_schema.py similarity index 100% rename from src/sagemaker_core/main/config_schema.py rename to src/sagemaker/core/config_schema.py diff --git a/src/sagemaker_core/main/__init__.py b/src/sagemaker/core/helper/__init__.py similarity index 100% rename from src/sagemaker_core/main/__init__.py rename to src/sagemaker/core/helper/__init__.py diff --git a/src/sagemaker_core/helper/session_helper.py b/src/sagemaker/core/helper/session_helper.py similarity index 99% rename from src/sagemaker_core/helper/session_helper.py rename to src/sagemaker/core/helper/session_helper.py index 1bcf336e..1a2e71cb 100644 --- a/src/sagemaker_core/helper/session_helper.py +++ b/src/sagemaker/core/helper/session_helper.py @@ -564,7 +564,7 @@ def determine_bucket_and_prefix( Args: bucket (Optional[str]): S3 Bucket to use (if it exists) key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists) - sagemaker_session (sagemaker.session.Session): Session to fetch a default bucket and + sagemaker_session (sagemaker.core.session.Session): Session to fetch a default bucket and prefix from, if bucket doesn't exist. Expected to exist Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker/core/resources.py similarity index 99% rename from src/sagemaker_core/main/resources.py rename to src/sagemaker/core/resources.py index 80b2c183..19e04ed3 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker/core/resources.py @@ -23,9 +23,10 @@ from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn from rich.status import Status from rich.style import Style -from sagemaker_core.main.code_injection.codec import transform -from sagemaker_core.main.code_injection.constants import Color -from sagemaker_core.main.utils import ( +from sagemaker.core.shapes import * +from sagemaker.core.utils.code_injection.codec import transform +from sagemaker.core.utils.code_injection.constants import Color +from sagemaker.core.utils.utils import ( SageMakerClient, ResourceIterator, Unassigned, @@ -37,13 +38,12 @@ is_primitive_list, serialize, ) -from sagemaker_core.main.intelligent_defaults_helper import ( +from sagemaker.core.utils.intelligent_defaults_helper import ( load_default_configs_for_resource_name, get_config_value, ) -from sagemaker_core.main.logs import MultiLogStreamHandler -from sagemaker_core.main.shapes import * -from sagemaker_core.main.exceptions import * +from sagemaker.core.utils.logs import MultiLogStreamHandler +from sagemaker.core.utils.exceptions import * logger = get_textual_rich_logger(__name__) @@ -12720,6 +12720,8 @@ def create( hub_content_name: Optional[Union[str, object]] = Unassigned(), min_version: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, ) -> Optional["HubContentReference"]: """ Create a HubContentReference resource @@ -24161,6 +24163,8 @@ def create( arn: str, expires_in_seconds: Optional[int] = Unassigned(), session_expiration_duration_in_seconds: Optional[int] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, ) -> Optional["PartnerAppPresignedUrl"]: """ Create a PartnerAppPresignedUrl resource @@ -25419,6 +25423,8 @@ def create( expires_in_seconds: Optional[int] = Unassigned(), space_name: Optional[Union[str, object]] = Unassigned(), landing_uri: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, ) -> Optional["PresignedDomainUrl"]: """ Create a PresignedDomainUrl resource @@ -25516,6 +25522,8 @@ def create( tracking_server_name: str, expires_in_seconds: Optional[int] = Unassigned(), session_expiration_duration_in_seconds: Optional[int] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, ) -> Optional["PresignedMlflowTrackingServerUrl"]: """ Create a PresignedMlflowTrackingServerUrl resource @@ -25604,6 +25612,8 @@ def create( cls, notebook_instance_name: Union[str, object], session_expiration_duration_in_seconds: Optional[int] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, ) -> Optional["PresignedNotebookInstanceUrl"]: """ Create a PresignedNotebookInstanceUrl resource diff --git a/src/sagemaker_core/main/shapes.py b/src/sagemaker/core/shapes.py similarity index 99% rename from src/sagemaker_core/main/shapes.py rename to src/sagemaker/core/shapes.py index a1a96210..aba6f4fe 100644 --- a/src/sagemaker_core/main/shapes.py +++ b/src/sagemaker/core/shapes.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, ConfigDict from typing import List, Dict, Optional, Any, Union -from sagemaker_core.main.utils import Unassigned +from sagemaker.core.utils.utils import Unassigned class Base(BaseModel): diff --git a/src/sagemaker/core/tools/__init__.py b/src/sagemaker/core/tools/__init__.py new file mode 100644 index 00000000..0956ddc7 --- /dev/null +++ b/src/sagemaker/core/tools/__init__.py @@ -0,0 +1 @@ +from ..utils.code_injection.codec import pascal_to_snake diff --git a/src/sagemaker_core/tools/additional_operations.json b/src/sagemaker/core/tools/additional_operations.json similarity index 100% rename from src/sagemaker_core/tools/additional_operations.json rename to src/sagemaker/core/tools/additional_operations.json diff --git a/src/sagemaker_core/tools/api_coverage.json b/src/sagemaker/core/tools/api_coverage.json similarity index 100% rename from src/sagemaker_core/tools/api_coverage.json rename to src/sagemaker/core/tools/api_coverage.json diff --git a/src/sagemaker_core/tools/codegen.py b/src/sagemaker/core/tools/codegen.py similarity index 87% rename from src/sagemaker_core/tools/codegen.py rename to src/sagemaker/core/tools/codegen.py index e29a21b7..dc30101d 100644 --- a/src/sagemaker_core/tools/codegen.py +++ b/src/sagemaker/core/tools/codegen.py @@ -11,12 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Generates the code for the service model.""" -from sagemaker_core.main.utils import reformat_file_with_black -from sagemaker_core.tools.shapes_codegen import ShapesCodeGen -from sagemaker_core.tools.resources_codegen import ResourcesCodeGen +from sagemaker.core.utils.utils import reformat_file_with_black +from sagemaker.core.tools.shapes_codegen import ShapesCodeGen +from sagemaker.core.tools.resources_codegen import ResourcesCodeGen from typing import Optional -from sagemaker_core.tools.data_extractor import ServiceJsonData, load_service_jsons +from sagemaker.core.tools.data_extractor import ServiceJsonData, load_service_jsons def generate_code( diff --git a/src/sagemaker_core/tools/constants.py b/src/sagemaker/core/tools/constants.py similarity index 91% rename from src/sagemaker_core/tools/constants.py rename to src/sagemaker/core/tools/constants.py index bbfc3537..f0b39acc 100644 --- a/src/sagemaker_core/tools/constants.py +++ b/src/sagemaker/core/tools/constants.py @@ -46,7 +46,7 @@ BASIC_RETURN_TYPES = {"str", "int", "bool", "float", "datetime.datetime"} -SHAPE_DAG_FILE_PATH = os.getcwd() + "/src/sagemaker_core/main/code_injection/shape_dag.py" +SHAPE_DAG_FILE_PATH = os.getcwd() + "/src/sagemaker/core/utils/code_injection/shape_dag.py" PYTHON_TYPES_TO_BASIC_JSON_TYPES = { "str": "string", "int": "integer", @@ -77,7 +77,7 @@ # TODO: The file name should be injected, we should update it to be more generic ADDITIONAL_OPERATION_FILE_PATH = ( - os.getcwd() + "/src/sagemaker_core/tools/additional_operations.json" + os.getcwd() + "/src/sagemaker/core/tools/additional_operations.json" ) SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker/2017-07-24/service-2.json" RUNTIME_SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker-runtime/2017-05-13/service-2.json" @@ -86,7 +86,7 @@ ) METRICS_SERVICE_JSON_FILE_PATH = os.getcwd() + "/sample/sagemaker-metrics/2022-09-30/service-2.json" -GENERATED_CLASSES_LOCATION = os.getcwd() + "/src/sagemaker_core/main" +GENERATED_CLASSES_LOCATION = os.getcwd() + "/src/sagemaker/core" UTILS_CODEGEN_FILE_NAME = "utils.py" INTELLIGENT_DEFAULTS_HELPER_CODEGEN_FILE_NAME = "intelligent_defaults_helper.py" @@ -96,4 +96,4 @@ CONFIG_SCHEMA_FILE_NAME = "config_schema.py" -API_COVERAGE_JSON_FILE_PATH = os.getcwd() + "/src/sagemaker_core/tools/api_coverage.json" +API_COVERAGE_JSON_FILE_PATH = os.getcwd() + "/src/sagemaker/core/tools/api_coverage.json" diff --git a/src/sagemaker_core/tools/data_extractor.py b/src/sagemaker/core/tools/data_extractor.py similarity index 97% rename from src/sagemaker_core/tools/data_extractor.py rename to src/sagemaker/core/tools/data_extractor.py index 818fcdcc..21e31347 100644 --- a/src/sagemaker_core/tools/data_extractor.py +++ b/src/sagemaker/core/tools/data_extractor.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from sagemaker_core.tools.constants import ( +from sagemaker.core.tools.constants import ( ADDITIONAL_OPERATION_FILE_PATH, FEATURE_STORE_SERVICE_JSON_FILE_PATH, METRICS_SERVICE_JSON_FILE_PATH, diff --git a/src/sagemaker_core/tools/method.py b/src/sagemaker/core/tools/method.py similarity index 93% rename from src/sagemaker_core/tools/method.py rename to src/sagemaker/core/tools/method.py index 18f075fb..2f932c5a 100644 --- a/src/sagemaker_core/tools/method.py +++ b/src/sagemaker/core/tools/method.py @@ -1,6 +1,6 @@ from enum import Enum -from sagemaker_core.main.utils import remove_html_tags +from sagemaker.core.utils.utils import remove_html_tags class MethodType(Enum): diff --git a/src/sagemaker_core/tools/resource_plan.csv b/src/sagemaker/core/tools/resource_plan.csv similarity index 100% rename from src/sagemaker_core/tools/resource_plan.csv rename to src/sagemaker/core/tools/resource_plan.csv diff --git a/src/sagemaker_core/tools/resources_codegen.py b/src/sagemaker/core/tools/resources_codegen.py similarity index 98% rename from src/sagemaker_core/tools/resources_codegen.py rename to src/sagemaker/core/tools/resources_codegen.py index dca20be3..72fe55d1 100644 --- a/src/sagemaker_core/tools/resources_codegen.py +++ b/src/sagemaker/core/tools/resources_codegen.py @@ -15,11 +15,11 @@ import os import json -from sagemaker_core.main.code_injection.codec import pascal_to_snake -from sagemaker_core.main.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA -from sagemaker_core.main.exceptions import IntelligentDefaultsError -from sagemaker_core.main.utils import get_textual_rich_logger -from sagemaker_core.tools.constants import ( +from sagemaker.core.utils.code_injection.codec import pascal_to_snake +from sagemaker.core.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA +from sagemaker.core.utils.exceptions import IntelligentDefaultsError +from sagemaker.core.utils.utils import get_textual_rich_logger +from sagemaker.core.tools.constants import ( BASIC_RETURN_TYPES, GENERATED_CLASSES_LOCATION, RESOURCES_CODEGEN_FILE_NAME, @@ -31,17 +31,17 @@ CONFIGURABLE_ATTRIBUTE_SUBSTRINGS, RESOURCE_WITH_LOGS, ) -from sagemaker_core.tools.method import Method, MethodType -from sagemaker_core.main.utils import ( +from sagemaker.core.tools.method import Method, MethodType +from sagemaker.core.utils.utils import ( add_indent, convert_to_snake_case, snake_to_pascal, remove_html_tags, escape_special_rst_characters, ) -from sagemaker_core.tools.resources_extractor import ResourcesExtractor -from sagemaker_core.tools.shapes_extractor import ShapesExtractor -from sagemaker_core.tools.templates import ( +from sagemaker.core.tools.resources_extractor import ResourcesExtractor +from sagemaker.core.tools.shapes_extractor import ShapesExtractor +from sagemaker.core.tools.templates import ( CALL_OPERATION_API_NO_ARG_TEMPLATE, CALL_OPERATION_API_TEMPLATE, CREATE_METHOD_TEMPLATE, @@ -75,7 +75,7 @@ INIT_WAIT_LOGS_TEMPLATE, PRINT_WAIT_LOGS, ) -from sagemaker_core.tools.data_extractor import ( +from sagemaker.core.tools.data_extractor import ( load_combined_shapes_data, load_combined_operations_data, ) @@ -186,14 +186,14 @@ def generate_imports(self) -> str: "from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn", "from rich.status import Status", "from rich.style import Style", - "from sagemaker_core.main.code_injection.codec import transform", - "from sagemaker_core.main.code_injection.constants import Color", - "from sagemaker_core.main.utils import SageMakerClient, ResourceIterator, Unassigned, get_textual_rich_logger, " + "from sagemaker.core.shapes import *", + "from sagemaker.core.utils.code_injection.codec import transform", + "from sagemaker.core.utils.code_injection.constants import Color", + "from sagemaker.core.utils.utils import SageMakerClient, ResourceIterator, Unassigned, get_textual_rich_logger, " "snake_to_pascal, pascal_to_snake, is_not_primitive, is_not_str_dict, is_primitive_list, serialize", - "from sagemaker_core.main.intelligent_defaults_helper import load_default_configs_for_resource_name, get_config_value", - "from sagemaker_core.main.logs import MultiLogStreamHandler", - "from sagemaker_core.main.shapes import *", - "from sagemaker_core.main.exceptions import *", + "from sagemaker.core.utils.intelligent_defaults_helper import load_default_configs_for_resource_name, get_config_value", + "from sagemaker.core.utils.logs import MultiLogStreamHandler", + "from sagemaker.core.utils.exceptions import *", ] formated_imports = "\n".join(imports) @@ -908,11 +908,18 @@ def generate_create_method(self, resource_name: str, **kwargs) -> str: deserialize_response = DESERIALIZE_INPUT_AND_RESPONSE_TO_CLS_TEMPLATE.format( operation_output_shape=operation_output_shape_name ) + method_args = ( + add_indent("cls,\n", 4) + + create_args + + "\n" + + add_indent("session: Optional[Session] = None,\n", 4) + + add_indent("region: Optional[str] = None,", 4) + ) formatted_method = GENERIC_METHOD_TEMPLATE.format( docstring=docstring, decorator=decorator, method_name="create", - method_args=add_indent("cls,\n", 4) + create_args, + method_args=method_args, return_type=f'Optional["{resource_name}"]', serialize_operation_input=serialize_operation_input, initialize_client=initialize_client, diff --git a/src/sagemaker_core/tools/resources_extractor.py b/src/sagemaker/core/tools/resources_extractor.py similarity index 98% rename from src/sagemaker_core/tools/resources_extractor.py rename to src/sagemaker/core/tools/resources_extractor.py index 63179e1c..8a6f3bf6 100644 --- a/src/sagemaker_core/tools/resources_extractor.py +++ b/src/sagemaker/core/tools/resources_extractor.py @@ -15,14 +15,14 @@ import pandas as pd -from sagemaker_core.main.utils import get_textual_rich_logger -from sagemaker_core.tools.constants import CLASS_METHODS, OBJECT_METHODS -from sagemaker_core.tools.data_extractor import ( +from sagemaker.core.utils.utils import get_textual_rich_logger +from sagemaker.core.tools.constants import CLASS_METHODS, OBJECT_METHODS +from sagemaker.core.tools.data_extractor import ( load_additional_operations_data, load_combined_operations_data, load_combined_shapes_data, ) -from sagemaker_core.tools.method import Method +from sagemaker.core.tools.method import Method log = get_textual_rich_logger(__name__) """ diff --git a/src/sagemaker_core/tools/shapes_codegen.py b/src/sagemaker/core/tools/shapes_codegen.py similarity index 96% rename from src/sagemaker_core/tools/shapes_codegen.py rename to src/sagemaker/core/tools/shapes_codegen.py index 0e36a9e5..d0e6fc3a 100644 --- a/src/sagemaker_core/tools/shapes_codegen.py +++ b/src/sagemaker/core/tools/shapes_codegen.py @@ -17,21 +17,21 @@ """ import os -from sagemaker_core.main.code_injection.codec import pascal_to_snake -from sagemaker_core.tools.constants import ( +from sagemaker.core.utils.code_injection.codec import pascal_to_snake +from sagemaker.core.tools.constants import ( LICENCES_STRING, GENERATED_CLASSES_LOCATION, SHAPES_CODEGEN_FILE_NAME, ) -from sagemaker_core.tools.shapes_extractor import ShapesExtractor -from sagemaker_core.main.utils import ( +from sagemaker.core.tools.shapes_extractor import ShapesExtractor +from sagemaker.core.utils.utils import ( add_indent, convert_to_snake_case, remove_html_tags, escape_special_rst_characters, ) -from sagemaker_core.tools.templates import SHAPE_CLASS_TEMPLATE, SHAPE_BASE_CLASS_TEMPLATE -from sagemaker_core.tools.data_extractor import ( +from sagemaker.core.tools.templates import SHAPE_CLASS_TEMPLATE, SHAPE_BASE_CLASS_TEMPLATE +from sagemaker.core.tools.data_extractor import ( load_combined_shapes_data, load_combined_operations_data, ) @@ -206,7 +206,7 @@ def generate_imports(self): imports += "\n" imports += "from pydantic import BaseModel, ConfigDict\n" imports += "from typing import List, Dict, Optional, Any, Union\n" - imports += "from sagemaker_core.main.utils import Unassigned" + imports += "from sagemaker.core.utils.utils import Unassigned" imports += "\n" return imports diff --git a/src/sagemaker_core/tools/shapes_extractor.py b/src/sagemaker/core/tools/shapes_extractor.py similarity index 98% rename from src/sagemaker_core/tools/shapes_extractor.py rename to src/sagemaker/core/tools/shapes_extractor.py index 7ca2a8d0..69ce7ca5 100644 --- a/src/sagemaker_core/tools/shapes_extractor.py +++ b/src/sagemaker/core/tools/shapes_extractor.py @@ -16,13 +16,13 @@ from functools import lru_cache from typing import Optional, Any -from sagemaker_core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH -from sagemaker_core.main.utils import ( +from sagemaker.core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH +from sagemaker.core.utils.utils import ( reformat_file_with_black, convert_to_snake_case, snake_to_pascal, ) -from sagemaker_core.tools.data_extractor import load_combined_shapes_data +from sagemaker.core.tools.data_extractor import load_combined_shapes_data class ShapesExtractor: diff --git a/src/sagemaker_core/tools/templates.py b/src/sagemaker/core/tools/templates.py similarity index 100% rename from src/sagemaker_core/tools/templates.py rename to src/sagemaker/core/tools/templates.py diff --git a/src/sagemaker_core/main/code_injection/__init__.py b/src/sagemaker/core/utils/__init__.py similarity index 100% rename from src/sagemaker_core/main/code_injection/__init__.py rename to src/sagemaker/core/utils/__init__.py diff --git a/src/sagemaker/core/utils/code_injection/__init__.py b/src/sagemaker/core/utils/code_injection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/sagemaker_core/main/code_injection/base.py b/src/sagemaker/core/utils/code_injection/base.py similarity index 100% rename from src/sagemaker_core/main/code_injection/base.py rename to src/sagemaker/core/utils/code_injection/base.py diff --git a/src/sagemaker_core/main/code_injection/codec.py b/src/sagemaker/core/utils/code_injection/codec.py similarity index 98% rename from src/sagemaker_core/main/code_injection/codec.py rename to src/sagemaker/core/utils/code_injection/codec.py index 286f7a16..905aa2c4 100644 --- a/src/sagemaker_core/main/code_injection/codec.py +++ b/src/sagemaker/core/utils/code_injection/codec.py @@ -15,8 +15,8 @@ from dataclasses import asdict import re -from sagemaker_core.main.code_injection.shape_dag import SHAPE_DAG -from sagemaker_core.main.code_injection.constants import ( +from sagemaker.core.utils.code_injection.shape_dag import SHAPE_DAG +from sagemaker.core.utils.code_injection.constants import ( BASIC_TYPES, STRUCTURE_TYPE, LIST_TYPE, diff --git a/src/sagemaker_core/main/code_injection/constants.py b/src/sagemaker/core/utils/code_injection/constants.py similarity index 100% rename from src/sagemaker_core/main/code_injection/constants.py rename to src/sagemaker/core/utils/code_injection/constants.py diff --git a/src/sagemaker_core/main/code_injection/shape_dag.py b/src/sagemaker/core/utils/code_injection/shape_dag.py similarity index 100% rename from src/sagemaker_core/main/code_injection/shape_dag.py rename to src/sagemaker/core/utils/code_injection/shape_dag.py diff --git a/src/sagemaker_core/main/exceptions.py b/src/sagemaker/core/utils/exceptions.py similarity index 100% rename from src/sagemaker_core/main/exceptions.py rename to src/sagemaker/core/utils/exceptions.py diff --git a/src/sagemaker_core/main/intelligent_defaults_helper.py b/src/sagemaker/core/utils/intelligent_defaults_helper.py similarity index 97% rename from src/sagemaker_core/main/intelligent_defaults_helper.py rename to src/sagemaker/core/utils/intelligent_defaults_helper.py index c6f08b7e..3f411ee7 100644 --- a/src/sagemaker_core/main/intelligent_defaults_helper.py +++ b/src/sagemaker/core/utils/intelligent_defaults_helper.py @@ -25,14 +25,14 @@ from botocore.utils import merge_dicts from six.moves.urllib.parse import urlparse -from sagemaker_core.main.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA -from sagemaker_core.main.exceptions import ( +from sagemaker.core.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA +from sagemaker.core.utils.exceptions import ( LocalConfigNotFoundError, S3ConfigNotFoundError, IntelligentDefaultsError, ConfigSchemaValidationError, ) -from sagemaker_core.main.utils import get_textual_rich_logger +from sagemaker.core.utils.utils import get_textual_rich_logger logger = get_textual_rich_logger(__name__) diff --git a/src/sagemaker_core/main/logs.py b/src/sagemaker/core/utils/logs.py similarity index 99% rename from src/sagemaker_core/main/logs.py rename to src/sagemaker/core/utils/logs.py index 1dfc9c76..ef91bfae 100644 --- a/src/sagemaker_core/main/logs.py +++ b/src/sagemaker/core/utils/logs.py @@ -5,7 +5,7 @@ import botocore.client from botocore.config import Config from typing import Generator, Tuple, List -from sagemaker_core.main.utils import SingletonMeta +from sagemaker.core.utils.utils import SingletonMeta class CloudWatchLogsClient(metaclass=SingletonMeta): diff --git a/src/sagemaker_core/main/user_agent.py b/src/sagemaker/core/utils/user_agent.py similarity index 100% rename from src/sagemaker_core/main/user_agent.py rename to src/sagemaker/core/utils/user_agent.py diff --git a/src/sagemaker_core/main/utils.py b/src/sagemaker/core/utils/utils.py similarity index 98% rename from src/sagemaker_core/main/utils.py rename to src/sagemaker/core/utils/utils.py index e4ec3871..83f823b0 100644 --- a/src/sagemaker_core/main/utils.py +++ b/src/sagemaker/core/utils/utils.py @@ -26,9 +26,9 @@ from rich.theme import Theme from rich.traceback import install from typing import Any, Dict, List, TypeVar, Generic, Type -from sagemaker_core.main.code_injection.codec import transform -from sagemaker_core.main.code_injection.constants import Color -from sagemaker_core.main.user_agent import get_user_agent_extra_suffix +from sagemaker.core.utils.code_injection.codec import transform +from sagemaker.core.utils.code_injection.constants import Color +from sagemaker.core.utils.user_agent import get_user_agent_extra_suffix def add_indent(text, num_spaces=4): diff --git a/src/sagemaker_core/__init__.py b/src/sagemaker_core/__init__.py deleted file mode 100644 index d15eedfe..00000000 --- a/src/sagemaker_core/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from sagemaker_core.main.utils import enable_textual_rich_console_and_traceback - - -enable_textual_rich_console_and_traceback() diff --git a/src/sagemaker_core/resources/__init__.py b/src/sagemaker_core/resources/__init__.py deleted file mode 100644 index 415fe247..00000000 --- a/src/sagemaker_core/resources/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ..main.resources import * diff --git a/src/sagemaker_core/shapes/__init__.py b/src/sagemaker_core/shapes/__init__.py deleted file mode 100644 index e3bcec64..00000000 --- a/src/sagemaker_core/shapes/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ..main.shapes import * diff --git a/src/sagemaker_core/tools/__init__.py b/src/sagemaker_core/tools/__init__.py deleted file mode 100644 index 7d05a151..00000000 --- a/src/sagemaker_core/tools/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ..main.code_injection.codec import pascal_to_snake diff --git a/tst/generated/test_config_schema.py b/tst/generated/test_config_schema.py index b6271322..29cabe19 100644 --- a/tst/generated/test_config_schema.py +++ b/tst/generated/test_config_schema.py @@ -1,4 +1,4 @@ -from sagemaker_core.main.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA +from sagemaker.core.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA def test_config_schema(): diff --git a/tst/generated/test_logs.py b/tst/generated/test_logs.py index 49fd5ca6..2904d372 100644 --- a/tst/generated/test_logs.py +++ b/tst/generated/test_logs.py @@ -1,7 +1,7 @@ import botocore import pytest from unittest.mock import patch, MagicMock -from sagemaker_core.main.logs import LogStreamHandler, MultiLogStreamHandler +from sagemaker.core.utils.logs import LogStreamHandler, MultiLogStreamHandler def test_single_stream_handler_get_latest(): @@ -36,7 +36,7 @@ def test_single_stream_handler_get_latest(): next(events) -@patch("sagemaker_core.main.logs.MultiLogStreamHandler.ready", autospec=True) +@patch("sagemaker.core.utils.logs.MultiLogStreamHandler.ready", autospec=True) def test_multi_stream_handler_get_latest(mock_ready): mock_ready.return_value = True diff --git a/tst/generated/test_resources.py b/tst/generated/test_resources.py index 081247ec..002dcb83 100644 --- a/tst/generated/test_resources.py +++ b/tst/generated/test_resources.py @@ -4,13 +4,13 @@ import unittest from unittest.mock import patch -from sagemaker_core.main.resources import Base +from sagemaker.core.resources import Base -from sagemaker_core.main.code_injection.codec import pascal_to_snake, snake_to_pascal +from sagemaker.core.utils.code_injection.codec import pascal_to_snake, snake_to_pascal -from sagemaker_core.main.utils import SageMakerClient, serialize -from sagemaker_core.tools.constants import BASIC_RETURN_TYPES -from sagemaker_core.tools.data_extractor import ( +from sagemaker.core.utils.utils import SageMakerClient, serialize +from sagemaker.core.tools.constants import BASIC_RETURN_TYPES +from sagemaker.core.tools.data_extractor import ( load_additional_operations_data, load_combined_operations_data, load_combined_shapes_data, @@ -34,21 +34,21 @@ class ResourcesTest(unittest.TestCase): def setUp(self) -> None: for name, cls in inspect.getmembers( - importlib.import_module("sagemaker_core.main.resources"), inspect.isclass + importlib.import_module("sagemaker.core.resources"), inspect.isclass ): - if cls.__module__ == "sagemaker_core.main.resources": + if cls.__module__ == "sagemaker.core.resources": if hasattr(cls, "get") and callable(cls.get): self.MOCK_RESOURCES_RESPONSE_BY_RESOURCE_NAME[name] = ( self._get_required_parameters_for_function(cls.get) ) for shape_name, shape_cls in inspect.getmembers( - importlib.import_module("sagemaker_core.main.shapes"), inspect.isclass + importlib.import_module("sagemaker.core.shapes"), inspect.isclass ): - if shape_cls.__module__ == "sagemaker_core.main.shapes": + if shape_cls.__module__ == "sagemaker.core.shapes": self.SHAPE_CLASSES_BY_SHAPE_NAME[shape_name] = shape_cls - @patch("sagemaker_core.main.resources.transform") + @patch("sagemaker.core.resources.transform") @patch("boto3.session.Session") def test_resources(self, session, mock_transform): report = { @@ -63,9 +63,9 @@ def test_resources(self, session, mock_transform): resources = set() client = SageMakerClient(session=session).get_client(service_name="sagemaker") for name, cls in inspect.getmembers( - importlib.import_module("sagemaker_core.main.resources"), inspect.isclass + importlib.import_module("sagemaker.core.resources"), inspect.isclass ): - if cls.__module__ == "sagemaker_core.main.resources": + if cls.__module__ == "sagemaker.core.resources": print_string = f"Running the following tests for resource {name}:" resources.add(name) if hasattr(cls, "get") and callable(cls.get): diff --git a/tst/generated/test_shapes.py b/tst/generated/test_shapes.py index 45b346fd..87e36377 100644 --- a/tst/generated/test_shapes.py +++ b/tst/generated/test_shapes.py @@ -3,9 +3,9 @@ from pydantic import BaseModel, ValidationError -from sagemaker_core.main.shapes import Base, AdditionalS3DataSource -from sagemaker_core.main.utils import Unassigned -from sagemaker_core.tools.constants import GENERATED_CLASSES_LOCATION, SHAPES_CODEGEN_FILE_NAME +from sagemaker.core.shapes import Base, AdditionalS3DataSource +from sagemaker.core.utils.utils import Unassigned +from sagemaker.core.tools.constants import GENERATED_CLASSES_LOCATION, SHAPES_CODEGEN_FILE_NAME FILE_NAME = GENERATED_CLASSES_LOCATION + "/" + SHAPES_CODEGEN_FILE_NAME diff --git a/tst/generated/test_user_agent.py b/tst/generated/test_user_agent.py index befb2e98..53b9e6cf 100644 --- a/tst/generated/test_user_agent.py +++ b/tst/generated/test_user_agent.py @@ -16,7 +16,7 @@ from mock import patch, mock_open -from sagemaker_core.main.user_agent import ( +from sagemaker.core.utils.user_agent import ( SagemakerCore_PREFIX, SagemakerCore_VERSION, NOTEBOOK_PREFIX, @@ -25,7 +25,7 @@ process_studio_metadata_file, get_user_agent_extra_suffix, ) -from sagemaker_core.main.user_agent import SagemakerCore_PREFIX +from sagemaker.core.utils.user_agent import SagemakerCore_PREFIX # Test process_notebook_metadata_file function @@ -63,7 +63,7 @@ def test_get_user_agent_extra_suffix(): assert get_user_agent_extra_suffix() == f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION}" with patch( - "sagemaker_core.main.user_agent.process_notebook_metadata_file", + "sagemaker.core.utils.user_agent.process_notebook_metadata_file", return_value="instance_type", ): assert ( @@ -72,7 +72,7 @@ def test_get_user_agent_extra_suffix(): ) with patch( - "sagemaker_core.main.user_agent.process_studio_metadata_file", return_value="studio_type" + "sagemaker.core.utils.user_agent.process_studio_metadata_file", return_value="studio_type" ): assert ( get_user_agent_extra_suffix() diff --git a/tst/generated/test_utils.py b/tst/generated/test_utils.py index c847b7e3..de67ac55 100644 --- a/tst/generated/test_utils.py +++ b/tst/generated/test_utils.py @@ -2,13 +2,13 @@ import datetime import logging from unittest.mock import Mock, patch, call -from sagemaker_core.main.resources import TrainingJob, DataQualityJobDefinition -from sagemaker_core.main.shapes import ( +from sagemaker.core.resources import TrainingJob, DataQualityJobDefinition +from sagemaker.core.shapes import ( AdditionalS3DataSource, TrialComponent, TrialComponentParameterValue, ) -from sagemaker_core.main.utils import * +from sagemaker.core.utils.utils import * LIST_TRAINING_JOB_RESPONSE_WITH_NEXT_TOKEN = { diff --git a/tst/test_codec.py b/tst/test_codec.py index 4695e5c4..7fa0dd92 100644 --- a/tst/test_codec.py +++ b/tst/test_codec.py @@ -2,9 +2,9 @@ from dateutil.tz import tzlocal from pprint import pprint import unittest -from sagemaker_core.main.code_injection.codec import pascal_to_snake -from sagemaker_core.main.code_injection.codec import transform -from sagemaker_core.main.resources import Model, TrialComponent, AutoMLJobV2 +from sagemaker.core.utils.code_injection.codec import pascal_to_snake +from sagemaker.core.utils.code_injection.codec import transform +from sagemaker.core.resources import Model, TrialComponent, AutoMLJobV2 class TestConversion(unittest.TestCase): diff --git a/tst/tools/test_api_coverage.py b/tst/tools/test_api_coverage.py index ae33552a..4bd9d329 100644 --- a/tst/tools/test_api_coverage.py +++ b/tst/tools/test_api_coverage.py @@ -1,7 +1,7 @@ import json -from sagemaker_core.tools.constants import API_COVERAGE_JSON_FILE_PATH -from sagemaker_core.tools.resources_extractor import ResourcesExtractor +from sagemaker.core.tools.constants import API_COVERAGE_JSON_FILE_PATH +from sagemaker.core.tools.resources_extractor import ResourcesExtractor class TestAPICoverage: diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index 1dfcf78e..dccb54b1 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1,7 +1,7 @@ import json -from sagemaker_core.tools.method import Method -from sagemaker_core.tools.resources_codegen import ResourcesCodeGen -from sagemaker_core.tools.constants import SERVICE_JSON_FILE_PATH +from sagemaker.core.tools.method import Method +from sagemaker.core.tools.resources_codegen import ResourcesCodeGen +from sagemaker.core.tools.constants import SERVICE_JSON_FILE_PATH class TestGenerateResource: @@ -1500,6 +1500,8 @@ def create( expires_in_seconds: Optional[int] = Unassigned(), space_name: Optional[Union[str, object]] = Unassigned(), landing_uri: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, ) -> Optional["PresignedDomainUrl"]: """ Create a PresignedDomainUrl resource diff --git a/workflow_helper/compute_boto_api_coverage.py b/workflow_helper/compute_boto_api_coverage.py index 474f4914..9e3a9ecd 100644 --- a/workflow_helper/compute_boto_api_coverage.py +++ b/workflow_helper/compute_boto_api_coverage.py @@ -1,5 +1,5 @@ -from sagemaker_core.main.utils import configure_logging -from sagemaker_core.tools.resources_extractor import ResourcesExtractor +from sagemaker.core.utils.utils import configure_logging +from sagemaker.core.tools.resources_extractor import ResourcesExtractor def main(): diff --git a/workflow_helper/compute_resource_coverage.py b/workflow_helper/compute_resource_coverage.py index 46418d59..9365dfe1 100644 --- a/workflow_helper/compute_resource_coverage.py +++ b/workflow_helper/compute_resource_coverage.py @@ -10,7 +10,7 @@ def main(): with open(json_file, "r") as f: data = json.load(f) - print(data["files"]["src/sagemaker_core/main/resources.py"]["summary"]["percent_covered"]) + print(data["files"]["src/sagemaker/core/resources.py"]["summary"]["percent_covered"]) if __name__ == "__main__": From 6c8cedf907f2c1d5319f0ec5469a9185a6504083 Mon Sep 17 00:00:00 2001 From: pintaoz Date: Wed, 26 Mar 2025 13:53:02 -0700 Subject: [PATCH 2/4] Fix version path --- src/sagemaker/core/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/core/_version.py b/src/sagemaker/core/_version.py index ba7131b7..f79bf800 100644 --- a/src/sagemaker/core/_version.py +++ b/src/sagemaker/core/_version.py @@ -3,7 +3,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) # Get the root directory of the project -root_dir = os.path.abspath(os.path.join(script_dir, "..", "..")) +root_dir = os.path.abspath(os.path.join(script_dir, "..", "..", "..")) version_file_path = os.path.join(root_dir, "VERSION") From 179839e4e5865a703e0d1c757889694b0c8a864e Mon Sep 17 00:00:00 2001 From: pintaoz Date: Wed, 26 Mar 2025 14:44:37 -0700 Subject: [PATCH 3/4] use logger instead of pring --- src/sagemaker/core/resources.py | 10 +++++----- src/sagemaker/core/tools/templates.py | 2 +- tst/tools/test_resources_codegen.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/core/resources.py b/src/sagemaker/core/resources.py index 19e04ed3..55940496 100644 --- a/src/sagemaker/core/resources.py +++ b/src/sagemaker/core/resources.py @@ -1346,7 +1346,7 @@ def wait_for_delete( status.update(f"Current status: [bold]{current_status}") if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return if timeout is not None and time.time() - start_time >= timeout: @@ -4344,7 +4344,7 @@ def wait_for_delete( status.update(f"Current status: [bold]{current_status}") if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return if timeout is not None and time.time() - start_time >= timeout: @@ -5656,7 +5656,7 @@ def wait_for_delete( status.update(f"Current status: [bold]{current_status}") if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return if timeout is not None and time.time() - start_time >= timeout: @@ -16350,7 +16350,7 @@ def wait_for_delete( status.update(f"Current status: [bold]{current_status}") if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return if timeout is not None and time.time() - start_time >= timeout: @@ -24078,7 +24078,7 @@ def wait_for_delete( status.update(f"Current status: [bold]{current_status}") if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return if timeout is not None and time.time() - start_time >= timeout: diff --git a/src/sagemaker/core/tools/templates.py b/src/sagemaker/core/tools/templates.py index 436a7984..08d2ba12 100644 --- a/src/sagemaker/core/tools/templates.py +++ b/src/sagemaker/core/tools/templates.py @@ -454,7 +454,7 @@ def wait_for_delete( DELETED_STATUS_CHECK = """ if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return """ diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index dccb54b1..09a74a80 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1856,7 +1856,7 @@ def wait_for_delete( if current_status.lower() == "deleted": - print("Resource was deleted.") + logger.info("Resource was deleted.") return From 6d74ce9c09f54176efaab4c0203c09cbc80de9fc Mon Sep 17 00:00:00 2001 From: pintaoz Date: Mon, 31 Mar 2025 15:39:51 -0700 Subject: [PATCH 4/4] Generate start methods --- resource_plan.csv | 10 +- src/sagemaker/core/resources.py | 218 ++++++++++++++++++ src/sagemaker/core/tools/constants.py | 4 +- src/sagemaker/core/tools/resources_codegen.py | 63 +++++ .../core/tools/resources_extractor.py | 10 +- 5 files changed, 289 insertions(+), 16 deletions(-) diff --git a/resource_plan.csv b/resource_plan.csv index 66884a61..93b42684 100644 --- a/resource_plan.csv +++ b/resource_plan.csv @@ -33,11 +33,11 @@ HyperParameterTuningJob,resource,"['create', 'get', 'get_all']","['delete', 'ref Image,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreateImage', 'DeleteImage', 'DescribeImage', 'ListImages', 'UpdateImage']","[{'name': 'ImageStatus', 'shape_name': 'ImageStatus'}]","['CREATING', 'CREATED', 'CREATE_FAILED', 'UPDATING', 'UPDATE_FAILED', 'DELETING', 'DELETE_FAILED']" ImageVersion,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update', 'wait_for_delete', 'wait_for_status']",['Image'],[],"['CreateImageVersion', 'DeleteImageVersion', 'DescribeImageVersion', 'ListImageVersions', 'UpdateImageVersion']","[{'name': 'ImageVersionStatus', 'shape_name': 'ImageVersionStatus'}]","['CREATING', 'CREATED', 'CREATE_FAILED', 'DELETING', 'DELETE_FAILED']" InferenceComponent,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update', 'wait_for_delete', 'wait_for_status']",['Endpoint'],[],"['CreateInferenceComponent', 'DeleteInferenceComponent', 'DescribeInferenceComponent', 'ListInferenceComponents', 'UpdateInferenceComponent']","[{'name': 'InferenceComponentStatus', 'shape_name': 'InferenceComponentStatus'}]","['InService', 'Creating', 'Updating', 'Failed', 'Deleting']" -InferenceExperiment,resource,"['create', 'get', 'get_all', 'start']","['delete', 'refresh', 'stop', 'update', 'wait_for_status']",['Endpoint'],[],"['CreateInferenceExperiment', 'DeleteInferenceExperiment', 'DescribeInferenceExperiment', 'ListInferenceExperiments', 'StartInferenceExperiment', 'StopInferenceExperiment', 'UpdateInferenceExperiment']","[{'name': 'Status', 'shape_name': 'InferenceExperimentStatus'}]","['Creating', 'Created', 'Updating', 'Running', 'Starting', 'Stopping', 'Completed', 'Cancelled']" +InferenceExperiment,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'start', 'stop', 'update', 'wait_for_status']",['Endpoint'],[],"['CreateInferenceExperiment', 'DeleteInferenceExperiment', 'DescribeInferenceExperiment', 'ListInferenceExperiments', 'StartInferenceExperiment', 'StopInferenceExperiment', 'UpdateInferenceExperiment']","[{'name': 'Status', 'shape_name': 'InferenceExperimentStatus'}]","['Creating', 'Created', 'Updating', 'Running', 'Starting', 'Stopping', 'Completed', 'Cancelled']" InferenceRecommendationsJob,resource,"['create', 'get', 'get_all']","['refresh', 'stop', 'wait', 'wait_for_delete']",[],[],"['CreateInferenceRecommendationsJob', 'DescribeInferenceRecommendationsJob', 'ListInferenceRecommendationsJobs', 'StopInferenceRecommendationsJob']","[{'name': 'Status', 'shape_name': 'RecommendationJobStatus'}]","['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED', 'STOPPING', 'STOPPED', 'DELETING', 'DELETED']" LabelingJob,resource,"['create', 'get', 'get_all']","['refresh', 'stop', 'wait']",[],[],"['CreateLabelingJob', 'DescribeLabelingJob', 'ListLabelingJobs', 'StopLabelingJob']","[{'name': 'LabelingJobStatus', 'shape_name': 'LabelingJobStatus'}]","['Initializing', 'InProgress', 'Completed', 'Failed', 'Stopping', 'Stopped']" LineageGroup,resource,"['get', 'get_all']",['refresh'],[],[],"['DescribeLineageGroup', 'ListLineageGroups']",[],[] -MlflowTrackingServer,resource,"['create', 'get', 'get_all', 'start']","['delete', 'refresh', 'stop', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreateMlflowTrackingServer', 'DeleteMlflowTrackingServer', 'DescribeMlflowTrackingServer', 'ListMlflowTrackingServers', 'StartMlflowTrackingServer', 'StopMlflowTrackingServer', 'UpdateMlflowTrackingServer']","[{'name': 'TrackingServerStatus', 'shape_name': 'TrackingServerStatus'}]","['Creating', 'Created', 'CreateFailed', 'Updating', 'Updated', 'UpdateFailed', 'Deleting', 'DeleteFailed', 'Stopping', 'Stopped', 'StopFailed', 'Starting', 'Started', 'StartFailed', 'MaintenanceInProgress', 'MaintenanceComplete', 'MaintenanceFailed']" +MlflowTrackingServer,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'start', 'stop', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreateMlflowTrackingServer', 'DeleteMlflowTrackingServer', 'DescribeMlflowTrackingServer', 'ListMlflowTrackingServers', 'StartMlflowTrackingServer', 'StopMlflowTrackingServer', 'UpdateMlflowTrackingServer']","[{'name': 'TrackingServerStatus', 'shape_name': 'TrackingServerStatus'}]","['Creating', 'Created', 'CreateFailed', 'Updating', 'Updated', 'UpdateFailed', 'Deleting', 'DeleteFailed', 'Stopping', 'Stopped', 'StopFailed', 'Starting', 'Started', 'StartFailed', 'MaintenanceInProgress', 'MaintenanceComplete', 'MaintenanceFailed']" Model,resource,"['create', 'get', 'get_all']","['delete', 'refresh']",[],[],"['CreateModel', 'DeleteModel', 'DescribeModel', 'ListModels']",[],[] ModelBiasJobDefinition,resource,"['create', 'get', 'get_all']","['delete', 'refresh']",[],[],"['CreateModelBiasJobDefinition', 'DeleteModelBiasJobDefinition', 'DescribeModelBiasJobDefinition', 'ListModelBiasJobDefinitions']",[],[] ModelCard,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update', 'wait_for_status']",[],[],"['CreateModelCard', 'DeleteModelCard', 'DescribeModelCard', 'ListModelCards', 'UpdateModelCard']","[{'name': 'ModelCardStatus', 'shape_name': 'ModelCardStatus'}]","['Draft', 'PendingReview', 'Approved', 'Archived']" @@ -48,14 +48,14 @@ ModelPackageGroup,resource,"['create', 'get', 'get_all']","['delete', 'refresh', ModelQualityJobDefinition,resource,"['create', 'get', 'get_all']","['delete', 'refresh']",[],[],"['CreateModelQualityJobDefinition', 'DeleteModelQualityJobDefinition', 'DescribeModelQualityJobDefinition', 'ListModelQualityJobDefinitions']",[],[] MonitoringAlert,resource,['get_all'],['update'],[],[],"['ListMonitoringAlerts', 'UpdateMonitoringAlert']",[],[] MonitoringExecution,resource,['get_all'],[],[],[],['ListMonitoringExecutions'],[],[] -MonitoringSchedule,resource,"['create', 'get', 'get_all', 'start']","['delete', 'refresh', 'stop', 'update', 'wait_for_status']",[],[],"['CreateMonitoringSchedule', 'DeleteMonitoringSchedule', 'DescribeMonitoringSchedule', 'ListMonitoringSchedules', 'StartMonitoringSchedule', 'StopMonitoringSchedule', 'UpdateMonitoringSchedule']","[{'name': 'MonitoringScheduleStatus', 'shape_name': 'ScheduleStatus'}]","['Pending', 'Failed', 'Scheduled', 'Stopped']" -NotebookInstance,resource,"['create', 'get', 'get_all', 'start']","['delete', 'refresh', 'stop', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreateNotebookInstance', 'DeleteNotebookInstance', 'DescribeNotebookInstance', 'ListNotebookInstances', 'StartNotebookInstance', 'StopNotebookInstance', 'UpdateNotebookInstance']","[{'name': 'NotebookInstanceStatus', 'shape_name': 'NotebookInstanceStatus'}]","['Pending', 'InService', 'Stopping', 'Stopped', 'Failed', 'Deleting', 'Updating']" +MonitoringSchedule,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'start', 'stop', 'update', 'wait_for_status']",[],[],"['CreateMonitoringSchedule', 'DeleteMonitoringSchedule', 'DescribeMonitoringSchedule', 'ListMonitoringSchedules', 'StartMonitoringSchedule', 'StopMonitoringSchedule', 'UpdateMonitoringSchedule']","[{'name': 'MonitoringScheduleStatus', 'shape_name': 'ScheduleStatus'}]","['Pending', 'Failed', 'Scheduled', 'Stopped']" +NotebookInstance,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'start', 'stop', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreateNotebookInstance', 'DeleteNotebookInstance', 'DescribeNotebookInstance', 'ListNotebookInstances', 'StartNotebookInstance', 'StopNotebookInstance', 'UpdateNotebookInstance']","[{'name': 'NotebookInstanceStatus', 'shape_name': 'NotebookInstanceStatus'}]","['Pending', 'InService', 'Stopping', 'Stopped', 'Failed', 'Deleting', 'Updating']" NotebookInstanceLifecycleConfig,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update']",[],[],"['CreateNotebookInstanceLifecycleConfig', 'DeleteNotebookInstanceLifecycleConfig', 'DescribeNotebookInstanceLifecycleConfig', 'ListNotebookInstanceLifecycleConfigs', 'UpdateNotebookInstanceLifecycleConfig']",[],[] OptimizationJob,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'stop', 'wait']",[],[],"['CreateOptimizationJob', 'DeleteOptimizationJob', 'DescribeOptimizationJob', 'ListOptimizationJobs', 'StopOptimizationJob']","[{'name': 'OptimizationJobStatus', 'shape_name': 'OptimizationJobStatus'}]","['INPROGRESS', 'COMPLETED', 'FAILED', 'STARTING', 'STOPPING', 'STOPPED']" PartnerApp,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreatePartnerApp', 'DeletePartnerApp', 'DescribePartnerApp', 'ListPartnerApps', 'UpdatePartnerApp']","[{'name': 'Status', 'shape_name': 'PartnerAppStatus'}]","['Creating', 'Updating', 'Deleting', 'Available', 'Failed', 'UpdateFailed', 'Deleted']" PartnerAppPresignedUrl,resource,['create'],[],[],[],['CreatePartnerAppPresignedUrl'],[],[] Pipeline,resource,"['create', 'get', 'get_all']","['delete', 'refresh', 'update', 'wait_for_delete', 'wait_for_status']",[],[],"['CreatePipeline', 'DeletePipeline', 'DescribePipeline', 'ListPipelines', 'UpdatePipeline']","[{'name': 'PipelineStatus', 'shape_name': 'PipelineStatus'}]","['Active', 'Deleting']" -PipelineExecution,resource,"['get', 'get_all', 'start']","['refresh', 'stop', 'update', 'wait_for_status']",[],[],"['DescribePipelineExecution', 'ListPipelineExecutions', 'StartPipelineExecution', 'StopPipelineExecution', 'UpdatePipelineExecution']","[{'name': 'PipelineExecutionStatus', 'shape_name': 'PipelineExecutionStatus'}]","['Executing', 'Stopping', 'Stopped', 'Failed', 'Succeeded']" +PipelineExecution,resource,"['get', 'get_all']","['refresh', 'start', 'stop', 'update', 'wait_for_status']",[],[],"['DescribePipelineExecution', 'ListPipelineExecutions', 'StartPipelineExecution', 'StopPipelineExecution', 'UpdatePipelineExecution']","[{'name': 'PipelineExecutionStatus', 'shape_name': 'PipelineExecutionStatus'}]","['Executing', 'Stopping', 'Stopped', 'Failed', 'Succeeded']" PresignedDomainUrl,resource,['create'],[],"['Space', 'UserProfile']",[],['CreatePresignedDomainUrl'],[],[] PresignedMlflowTrackingServerUrl,resource,['create'],[],[],[],['CreatePresignedMlflowTrackingServerUrl'],[],[] PresignedNotebookInstanceUrl,resource,['create'],[],['NotebookInstance'],[],['CreatePresignedNotebookInstanceUrl'],[],[] diff --git a/src/sagemaker/core/resources.py b/src/sagemaker/core/resources.py index 55940496..38a058a5 100644 --- a/src/sagemaker/core/resources.py +++ b/src/sagemaker/core/resources.py @@ -15783,6 +15783,48 @@ def delete( logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def start( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Start a InferenceExperiment resource + + Parameters: + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "Name": self.name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling start_inference_experiment API") + response = client.start_inference_experiment(**operation_input_args) + logger.debug(f"Response: {response}") + @Base.add_validate_call def stop(self) -> None: """ @@ -17502,6 +17544,48 @@ def delete( logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def start( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Start a MlflowTrackingServer resource + + Parameters: + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "TrackingServerName": self.tracking_server_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling start_mlflow_tracking_server API") + response = client.start_mlflow_tracking_server(**operation_input_args) + logger.debug(f"Response: {response}") + @Base.add_validate_call def stop(self) -> None: """ @@ -22015,6 +22099,47 @@ def delete( logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def start( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Start a MonitoringSchedule resource + + Parameters: + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "MonitoringScheduleName": self.monitoring_schedule_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling start_monitoring_schedule API") + response = client.start_monitoring_schedule(**operation_input_args) + logger.debug(f"Response: {response}") + @Base.add_validate_call def stop(self) -> None: """ @@ -22586,6 +22711,47 @@ def delete( logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def start( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Start a NotebookInstance resource + + Parameters: + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + """ + + operation_input_args = { + "NotebookInstanceName": self.notebook_instance_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling start_notebook_instance API") + response = client.start_notebook_instance(**operation_input_args) + logger.debug(f"Response: {response}") + @Base.add_validate_call def stop(self) -> None: """ @@ -24914,6 +25080,58 @@ def update( return self + @Base.add_validate_call + def start( + self, + pipeline_name: str, + client_request_token: str, + pipeline_parameters: Optional[List[Parameter]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Start a PipelineExecution resource + + Parameters: + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "PipelineName": pipeline_name, + "PipelineExecutionDisplayName": self.pipeline_execution_display_name, + "PipelineParameters": pipeline_parameters, + "PipelineExecutionDescription": self.pipeline_execution_description, + "ClientRequestToken": client_request_token, + "ParallelismConfiguration": self.parallelism_configuration, + "SelectiveExecutionConfig": self.selective_execution_config, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling start_pipeline_execution API") + response = client.start_pipeline_execution(**operation_input_args) + logger.debug(f"Response: {response}") + @Base.add_validate_call def stop(self) -> None: """ diff --git a/src/sagemaker/core/tools/constants.py b/src/sagemaker/core/tools/constants.py index f0b39acc..0acc4b59 100644 --- a/src/sagemaker/core/tools/constants.py +++ b/src/sagemaker/core/tools/constants.py @@ -13,9 +13,9 @@ """Constants used in the code_generator modules.""" import os -CLASS_METHODS = set(["create", "add", "start", "register", "import", "list", "get"]) +CLASS_METHODS = set(["create", "add", "register", "import", "list", "get"]) OBJECT_METHODS = set( - ["refresh", "delete", "update", "stop", "deregister", "wait", "wait_for_status"] + ["refresh", "delete", "update", "start", "stop", "deregister", "wait", "wait_for_status"] ) TERMINAL_STATES = set(["Completed", "Stopped", "Deleted", "Failed", "Succeeded", "Cancelled"]) diff --git a/src/sagemaker/core/tools/resources_codegen.py b/src/sagemaker/core/tools/resources_codegen.py index 72fe55d1..6d06956d 100644 --- a/src/sagemaker/core/tools/resources_codegen.py +++ b/src/sagemaker/core/tools/resources_codegen.py @@ -419,6 +419,11 @@ def generate_resource_class( ): resource_class += add_indent(delete_method, 4) + if start_method := self._evaluate_method( + resource_name, "start", object_methods, resource_attributes=resource_attributes + ): + resource_class += add_indent(start_method, 4) + if stop_method := self._evaluate_method(resource_name, "stop", object_methods): resource_class += add_indent(stop_method, 4) @@ -984,6 +989,7 @@ def _generate_docstring( str: The generated docstring for the IMPORT method. """ docstring = f"{title}\n" + _shape_attr_documentation_string = "" if operation_input_shape_name: _shape_attr_documentation_string = self._get_shape_attr_documentation_string( self.shapes_extractor.fetch_shape_members_and_doc_strings( @@ -1296,6 +1302,63 @@ def generate_delete_method(self, resource_name: str, **kwargs) -> str: ) return formatted_method + def generate_start_method(self, resource_name: str, **kwargs) -> str: + """Auto-Generate 'start' object Method [delete API] for a resource. + + Args: + resource_name (str): The resource name. + + Returns: + str: The formatted stop Method template. + """ + operation_name = "Start" + resource_name + operation_metadata = self.operations[operation_name] + operation_input_shape_name = operation_metadata["input"]["shape"] + resource_attributes = kwargs["resource_attributes"] + + method_args = add_indent("self,\n", 4) + method_args += ( + self._generate_method_args(operation_input_shape_name, resource_attributes) + "\n" + ) + operation_input_args = self._generate_operation_input_args_updated( + operation_metadata, False, resource_attributes + ) + exclude_resource_attrs = resource_attributes + method_args += add_indent("session: Optional[Session] = None,\n", 4) + method_args += add_indent("region: Optional[str] = None,", 4) + + serialize_operation_input = SERIALIZE_INPUT_TEMPLATE.format( + operation_input_args=operation_input_args + ) + call_operation_api = CALL_OPERATION_API_TEMPLATE.format( + operation=convert_to_snake_case(operation_name) + ) + + # generate docstring + docstring = self._generate_docstring( + title=f"Start a {resource_name} resource", + operation_name=operation_name, + resource_name=resource_name, + include_session_region=True, + include_return_resource_docstring=False, + exclude_resource_attrs=exclude_resource_attrs, + ) + + initialize_client = INITIALIZE_CLIENT_TEMPLATE.format(service_name="sagemaker") + + formatted_method = GENERIC_METHOD_TEMPLATE.format( + docstring=docstring, + decorator="", + method_name="start", + method_args=method_args, + return_type="None", + serialize_operation_input=serialize_operation_input, + initialize_client=initialize_client, + call_operation_api=call_operation_api, + deserialize_response="", + ) + return formatted_method + def generate_stop_method(self, resource_name: str) -> str: """Auto-Generate 'stop' object Method [delete API] for a resource. diff --git a/src/sagemaker/core/tools/resources_extractor.py b/src/sagemaker/core/tools/resources_extractor.py index 8a6f3bf6..01072f98 100644 --- a/src/sagemaker/core/tools/resources_extractor.py +++ b/src/sagemaker/core/tools/resources_extractor.py @@ -311,15 +311,7 @@ def _extract_resource_plan_as_dataframe(self): ): chain_resource_names.add(chain_resource_name) action_split = action_low.split(resource_low) - if action_split[0] == "invoke": - if not action_split[1]: - invoke_method = "invoke" - elif action_split[1] == "async": - invoke_method = "invoke_async" - else: - invoke_method = "invoke_with_response_stream" - object_methods.add(invoke_method) - elif action_split[0] in CLASS_METHODS: + if action_split[0] in CLASS_METHODS: if action_low.split(resource_low)[0] == "list": class_methods.add("get_all") else: