Skip to content

Commit 217563e

Browse files
authored
[Feature] Introduce ProcessingStep to use SageMaker Processing Job (aws#68)
* Add ProcessingStep for SageMaker Processing Job * Add integration test for ProcessingStep * Update the doc for ProcessingStep * Upgrade SageMaker version to have all the fixes for ProcessingStep * Fix the failed unit tests
1 parent 3afdefd commit 217563e

File tree

9 files changed

+294
-5
lines changed

9 files changed

+294
-5
lines changed

doc/StepFunctionsWorkflowExecutionPolicy.json

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
"sagemaker:DeleteEndpoint",
2121
"sagemaker:UpdateEndpoint",
2222
"sagemaker:ListTags",
23+
"sagemaker:CreateProcessingJob",
24+
"sagemaker:DescribeProcessingJob",
25+
"sagemaker:StopProcessingJob",
2326
"lambda:InvokeFunction",
2427
"sqs:SendMessage",
2528
"sns:Publish",
@@ -63,6 +66,7 @@
6366
"arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerTrainingJobsRule",
6467
"arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerTransformJobsRule",
6568
"arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerTuningJobsRule",
69+
"arn:aws:events:*:*:rule/StepFunctionsGetEventsForSageMakerProcessingJobsRule",
6670
"arn:aws:events:*:*:rule/StepFunctionsGetEventsForECSTaskRule",
6771
"arn:aws:events:*:*:rule/StepFunctionsGetEventsForBatchJobsRule"
6872
]

doc/sagemaker.rst

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ This module provides classes to build steps that integrate with Amazon SageMaker
1414
.. autoclass:: stepfunctions.steps.sagemaker.EndpointConfigStep
1515

1616
.. autoclass:: stepfunctions.steps.sagemaker.EndpointStep
17+
18+
.. autoclass:: stepfunctions.steps.sagemaker.ProcessingStep

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
sagemaker>=1.42.8
1+
sagemaker>=1.71.0
22
boto3>=1.9.213
33
pyyaml

src/stepfunctions/steps/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from stepfunctions.steps.states import Pass, Succeed, Fail, Wait, Choice, Parallel, Map, Task, Chain, Retry, Catch
1818
from stepfunctions.steps.states import Graph, FrozenGraph
19-
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep
19+
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep, ProcessingStep
2020
from stepfunctions.steps.compute import LambdaStep, BatchSubmitJobStep, GlueStartJobRunStep, EcsRunTaskStep
2121
from stepfunctions.steps.service import DynamoDBGetItemStep, DynamoDBPutItemStep, DynamoDBUpdateItemStep, DynamoDBDeleteItemStep
2222
from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep

src/stepfunctions/steps/sagemaker.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from stepfunctions.steps.fields import Field
1818
from stepfunctions.steps.utils import tags_dict_to_kv_list
1919

20-
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config
20+
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
2121
from sagemaker.model import Model, FrameworkModel
2222
from sagemaker.model_monitor import DataCaptureConfig
2323

@@ -356,3 +356,58 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
356356
kwargs[Field.Parameters.value] = parameters
357357

