6
6
#
7
7
# http://www.apache.org/licenses/LICENSE-2.0
8
8
#
9
- # or in the "license" file accompanying this file. This file is distributed
10
- # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11
- # express or implied. See the License for the specific language governing
9
+ # or in the "license" file accompanying this file. This file is distributed
10
+ # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11
+ # express or implied. See the License for the specific language governing
12
12
# permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
15
import pytest
16
16
import json
17
17
18
18
from sagemaker .utils import unique_name_from_base
19
- from sagemaker .image_uris import retrieve
19
+ from sagemaker .image_uris import retrieve
20
20
from stepfunctions import steps
21
21
from stepfunctions .workflow import Workflow
22
22
from stepfunctions .steps .utils import get_aws_partition
25
25
26
26
@pytest .fixture (scope = "module" )
27
27
def training_job_parameters (sagemaker_session , sagemaker_role_arn , record_set_fixture ):
28
- parameters = {
29
- "AlgorithmSpecification" : {
28
+ parameters = {
29
+ "AlgorithmSpecification" : {
30
30
"TrainingImage" : retrieve (region = sagemaker_session .boto_session .region_name , framework = 'pca' ),
31
31
"TrainingInputMode" : "File"
32
32
},
33
- "OutputDataConfig" : {
33
+ "OutputDataConfig" : {
34
34
"S3OutputPath" : "s3://{}/" .format (sagemaker_session .default_bucket ())
35
35
},
36
- "StoppingCondition" : {
36
+ "StoppingCondition" : {
37
37
"MaxRuntimeInSeconds" : 86400
38
38
},
39
- "ResourceConfig" : {
39
+ "ResourceConfig" : {
40
40
"InstanceCount" : 1 ,
41
41
"InstanceType" : "ml.m5.large" ,
42
42
"VolumeSizeInGB" : 30
43
43
},
44
44
"RoleArn" : sagemaker_role_arn ,
45
- "InputDataConfig" :[
46
- {
47
- "DataSource" : {
48
- "S3DataSource" : {
45
+ "InputDataConfig" :[
46
+ {
47
+ "DataSource" : {
48
+ "S3DataSource" : {
49
49
"S3DataDistributionType" : "ShardedByS3Key" ,
50
50
"S3DataType" : "ManifestFile" ,
51
51
"S3Uri" : record_set_fixture .s3_data
@@ -54,7 +54,7 @@ def training_job_parameters(sagemaker_session, sagemaker_role_arn, record_set_fi
54
54
"ChannelName" : "train"
55
55
}
56
56
],
57
- "HyperParameters" : {
57
+ "HyperParameters" : {
58
58
"num_components" : "48" ,
59
59
"feature_dim" : "784" ,
60
60
"mini_batch_size" : "200"
@@ -93,7 +93,7 @@ def test_pass_state_machine_creation(sfn_client, sfn_role_arn):
93
93
94
94
definition = steps .Pass (pass_state_name , result = pass_state_result )
95
95
workflow = Workflow (
96
- 'Test_Pass_Workflow' ,
96
+ unique_name_from_base ( 'Test_Pass_Workflow' ) ,
97
97
definition = definition ,
98
98
role = sfn_role_arn
99
99
)
@@ -164,7 +164,7 @@ def test_wait_state_machine_creation(sfn_client, sfn_role_arn):
164
164
])
165
165
166
166
workflow = Workflow (
167
- 'Test_Wait_Workflow' ,
167
+ unique_name_from_base ( 'Test_Wait_Workflow' ) ,
168
168
definition = definition ,
169
169
role = sfn_role_arn
170
170
)
@@ -223,7 +223,7 @@ def test_parallel_state_machine_creation(sfn_client, sfn_role_arn):
223
223
])
224
224
225
225
workflow = Workflow (
226
- 'Test_Parallel_Workflow' ,
226
+ unique_name_from_base ( 'Test_Parallel_Workflow' ) ,
227
227
definition = definition ,
228
228
role = sfn_role_arn
229
229
)
@@ -269,9 +269,9 @@ def test_map_state_machine_creation(sfn_client, sfn_role_arn):
269
269
}
270
270
271
271
map_state = steps .Map (
272
- map_state_name ,
272
+ map_state_name ,
273
273
items_path = items_path ,
274
- iterator = steps .Pass (iterated_state_name ),
274
+ iterator = steps .Pass (iterated_state_name ),
275
275
max_concurrency = max_concurrency )
276
276
277
277
definition = steps .Chain ([
@@ -280,7 +280,7 @@ def test_map_state_machine_creation(sfn_client, sfn_role_arn):
280
280
])
281
281
282
282
workflow = Workflow (
283
- 'Test_Map_Workflow' ,
283
+ unique_name_from_base ( 'Test_Map_Workflow' ) ,
284
284
definition = definition ,
285
285
role = sfn_role_arn
286
286
)
@@ -345,8 +345,8 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn):
345
345
346
346
definition .default_choice (
347
347
steps .Fail (
348
- default_state_name ,
349
- error = default_error ,
348
+ default_state_name ,
349
+ error = default_error ,
350
350
cause = default_cause
351
351
)
352
352
)
@@ -356,23 +356,23 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn):
356
356
value = first_choice_value
357
357
),
358
358
steps .Pass (
359
- first_match_name ,
359
+ first_match_name ,
360
360
result = first_choice_state_result
361
361
)
362
362
)
363
363
definition .add_choice (
364
364
steps .ChoiceRule .NumericEquals (
365
- variable = variable ,
365
+ variable = variable ,
366
366
value = second_choice_value
367
- ),
367
+ ),
368
368
steps .Pass (
369
- second_match_name ,
369
+ second_match_name ,
370
370
result = second_choice_state_result
371
371
)
372
372
)
373
373
374
374
workflow = Workflow (
375
- 'Test_Choice_Workflow' ,
375
+ unique_name_from_base ( 'Test_Choice_Workflow' ) ,
376
376
definition = definition ,
377
377
role = sfn_role_arn
378
378
)
@@ -385,10 +385,10 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para
385
385
final_state_name = "FinalState"
386
386
resource = f"arn:{ get_aws_partition ()} :states:::sagemaker:createTrainingJob.sync"
387
387
task_state_result = "Task State Result"
388
- asl_state_machine_definition = {
388
+ asl_state_machine_definition = {
389
389
"StartAt" : task_state_name ,
390
- "States" : {
391
- task_state_name : {
390
+ "States" : {
391
+ task_state_name : {
392
392
"Resource" : resource ,
393
393
"Parameters" : training_job_parameters ,
394
394
"Type" : "Task" ,
@@ -410,9 +410,9 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para
410
410
),
411
411
steps .Pass (final_state_name , result = task_state_result )
412
412
])
413
-
413
+
414
414
workflow = Workflow (
415
- 'Test_Task_Workflow' ,
415
+ unique_name_from_base ( 'Test_Task_Workflow' ) ,
416
416
definition = definition ,
417
417
role = sfn_role_arn
418
418
)
@@ -465,13 +465,13 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
465
465
)
466
466
task .add_catch (
467
467
steps .Catch (
468
- error_equals = [all_fail_error ],
468
+ error_equals = [all_fail_error ],
469
469
next_step = steps .Pass (all_error_state_name , result = catch_state_result )
470
470
)
471
471
)
472
472
473
473
workflow = Workflow (
474
- 'Test_Catch_Workflow' ,
474
+ unique_name_from_base ( 'Test_Catch_Workflow' ) ,
475
475
definition = task ,
476
476
role = sfn_role_arn
477
477
)
@@ -518,15 +518,15 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
518
518
519
519
task .add_retry (
520
520
steps .Retry (
521
- error_equals = [all_fail_error ],
522
- interval_seconds = interval_seconds ,
523
- max_attempts = max_attempts ,
521
+ error_equals = [all_fail_error ],
522
+ interval_seconds = interval_seconds ,
523
+ max_attempts = max_attempts ,
524
524
backoff_rate = backoff_rate
525
525
)
526
526
)
527
527
528
528
workflow = Workflow (
529
- 'Test_Retry_Workflow' ,
529
+ unique_name_from_base ( 'Test_Retry_Workflow' ) ,
530
530
definition = task ,
531
531
role = sfn_role_arn
532
532
)
0 commit comments