Skip to content

Commit b5c5037

Browse files
author
Wei Chu
committed
add test
1 parent c774dd0 commit b5c5037

File tree

4 files changed

+158
-5
lines changed

4 files changed

+158
-5
lines changed

src/sagemaker_pytorch_serving_container/handler_service.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker_inference.default_handler_service import DefaultHandlerService
16-
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
1716
from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer
1817

1918
import os
@@ -37,7 +36,7 @@ class HandlerService(DefaultHandlerService):
3736
def __init__(self):
3837
self._initialized = False
3938

40-
transformer = PyTorchTransformer(default_inference_handler=DefaultPytorchInferenceHandler())
39+
transformer = PyTorchTransformer()
4140
super(HandlerService, self).__init__(transformer=transformer)
4241

4342
def initialize(self, context):

src/sagemaker_pytorch_serving_container/transformer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
from sagemaker_inference.transformer import Transformer
2020
from sagemaker_inference import content_types, environment, utils
2121
from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError
22+
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
2223

2324

2425
class PyTorchTransformer(Transformer):
2526
"""Represents the execution workflow for handling pytorch inference requests
2627
sent to the model server.
2728
"""
28-
def __init__(self, default_inference_handler=None):
29+
def __init__(self, default_inference_handler=DefaultPytorchInferenceHandler()):
2930
super().__init__(default_inference_handler)
3031
self._context = None
3132

@@ -44,7 +45,7 @@ def transform(self, data, context):
4445
try:
4546
properties = context.system_properties
4647
model_dir = properties.get("model_dir")
47-
self.validate_and_initialize(model_dir=model_dir, context=self._context)
48+
self.validate_and_initialize(model_dir=model_dir, context=context)
4849

4950
response_list = []
5051
for i in range(len(data)):

test/unit/test_handler_service.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_hosting_start(PyTorchTransformer, DefaultPytorchInferenceHandler):
2222

2323
handler_service.HandlerService()
2424

25-
PyTorchTransformer.assert_called_with(default_inference_handler=DefaultPytorchInferenceHandler())
25+
PyTorchTransformer.assert_called_with()
2626

2727

2828
@patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler')

test/unit/test_transformer.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2019-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from mock import Mock, patch
15+
import pytest
16+
17+
from sagemaker_inference import environment
18+
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
19+
from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer
20+
21+
22+
INPUT_DATA = "input_data"
23+
CONTENT_TYPE = "content_type"
24+
ACCEPT = "accept"
25+
RESULT = "result"
26+
MODEL = "foo"
27+
28+
PREPROCESSED_DATA = "preprocessed_data"
29+
PREDICT_RESULT = "prediction_result"
30+
PROCESSED_RESULT = "processed_result"
31+
32+
33+
def test_default_transformer():
34+
transformer = PyTorchTransformer()
35+
36+
assert isinstance(transformer._default_inference_handler, DefaultPytorchInferenceHandler)
37+
assert transformer._initialized is False
38+
assert transformer._environment is None
39+
assert transformer._pre_model_fn is None
40+
assert transformer._model_warmup_fn is None
41+
assert transformer._model is None
42+
assert transformer._model_fn is None
43+
assert transformer._transform_fn is None
44+
assert transformer._input_fn is None
45+
assert transformer._predict_fn is None
46+
assert transformer._output_fn is None
47+
assert transformer._context is None
48+
49+
50+
def test_transformer_with_custom_default_inference_handler():
51+
default_inference_handler = Mock()
52+
53+
transformer = PyTorchTransformer(default_inference_handler)
54+
55+
assert transformer._default_inference_handler == default_inference_handler
56+
assert transformer._initialized is False
57+
assert transformer._environment is None
58+
assert transformer._pre_model_fn is None
59+
assert transformer._model_warmup_fn is None
60+
assert transformer._model is None
61+
assert transformer._model_fn is None
62+
assert transformer._transform_fn is None
63+
assert transformer._input_fn is None
64+
assert transformer._predict_fn is None
65+
assert transformer._output_fn is None
66+
assert transformer._context is None
67+
68+
69+
@pytest.mark.parametrize("accept_key", ["Accept", "accept"])
70+
@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE)
71+
@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize")
72+
def test_transform(validate, retrieve_content_type_header, accept_key):
73+
data = [{"body": INPUT_DATA}]
74+
context = Mock()
75+
request_processor = Mock()
76+
transform_fn = Mock(return_value=RESULT)
77+
78+
context.request_processor = [request_processor]
79+
request_property = {accept_key: ACCEPT}
80+
request_processor.get_request_properties.return_value = request_property
81+
82+
transformer = PyTorchTransformer()
83+
transformer._model = MODEL
84+
transformer._transform_fn = transform_fn
85+
transformer._context = context
86+
87+
result = transformer.transform(data, context)
88+
89+
validate.assert_called_once()
90+
retrieve_content_type_header.assert_called_once_with(request_property)
91+
transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
92+
context.set_response_content_type.assert_called_once_with(0, ACCEPT)
93+
assert isinstance(result, list)
94+
assert result[0] == RESULT
95+
96+
97+
@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer._validate_user_module_and_set_functions")
98+
@patch("sagemaker_inference.environment.Environment")
99+
def test_validate_and_initialize(env, validate_user_module):
100+
transformer = PyTorchTransformer()
101+
102+
model_fn = Mock()
103+
context = Mock()
104+
transformer._model_fn = model_fn
105+
106+
assert transformer._initialized is False
107+
assert transformer._context is None
108+
109+
transformer.validate_and_initialize(context=context)
110+
111+
assert transformer._initialized is True
112+
assert transformer._context == context
113+
114+
transformer.validate_and_initialize()
115+
116+
model_fn.assert_called_once_with(environment.model_dir, context)
117+
env.assert_called_once_with()
118+
validate_user_module.assert_called_once_with()
119+
120+
121+
def test_default_transform_fn():
122+
transformer = PyTorchTransformer()
123+
context = Mock()
124+
transformer._context = context
125+
126+
input_fn = Mock(return_value=PREPROCESSED_DATA)
127+
predict_fn = Mock(return_value=PREDICT_RESULT)
128+
output_fn = Mock(return_value=PROCESSED_RESULT)
129+
130+
transformer._input_fn = input_fn
131+
transformer._predict_fn = predict_fn
132+
transformer._output_fn = output_fn
133+
134+
result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT)
135+
136+
input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, context)
137+
predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, context)
138+
output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, context)
139+
assert result == PROCESSED_RESULT
140+
141+
142+
def test_run_handle_function():
143+
def three_inputs_func(a, b, c): pass
144+
145+
three_inputs_mock = Mock(spec=three_inputs_func)
146+
a = Mock()
147+
b = Mock()
148+
context = Mock()
149+
150+
transformer = PyTorchTransformer()
151+
transformer._context = context
152+
transformer._run_handle_function(three_inputs_mock, a, b)
153+
three_inputs_mock.assert_called_with(a, b, context)

0 commit comments

Comments
 (0)