358358
super(TuningStep, self).__init__(state_id, **kwargs)
359+
360+
361+
class ProcessingStep(Task):
362+
363+
"""
364+
Creates a Task State to execute a SageMaker Processing Job.
365+
"""
366+
367+
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
368+
"""
369+
Args:
370+
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
371+
processor (sagemaker.processing.Processor): The processor for the processing step.
372+
job_name (str or Placeholder): Specify a processing job name, this is required for the processing job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
373+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
374+
the processing job. These must be provided as
375+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
376+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
377+
the processing job. These can be specified as either path strings or
378+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
379+
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
380+
container_arguments ([str]): The arguments for a container used to run a processing job.
381+
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
382+
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
383+
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
384+
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
385+
The KmsKeyId is applied to all outputs.
386+
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
387+
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
388+
"""
389+
if wait_for_completion:
390+
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync'
391+
else:
392+
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob'
393+
394+
if isinstance(job_name, str):
395+
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
396+
else:
397+
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
398+
399+
if isinstance(job_name, (ExecutionInput, StepInput)):
400+
parameters['ProcessingJobName'] = job_name
401+
402+
if experiment_config is not None:
403+
parameters['ExperimentConfig'] = experiment_config
404+
405+
if tags:
406+
parameters['Tags'] = tags_dict_to_kv_list(tags)
407+
408+
if 'S3Operations' in parameters:
409+
del parameters['S3Operations']
410+
411+
kwargs[Field.Parameters.value] = parameters
412+
413+
super(ProcessingStep, self).__init__(state_id, **kwargs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import argparse
2+
import os
3+
import warnings
4+
5+
import pandas as pd
6+
import numpy as np
7+
from sklearn.model_selection import train_test_split
8+
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelBinarizer, KBinsDiscretizer
9+
from sklearn.preprocessing import PolynomialFeatures
10+
from sklearn.compose import make_column_transformer
11+
12+
from sklearn.exceptions import DataConversionWarning
13+
warnings.filterwarnings(action='ignore', category=DataConversionWarning)
14+
15+
16+
columns = ['age', 'education', 'major industry code', 'class of worker', 'num persons worked for employer',
17+
'capital gains', 'capital losses', 'dividends from stocks', 'income']
18+
class_labels = [' - 50000.', ' 50000+.']
19+
20+
def print_shape(df):
21+
negative_examples, positive_examples = np.bincount(df['income'])
22+
print('Data shape: {}, {} positive examples, {} negative examples'.format(df.shape, positive_examples, negative_examples))
23+
24+
if __name__=='__main__':
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument('--train-test-split-ratio', type=float, default=0.3)
27+
args, _ = parser.parse_known_args()
28+
29+
print('Received arguments {}'.format(args))
30+
31+
input_data_path = os.path.join('/opt/ml/processing/input', 'census-income.csv')
32+
33+
print('Reading input data from {}'.format(input_data_path))
34+
df = pd.read_csv(input_data_path)
35+
df = pd.DataFrame(data=df, columns=columns)
36+
df.dropna(inplace=True)
37+
df.drop_duplicates(inplace=True)
38+
df.replace(class_labels, [0, 1], inplace=True)
39+
40+
negative_examples, positive_examples = np.bincount(df['income'])
41+
print('Data after cleaning: {}, {} positive examples, {} negative examples'.format(df.shape, positive_examples, negative_examples))
42+
43+
split_ratio = args.train_test_split_ratio
44+
print('Splitting data into train and test sets with ratio {}'.format(split_ratio))
45+
X_train, X_test, y_train, y_test = train_test_split(df.drop('income', axis=1), df['income'], test_size=split_ratio, random_state=0)
46+
47+
preprocess = make_column_transformer(
48+
(['age', 'num persons worked for employer'], KBinsDiscretizer(encode='onehot-dense', n_bins=10)),
49+
(['capital gains', 'capital losses', 'dividends from stocks'], StandardScaler()),
50+
(['education', 'major industry code', 'class of worker'], OneHotEncoder(sparse=False))
51+
)
52+
print('Running preprocessing and feature engineering transformations')
53+
train_features = preprocess.fit_transform(X_train)
54+
test_features = preprocess.transform(X_test)
55+
56+
print('Train data shape after preprocessing: {}'.format(train_features.shape))
57+
print('Test data shape after preprocessing: {}'.format(test_features.shape))
58+
59+
train_features_output_path = os.path.join('/opt/ml/processing/train', 'train_features.csv')
60+
train_labels_output_path = os.path.join('/opt/ml/processing/train', 'train_labels.csv')
61+
62+
test_features_output_path = os.path.join('/opt/ml/processing/test', 'test_features.csv')
63+
test_labels_output_path = os.path.join('/opt/ml/processing/test', 'test_labels.csv')
64+
65+
print('Saving training features to {}'.format(train_features_output_path))
66+
pd.DataFrame(train_features).to_csv(train_features_output_path, header=False, index=False)
67+
68+
print('Saving test features to {}'.format(test_features_output_path))
69+
pd.DataFrame(test_features).to_csv(test_features_output_path, header=False, index=False)
70+
71+
print('Saving training labels to {}'.format(train_labels_output_path))
72+
y_train.to_csv(train_labels_output_path, header=False, index=False)
73+
74+
print('Saving test labels to {}'.format(test_labels_output_path))
75+
y_test.to_csv(test_labels_output_path, header=False, index=False)

tests/integ/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pickle
2121
from sagemaker import Session
2222
from sagemaker.amazon import pca
23+
from sagemaker.sklearn.processing import SKLearnProcessor
2324
from tests.integ import DATA_DIR
2425

2526
@pytest.fixture(scope="session")
@@ -58,6 +59,17 @@ def pca_estimator_fixture(sagemaker_role_arn):
5859
)
5960
return estimator
6061

