-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathtest_handler_service.py
More file actions
120 lines (88 loc) · 4.54 KB
/
test_handler_service.py
File metadata and controls
120 lines (88 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from mock import MagicMock, Mock, patch
import mxnet as mx
import pytest
from sagemaker_inference import environment
from sagemaker_inference.default_inference_handler import DefaultInferenceHandler
from sagemaker_inference.transformer import Transformer
from sagemaker_mxnet_serving_container.default_inference_handler import DefaultGluonBlockInferenceHandler
from sagemaker_mxnet_serving_container.handler_service import HandlerService
from sagemaker_mxnet_serving_container.mxnet_module_transformer import MXNetModuleTransformer
MODULE_NAME = 'module_name'
@patch('sagemaker_mxnet_serving_container.handler_service.HandlerService._user_module_transformer')
@patch('sagemaker_inference.default_handler_service.DefaultHandlerService.initialize')
def test_handler_service(user_module_transformer, initialize):
service = HandlerService()
properties = {
'model_dir': '/opt/ml/models/model-name'
}
def getitem(key):
return properties[key]
context = MagicMock()
context.system_properties.__getitem__.side_effect = getitem
service.initialize(context)
assert isinstance(service._service, Mock)
class UserModuleTransformFn:
def __init__(self):
self.transform_fn = Mock()
@patch('sagemaker_inference.environment.Environment')
@patch('importlib.util.module_from_spec', return_value=UserModuleTransformFn())
@patch('os.path.exists', return_value=True)
def test_user_module_transform_fn(path_exists, module_from_spec, env):
env.return_value.module_name = MODULE_NAME
transformer = HandlerService._user_module_transformer()
module_from_spec.assert_called_once()
assert isinstance(transformer._default_inference_handler, DefaultInferenceHandler)
assert isinstance(transformer, Transformer)
class UserModuleModelFn:
def __init__(self):
self.model_fn = Mock()
@patch('sagemaker_inference.environment.Environment')
@patch('importlib.util.module_from_spec', return_value=UserModuleModelFn())
@patch('os.path.exists', return_value=True)
def test_user_module_mxnet_module_transformer(path_exists, module_from_spec, env):
env.return_value.module_name = MODULE_NAME
module_from_spec.return_value.model_fn.return_value = mx.module.BaseModule()
transformer = HandlerService._user_module_transformer()
module_from_spec.assert_called_once()
assert isinstance(transformer, MXNetModuleTransformer)
@patch('sagemaker_inference.environment.Environment')
@patch('sagemaker_mxnet_serving_container.default_inference_handler.DefaultMXNetInferenceHandler.default_model_fn')
@patch('importlib.util.module_from_spec', return_value=object())
@patch('os.path.exists', return_value=True)
def test_default_inference_handler_mxnet_gluon_transformer(path_exists, module_from_spec, model_fn, env):
env.return_value.module_name = MODULE_NAME
model_fn.return_value = mx.gluon.block.Block()
transformer = HandlerService._user_module_transformer()
module_from_spec.assert_called_once()
model_fn.assert_called_once_with(environment.model_dir)
assert isinstance(transformer, Transformer)
assert isinstance(transformer._default_inference_handler, DefaultGluonBlockInferenceHandler)
@patch('sagemaker_inference.environment.Environment')
@patch('importlib.util.module_from_spec', return_value=UserModuleModelFn())
@patch('os.path.exists', return_value=True)
def test_user_module_unsupported(path_exists, module_from_spec, env):
env.return_value.module_name = MODULE_NAME
with pytest.raises(ValueError) as e:
HandlerService._user_module_transformer()
module_from_spec.assert_called_once()
e.match('Unsupported model type')
@patch('sagemaker_inference.environment.Environment')
@patch('importlib.util.module_from_spec', return_value=UserModuleModelFn())
def test_user_module_invalid_path(module_from_spec, env):
env.return_value.module_name = MODULE_NAME
with pytest.raises(ValueError) as e:
HandlerService._user_module_transformer()
e.match('Invalid inference_script path')