62+
@pytest.fixture(scope="session")
63+
def sklearn_processor_fixture(sagemaker_role_arn):
64+
processor = SKLearnProcessor(
65+
framework_version="0.20.0",
66+
role=sagemaker_role_arn,
67+
instance_type="ml.m5.xlarge",
68+
instance_count=1,
69+
max_runtime_in_seconds=300
70+
)
71+
return processor
72+
6173
@pytest.fixture(scope="session")
6274
def train_set():
6375
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")

tests/integ/test_sagemaker_steps.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
from sagemaker.utils import unique_name_from_base
2828
from sagemaker.parameter import IntegerParameter, CategoricalParameter
2929
from sagemaker.tuner import HyperparameterTuner
30+
from sagemaker.processing import ProcessingInput, ProcessingOutput
3031

3132
from stepfunctions.steps import Chain
32-
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep
33+
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
3334
from stepfunctions.workflow import Workflow
3435

3536
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
@@ -297,3 +298,56 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
297298
# Cleanup
298299
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
299300
# End of Cleanup
301+
302+
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
303+
region = boto3.session.Session().region_name
304+
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)
305+
306+
input_s3 = sagemaker_session.upload_data(
307+
path=os.path.join(DATA_DIR, 'sklearn_processing'),
308+
bucket=sagemaker_session.default_bucket(),
309+
key_prefix='integ-test-data/sklearn_processing/code'
310+
)
311+
312+
output_s3 = 's3://' + sagemaker_session.default_bucket() + '/integ-test-data/sklearn_processing'
313+
314+
inputs = [
315+
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
316+
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code', input_name='code'),
317+
]
318+
319+
outputs = [
320+
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data', output_name='train_data'),
321+
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data', output_name='test_data'),
322+
]
323+
324+
job_name = generate_job_name()
325+
processing_step = ProcessingStep('create_processing_job_step',
326+
processor=sklearn_processor_fixture,
327+
job_name=job_name,
328+
inputs=inputs,
329+
outputs=outputs,
330+
container_arguments=['--train-test-split-ratio', '0.2'],
331+
container_entrypoint=['python3', '/opt/ml/processing/input/code/preprocessor.py'],
332+
)
333+
workflow_graph = Chain([processing_step])
334+
335+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
336+
# Create workflow and check definition
337+
workflow = create_workflow_and_check_definition(
338+
workflow_graph=workflow_graph,
339+
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
340+
sfn_client=sfn_client,
341+
sfn_role_arn=sfn_role_arn
342+
)
343+
344+
# Execute workflow
345+
execution = workflow.execute()
346+
execution_output = execution.get_output(wait=True)
347+
348+
# Check workflow output
349+
assert execution_output.get("ProcessingJobStatus") == "Completed"
350+
351+
# Cleanup
352+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
353+
# End of Cleanup

tests/unit/test_sagemaker_steps.py

+88-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
from sagemaker.pipeline import PipelineModel
2323
from sagemaker.model_monitor import DataCaptureConfig
2424
from sagemaker.debugger import Rule, rule_configs, DebuggerHookConfig, CollectionConfig
25+
from sagemaker.sklearn.processing import SKLearnProcessor
26+
from sagemaker.processing import ProcessingInput, ProcessingOutput
2527

2628
from unittest.mock import MagicMock, patch
27-
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep
29+
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep
2830
from stepfunctions.steps.sagemaker import tuning_config
2931

3032
from tests.unit.utils import mock_boto_api_call
@@ -156,6 +158,22 @@ def tensorflow_estimator():
156158

157159
return estimator
158160

161+
@pytest.fixture
162+
def sklearn_processor():
163+
sagemaker_session = MagicMock()
164+
sagemaker_session.boto_region_name = 'us-east-1'
165+
sagemaker_session._default_bucket = 'sagemaker'
166+
167+
processor = SKLearnProcessor(
168+
framework_version="0.20.0",
169+
role=EXECUTION_ROLE,
170+
instance_type="ml.m5.xlarge",
171+
instance_count=1,
172+
sagemaker_session=sagemaker_session
173+
)
174+
175+
return processor
176+
159177
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
160178
def test_training_step_creation(pca_estimator):
161179
step = TrainingStep('Training',
@@ -566,3 +584,72 @@ def test_endpoint_step_creation(pca_model):
566584
'Resource': 'arn:aws:states:::sagemaker:updateEndpoint',
567585
'End': True
568586
}
587+
588+
def test_processing_step_creation(sklearn_processor):
589+
inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')]
590+
outputs = [
591+
ProcessingOutput(source='/opt/ml/processing/output/train'),
592+
ProcessingOutput(source='/opt/ml/processing/output/validation'),
593+
ProcessingOutput(source='/opt/ml/processing/output/test')
594+
]
595+
step = ProcessingStep('Feature Transformation', sklearn_processor, 'MyProcessingJob', inputs=inputs, outputs=outputs)
596+
assert step.to_dict() == {
597+
'Type': 'Task',
598+
'Parameters': {
599+
'AppSpecification': {
600+
'ImageUri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3'
601+
},
602+
'ProcessingInputs': [
603+
{
604+
'InputName': None,
605+
'S3Input': {
606+
'LocalPath': '/opt/ml/processing/input',
607+
'S3CompressionType': 'None',
608+
'S3DataDistributionType': 'FullyReplicated',
609+
'S3DataType': 'S3Prefix',
610+
'S3InputMode': 'File',
611+
'S3Uri': 'dataset.csv'
612+
}
613+
}
614+
],
615+
'ProcessingOutputConfig': {
616+
'Outputs': [
617+
{
618+
'OutputName': None,
619+
'S3Output': {
620+
'LocalPath': '/opt/ml/processing/output/train',
621+
'S3UploadMode': 'EndOfJob',
622+
'S3Uri': None
623+
}
624+
},
625+
{
626+
'OutputName': None,
627+
'S3Output': {
628+
'LocalPath': '/opt/ml/processing/output/validation',
629+
'S3UploadMode': 'EndOfJob',
630+
'S3Uri': None
631+
}
632+
},
633+
{
634+
'OutputName': None,
635+
'S3Output': {
636+
'LocalPath': '/opt/ml/processing/output/test',
637+
'S3UploadMode': 'EndOfJob',
638+
'S3Uri': None
639+
}
640+
}
641+
]
642+
},
643+
'ProcessingResources': {
644+
'ClusterConfig': {
645+
'InstanceCount': 1,
646+
'InstanceType': 'ml.m5.xlarge',
647+
'VolumeSizeInGB': 30
648+
}
649+
},
650+
'ProcessingJobName': 'MyProcessingJob',
651+
'RoleArn': EXECUTION_ROLE
652+
},
653+
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
654+
'End': True
655+
}

0 commit comments

Comments
 (0)