From 1f1ea242034febb92c808d5d862edaf8e730f282 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 28 Jul 2022 16:17:12 -0700 Subject: [PATCH 01/12] multi-gpu support --- .../default_pytorch_inference_handler.py | 28 +++- .../handler_service.py | 28 ++-- .../transformer.py | 148 ++++++++++++++++++ 3 files changed, 188 insertions(+), 16 deletions(-) create mode 100644 src/sagemaker_pytorch_serving_container/transformer.py diff --git a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py index 313f956f..955e57d4 100644 --- a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py @@ -43,12 +43,13 @@ def _is_model_file(filename): is_model_file = ext in [".pt", ".pth"] return is_model_file - def default_model_fn(self, model_dir): + def default_model_fn(self, model_dir, context=None): """Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used. In other cases, users should provide customized model_fn() in script. Args: model_dir: a directory where model is saved. + context: context for the request. Returns: A PyTorch model. """ @@ -65,7 +66,12 @@ def default_model_fn(self, model_dir): "Failed to load {}. Please ensure model is saved using torchscript.".format(model_path) ) from e else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if context: + properties = context.system_properties + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME) if not os.path.exists(model_path): model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)] @@ -83,29 +89,35 @@ def default_model_fn(self, model_dir): model = model.to(device) return model - def default_input_fn(self, input_data, content_type): + def default_input_fn(self, input_data, content_type, context=None): """A default input_fn that can handle JSON, CSV and NPZ formats. Args: input_data: the request payload serialized in the content_type format content_type: the request content_type + context: context for the request Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor, depending if cuda is available. """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if context: + properties = context.system_properties + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") np_array = decoder.decode(input_data, content_type) tensor = torch.FloatTensor( np_array) if content_type in content_types.UTF8_TYPES else torch.from_numpy(np_array) return tensor.to(device) - def default_predict_fn(self, data, model): + def default_predict_fn(self, data, model, context=None): """A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn. Runs prediction on GPU if cuda is available. Args: data: input data (torch.Tensor) for prediction deserialized by input_fn model: PyTorch model loaded in memory by model_fn + context: context for the request Returns: a prediction """ @@ -118,7 +130,11 @@ def default_predict_fn(self, data, model): with torch.jit.optimized_execution(True, {"target_device": "eia:0"}): output = model(input_data) else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if context: + properties = context.system_properties + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) input_data = data.to(device) model.eval() diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 07e81c7b..c54d4556 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -13,32 +13,30 @@ from __future__ import absolute_import from sagemaker_inference.default_handler_service import DefaultHandlerService -from sagemaker_inference.transformer import Transformer from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler +from sagemaker_pytorch_serving_container.transformer import PTTransformer import os import sys +PYTHON_PATH_ENV = "PYTHONPATH" ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" - class HandlerService(DefaultHandlerService): + """ + Handler service that is executed by the model server. - """Handler service that is executed by the model server. - - Determines specific default inference handlers to use based on the type MXNet model being used. + Determines specific default inference handlers to use based on the type pytorch model being used. This class extends ``DefaultHandlerService``, which define the following: - The ``handle`` method is invoked for all incoming inference requests to the model server. - The ``initialize`` method is invoked at model server start up. - - Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/custom_service.md - """ + def __init__(self): self._initialized = False - transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler()) + transformer = PTTransformer(default_inference_handler=DefaultPytorchInferenceHandler()) super(HandlerService, self).__init__(transformer=transformer) def initialize(self, context): @@ -48,4 +46,14 @@ def initialize(self, context): sys.path.append(code_dir) self._initialized = True - super().initialize(context) + properties = context.system_properties + model_dir = properties.get("model_dir") + + # add model_dir/code to python path + code_dir_path = "{}:".format(model_dir + "/code") + if PYTHON_PATH_ENV in os.environ: + os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] + else: + os.environ[PYTHON_PATH_ENV] = code_dir_path + + self._service.validate_and_initialize(model_dir=model_dir, context=context) diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py new file mode 100644 index 00000000..97a8b2e7 --- /dev/null +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -0,0 +1,148 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import + +import logging +import traceback + +from six.moves import http_client +from sagemaker_inference.transformer import Transformer +from sagemaker_inference import content_types, environment, utils +from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError + +logger = logging.getLogger() + +class PTTransformer(Transformer): + """Represents the execution workflow for handling pytorch inference requests + sent to the model server. + """ + def __init__(self, default_inference_handler=None): + super().__init__(default_inference_handler) + self._context = None + + def transform(self, data, context): + """Take a request with input data, deserialize it, make a prediction, and return a + serialized response. + Args: + data (obj): the request data. + context (obj): metadata on the incoming request data. + Returns: + list[obj]: The serialized prediction result wrapped in a list if + inference is successful. Otherwise returns an error message + with the context set appropriately. + """ + + try: + properties = context.system_properties + model_dir = properties.get("model_dir") + self.validate_and_initialize(model_dir=model_dir, cotext=self._context) + + input_data = data[0].get("body") + + request_processor = context.request_processor[0] + + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") + + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept + + if content_type in content_types.UTF8_TYPES: + input_data = input_data.decode("utf-8") + + try: + # custom/default handler takes context (for multi-gpu setup) + logger.info('running transform function with context.') + result = self._transform_fn(self._model, input_data, content_type, accept, self._context) + except TypeError: + # custom handler does not take context + logger.info('running transform function without context.') + result = self._transform_fn(self._model, input_data, content_type, accept) + + response = result + response_content_type = accept + + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] + + context.set_response_content_type(0, response_content_type) + return [response] + except Exception as e: # pylint: disable=broad-except + trace = traceback.format_exc() + if isinstance(e, BaseInferenceToolkitError): + return super().handle_error(context, e, trace) + else: + return super().handle_error( + context, + GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), + trace, + ) + + def validate_and_initialize(self, model_dir=environment.model_dir, context=None): + """Validates the user module against the SageMaker inference contract. + Load the model as defined by the ``model_fn`` to prepare handling predictions. + """ + if not self._initialized: + self._context = context + self._environment = environment.Environment() + self._validate_user_module_and_set_functions() + try: + # custom/default model function takes context (for multi-gpu setup) + logger.info('running model functions with context.') + if self._pre_model_fn is not None: + self._pre_model_fn(model_dir, context) + self._model = self._model_fn(model_dir, context) + if self._model_warmup_fn is not None: + self._model_warmup_fn(model_dir, self._model, context) + except TypeError: + # custom model function does not take context + logger.info('running model functions without context.') + if self._pre_model_fn is not None: + self._pre_model_fn(model_dir) + self._model = self._model_fn(model_dir) + if self._model_warmup_fn is not None: + self._model_warmup_fn(model_dir, self._model) + self._initialized = True + + def _default_transform_fn(self, model, input_data, content_type, accept): + """Make predictions against the model and return a serialized response. + This serves as the default implementation of transform_fn, used when the + user has not provided an implementation. + Args: + model (obj): model loaded by model_fn. + input_data (obj): the request data. + content_type (str): the request content type. + accept (str): accept header expected by the client. + Returns: + obj: the serialized prediction result or a tuple of the form + (response_data, content_type) + """ + try: + # custom/default handler takes context (for multi-gpu setup) + logger.info('running handler functions with context.') + data = self._input_fn(input_data, content_type, self._context) + prediction = self._predict_fn(data, model, self._context) + result = self._output_fn(prediction, accept, self._context) + except TypeError: + # custom handler does not take context + logger.info('running handler functions without context.') + data = self._input_fn(input_data, content_type) + prediction = self._predict_fn(data, model) + result = self._output_fn(prediction, accept) + + return result + \ No newline at end of file From 47d6f21b0eac5fa22f0f75f90732eee87def87dd Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 4 Aug 2022 15:14:28 -0700 Subject: [PATCH 02/12] add fucntion wrapper --- .../transformer.py | 66 ++++++++----------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py index 97a8b2e7..b85d9a80 100644 --- a/src/sagemaker_pytorch_serving_container/transformer.py +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -21,7 +21,6 @@ from sagemaker_inference import content_types, environment, utils from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError -logger = logging.getLogger() class PTTransformer(Transformer): """Represents the execution workflow for handling pytorch inference requests @@ -46,7 +45,7 @@ def transform(self, data, context): try: properties = context.system_properties model_dir = properties.get("model_dir") - self.validate_and_initialize(model_dir=model_dir, cotext=self._context) + self.validate_and_initialize(model_dir=model_dir, context=self._context) input_data = data[0].get("body") @@ -62,14 +61,7 @@ def transform(self, data, context): if content_type in content_types.UTF8_TYPES: input_data = input_data.decode("utf-8") - try: - # custom/default handler takes context (for multi-gpu setup) - logger.info('running transform function with context.') - result = self._transform_fn(self._model, input_data, content_type, accept, self._context) - except TypeError: - # custom handler does not take context - logger.info('running transform function without context.') - result = self._transform_fn(self._model, input_data, content_type, accept) + result = self._run_handle_function(self._transform_fn, *(self._model, input_data, content_type, accept)) response = result response_content_type = accept @@ -100,22 +92,15 @@ def validate_and_initialize(self, model_dir=environment.model_dir, context=None) self._context = context self._environment = environment.Environment() self._validate_user_module_and_set_functions() - try: - # custom/default model function takes context (for multi-gpu setup) - logger.info('running model functions with context.') - if self._pre_model_fn is not None: - self._pre_model_fn(model_dir, context) - self._model = self._model_fn(model_dir, context) - if self._model_warmup_fn is not None: - self._model_warmup_fn(model_dir, self._model, context) - except TypeError: - # custom model function does not take context - logger.info('running model functions without context.') - if self._pre_model_fn is not None: - self._pre_model_fn(model_dir) - self._model = self._model_fn(model_dir) - if self._model_warmup_fn is not None: - self._model_warmup_fn(model_dir, self._model) + + if self._pre_model_fn is not None: + self._run_handle_function(self._pre_model_fn, *(model_dir, )) + + self._model = self._run_handle_function(self._model_fn, *(model_dir, )) + + if self._model_warmup_fn is not None: + self._run_handle_function(self._model_warmup_fn, *(model_dir, self._model)) + self._initialized = True def _default_transform_fn(self, model, input_data, content_type, accept): @@ -131,18 +116,21 @@ def _default_transform_fn(self, model, input_data, content_type, accept): obj: the serialized prediction result or a tuple of the form (response_data, content_type) """ - try: - # custom/default handler takes context (for multi-gpu setup) - logger.info('running handler functions with context.') - data = self._input_fn(input_data, content_type, self._context) - prediction = self._predict_fn(data, model, self._context) - result = self._output_fn(prediction, accept, self._context) - except TypeError: - # custom handler does not take context - logger.info('running handler functions without context.') - data = self._input_fn(input_data, content_type) - prediction = self._predict_fn(data, model) - result = self._output_fn(prediction, accept) + data = self._run_handle_function(self._input_fn, *(input_data, content_type)) + prediction = self._run_handle_function(self._predict_fn, *(data, model)) + result = self._run_handle_function(self._output_fn, *(prediction, accept)) return result - \ No newline at end of file + + def _run_handle_function(self, func, *argv): + """Wrapper to call the handle function which covers 2 cases: + 1. context passed to the handle function + 2. context not passed to the handle function + """ + try: + argv_context = argv + (self._context, ) + result = func(*argv_context) + except TypeError: + result = func(*argv) + + return result From 58f74252a63a77c9ada2e8c3fdeeb47a06aa1299 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Aug 2022 16:14:52 -0700 Subject: [PATCH 03/12] fix for batch inference --- .../handler_service.py | 4 +- .../transformer.py | 43 +++++++++++-------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index c54d4556..ed2cddfb 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -14,7 +14,7 @@ from sagemaker_inference.default_handler_service import DefaultHandlerService from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler -from sagemaker_pytorch_serving_container.transformer import PTTransformer +from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer import os import sys @@ -36,7 +36,7 @@ class HandlerService(DefaultHandlerService): def __init__(self): self._initialized = False - transformer = PTTransformer(default_inference_handler=DefaultPytorchInferenceHandler()) + transformer = PyTorchTransformer(default_inference_handler=DefaultPytorchInferenceHandler()) super(HandlerService, self).__init__(transformer=transformer) def initialize(self, context): diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py index b85d9a80..7f1f3c99 100644 --- a/src/sagemaker_pytorch_serving_container/transformer.py +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -22,7 +22,7 @@ from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError -class PTTransformer(Transformer): +class PyTorchTransformer(Transformer): """Represents the execution workflow for handling pytorch inference requests sent to the model server. """ @@ -47,32 +47,37 @@ def transform(self, data, context): model_dir = properties.get("model_dir") self.validate_and_initialize(model_dir=model_dir, context=self._context) - input_data = data[0].get("body") + response_list = [] + for i in range(len(data)): + input_data = data[i].get("body") - request_processor = context.request_processor[0] + request_processor = context.request_processor[0] - request_property = request_processor.get_request_properties() - content_type = utils.retrieve_content_type_header(request_property) - accept = request_property.get("Accept") or request_property.get("accept") + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") - if not accept or accept == content_types.ANY: - accept = self._environment.default_accept + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept - if content_type in content_types.UTF8_TYPES: - input_data = input_data.decode("utf-8") + if content_type in content_types.UTF8_TYPES: + input_data = input_data.decode("utf-8") - result = self._run_handle_function(self._transform_fn, *(self._model, input_data, content_type, accept)) + result = self._run_handle_function(self._transform_fn, *(self._model, input_data, content_type, accept)) - response = result - response_content_type = accept + response = result + response_content_type = accept - if isinstance(result, tuple): - # handles tuple for backwards compatibility - response = result[0] - response_content_type = result[1] + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] - context.set_response_content_type(0, response_content_type) - return [response] + context.set_response_content_type(0, response_content_type) + + response_list.append(response) + + return response_list except Exception as e: # pylint: disable=broad-except trace = traceback.format_exc() if isinstance(e, BaseInferenceToolkitError): From 2a0dbdab1651496f2f17ce53af49b714e6341f19 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Aug 2022 16:33:13 -0700 Subject: [PATCH 04/12] update unit test --- test/unit/test_handler_service.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/unit/test_handler_service.py b/test/unit/test_handler_service.py index fd3dfc60..25b3d984 100644 --- a/test/unit/test_handler_service.py +++ b/test/unit/test_handler_service.py @@ -16,18 +16,18 @@ @patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler') -@patch('sagemaker_inference.transformer.Transformer') -def test_hosting_start(Transformer, DefaultPytorchInferenceHandler): +@patch('sagemaker_pytorch_serving_container.transformer.PyTorchTransformer') +def test_hosting_start(PyTorchTransformer, DefaultPytorchInferenceHandler): from sagemaker_pytorch_serving_container import handler_service handler_service.HandlerService() - Transformer.assert_called_with(default_inference_handler=DefaultPytorchInferenceHandler()) + PyTorchTransformer.assert_called_with(default_inference_handler=DefaultPytorchInferenceHandler()) @patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler') -@patch('sagemaker_inference.transformer.Transformer') -def test_hosting_start_enable_multi_model(Transformer, DefaultPytorchInferenceHandler): +@patch('sagemaker_pytorch_serving_container.transformer.PyTorchTransformer') +def test_hosting_start_enable_multi_model(PyTorchTransformer, DefaultPytorchInferenceHandler): from sagemaker_pytorch_serving_container import handler_service context = Mock() From d869551cd86e2c956bb203bccb6c85617922f592 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Aug 2022 18:17:08 -0700 Subject: [PATCH 05/12] fix sanity --- .../default_pytorch_inference_handler.py | 9 +++++---- .../handler_service.py | 1 + src/sagemaker_pytorch_serving_container/transformer.py | 7 +++---- test/utils/file_utils.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py index 955e57d4..4d6a7aab 100644 --- a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py @@ -66,7 +66,7 @@ def default_model_fn(self, model_dir, context=None): "Failed to load {}. Please ensure model is saved using torchscript.".format(model_path) ) from e else: - if context: + if context: properties = context.system_properties device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") else: @@ -100,7 +100,7 @@ def default_input_fn(self, input_data, content_type, context=None): Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor, depending if cuda is available. """ - if context: + if context: properties = context.system_properties device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") else: @@ -130,9 +130,10 @@ def default_predict_fn(self, data, model, context=None): with torch.jit.optimized_execution(True, {"target_device": "eia:0"}): output = model(input_data) else: - if context: + if context: properties = context.system_properties - device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + device = torch.device("cuda:" + str(properties.get("gpu_id")) + if torch.cuda.is_available() else "cpu") else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index ed2cddfb..3f314289 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -22,6 +22,7 @@ PYTHON_PATH_ENV = "PYTHONPATH" ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" + class HandlerService(DefaultHandlerService): """ Handler service that is executed by the model server. diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py index 7f1f3c99..9a172d5b 100644 --- a/src/sagemaker_pytorch_serving_container/transformer.py +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -13,7 +13,6 @@ from __future__ import absolute_import -import logging import traceback from six.moves import http_client @@ -88,7 +87,7 @@ def transform(self, data, context): GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), trace, ) - + def validate_and_initialize(self, model_dir=environment.model_dir, context=None): """Validates the user module against the SageMaker inference contract. Load the model as defined by the ``model_fn`` to prepare handling predictions. @@ -126,7 +125,7 @@ def _default_transform_fn(self, model, input_data, content_type, accept): result = self._run_handle_function(self._output_fn, *(prediction, accept)) return result - + def _run_handle_function(self, func, *argv): """Wrapper to call the handle function which covers 2 cases: 1. context passed to the handle function @@ -137,5 +136,5 @@ def _run_handle_function(self, func, *argv): result = func(*argv_context) except TypeError: result = func(*argv) - + return result diff --git a/test/utils/file_utils.py b/test/utils/file_utils.py index 8cc3771d..327c8ecc 100644 --- a/test/utils/file_utils.py +++ b/test/utils/file_utils.py @@ -19,7 +19,7 @@ def make_tarfile(script, model, output_path, filename="model.tar.gz", script_path=None): output_filename = os.path.join(output_path, filename) with tarfile.open(output_filename, "w:gz") as tar: - if(script_path): + if (script_path): tar.add(script, arcname=os.path.join(script_path, os.path.basename(script))) else: tar.add(script, arcname=os.path.basename(script)) From c774dd03d0555ace13f05fbf7394d57807fdde33 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Aug 2022 18:36:10 -0700 Subject: [PATCH 06/12] fix sanity --- .../default_pytorch_inference_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py index 4d6a7aab..e1f1dce4 100644 --- a/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py @@ -132,7 +132,7 @@ def default_predict_fn(self, data, model, context=None): else: if context: properties = context.system_properties - device = torch.device("cuda:" + str(properties.get("gpu_id")) + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") From b5c50374b99bed0547de5265692ec067431b3127 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 8 Aug 2022 20:43:56 -0700 Subject: [PATCH 07/12] add test --- .../handler_service.py | 3 +- .../transformer.py | 5 +- test/unit/test_handler_service.py | 2 +- test/unit/test_transformer.py | 153 ++++++++++++++++++ 4 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 test/unit/test_transformer.py diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 3f314289..6cce4ca8 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -13,7 +13,6 @@ from __future__ import absolute_import from sagemaker_inference.default_handler_service import DefaultHandlerService -from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer import os @@ -37,7 +36,7 @@ class HandlerService(DefaultHandlerService): def __init__(self): self._initialized = False - transformer = PyTorchTransformer(default_inference_handler=DefaultPytorchInferenceHandler()) + transformer = PyTorchTransformer() super(HandlerService, self).__init__(transformer=transformer) def initialize(self, context): diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py index 9a172d5b..6da834c5 100644 --- a/src/sagemaker_pytorch_serving_container/transformer.py +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -19,13 +19,14 @@ from sagemaker_inference.transformer import Transformer from sagemaker_inference import content_types, environment, utils from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError +from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler class PyTorchTransformer(Transformer): """Represents the execution workflow for handling pytorch inference requests sent to the model server. """ - def __init__(self, default_inference_handler=None): + def __init__(self, default_inference_handler=DefaultPytorchInferenceHandler()): super().__init__(default_inference_handler) self._context = None @@ -44,7 +45,7 @@ def transform(self, data, context): try: properties = context.system_properties model_dir = properties.get("model_dir") - self.validate_and_initialize(model_dir=model_dir, context=self._context) + self.validate_and_initialize(model_dir=model_dir, context=context) response_list = [] for i in range(len(data)): diff --git a/test/unit/test_handler_service.py b/test/unit/test_handler_service.py index 25b3d984..8be732ec 100644 --- a/test/unit/test_handler_service.py +++ b/test/unit/test_handler_service.py @@ -22,7 +22,7 @@ def test_hosting_start(PyTorchTransformer, DefaultPytorchInferenceHandler): handler_service.HandlerService() - PyTorchTransformer.assert_called_with(default_inference_handler=DefaultPytorchInferenceHandler()) + PyTorchTransformer.assert_called_with() @patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler') diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py new file mode 100644 index 00000000..452c46c8 --- /dev/null +++ b/test/unit/test_transformer.py @@ -0,0 +1,153 @@ +# Copyright 2019-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from mock import Mock, patch +import pytest + +from sagemaker_inference import environment +from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler +from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer + + +INPUT_DATA = "input_data" +CONTENT_TYPE = "content_type" +ACCEPT = "accept" +RESULT = "result" +MODEL = "foo" + +PREPROCESSED_DATA = "preprocessed_data" +PREDICT_RESULT = "prediction_result" +PROCESSED_RESULT = "processed_result" + + +def test_default_transformer(): + transformer = PyTorchTransformer() + + assert isinstance(transformer._default_inference_handler, DefaultPytorchInferenceHandler) + assert transformer._initialized is False + assert transformer._environment is None + assert transformer._pre_model_fn is None + assert transformer._model_warmup_fn is None + assert transformer._model is None + assert transformer._model_fn is None + assert transformer._transform_fn is None + assert transformer._input_fn is None + assert transformer._predict_fn is None + assert transformer._output_fn is None + assert transformer._context is None + + +def test_transformer_with_custom_default_inference_handler(): + default_inference_handler = Mock() + + transformer = PyTorchTransformer(default_inference_handler) + + assert transformer._default_inference_handler == default_inference_handler + assert transformer._initialized is False + assert transformer._environment is None + assert transformer._pre_model_fn is None + assert transformer._model_warmup_fn is None + assert transformer._model is None + assert transformer._model_fn is None + assert transformer._transform_fn is None + assert transformer._input_fn is None + assert transformer._predict_fn is None + assert transformer._output_fn is None + assert transformer._context is None + + +@pytest.mark.parametrize("accept_key", ["Accept", "accept"]) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform(validate, retrieve_content_type_header, accept_key): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock(return_value=RESULT) + + context.request_processor = [request_processor] + request_property = {accept_key: ACCEPT} + request_processor.get_request_properties.return_value = request_property + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + validate.assert_called_once() + retrieve_content_type_header.assert_called_once_with(request_property) + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) + context.set_response_content_type.assert_called_once_with(0, ACCEPT) + assert isinstance(result, list) + assert result[0] == RESULT + + +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_validate_and_initialize(env, validate_user_module): + transformer = PyTorchTransformer() + + model_fn = Mock() + context = Mock() + transformer._model_fn = model_fn + + assert transformer._initialized is False + assert transformer._context is None + + transformer.validate_and_initialize(context=context) + + assert transformer._initialized is True + assert transformer._context == context + + transformer.validate_and_initialize() + + model_fn.assert_called_once_with(environment.model_dir, context) + env.assert_called_once_with() + validate_user_module.assert_called_once_with() + + +def test_default_transform_fn(): + transformer = PyTorchTransformer() + context = Mock() + transformer._context = context + + input_fn = Mock(return_value=PREPROCESSED_DATA) + predict_fn = Mock(return_value=PREDICT_RESULT) + output_fn = Mock(return_value=PROCESSED_RESULT) + + transformer._input_fn = input_fn + transformer._predict_fn = predict_fn + transformer._output_fn = output_fn + + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) + + input_fn.assert_called_once_with(INPUT_DATA, CONTENT_TYPE, context) + predict_fn.assert_called_once_with(PREPROCESSED_DATA, MODEL, context) + output_fn.assert_called_once_with(PREDICT_RESULT, ACCEPT, context) + assert result == PROCESSED_RESULT + + +def test_run_handle_function(): + def three_inputs_func(a, b, c): pass + + three_inputs_mock = Mock(spec=three_inputs_func) + a = Mock() + b = Mock() + context = Mock() + + transformer = PyTorchTransformer() + transformer._context = context + transformer._run_handle_function(three_inputs_mock, a, b) + three_inputs_mock.assert_called_with(a, b, context) From 1ab82495518e265b2de9639fb9afb744cdb45e5d Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 9 Aug 2022 13:53:16 -0700 Subject: [PATCH 08/12] fix sanity --- test/unit/test_transformer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py index 452c46c8..fc1dbca9 100644 --- a/test/unit/test_transformer.py +++ b/test/unit/test_transformer.py @@ -9,9 +9,11 @@ # or in the 'license' file accompanying this file. This file is # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. -from mock import Mock, patch +from __future__ import absolute_import + +from mock import Mock, patch import pytest from sagemaker_inference import environment @@ -32,7 +34,7 @@ def test_default_transformer(): transformer = PyTorchTransformer() - + assert isinstance(transformer._default_inference_handler, DefaultPytorchInferenceHandler) assert transformer._initialized is False assert transformer._environment is None @@ -140,7 +142,8 @@ def test_default_transform_fn(): def test_run_handle_function(): - def three_inputs_func(a, b, c): pass + def three_inputs_func(a, b, c): + pass three_inputs_mock = Mock(spec=three_inputs_func) a = Mock() From 7fad13e07ba97b8de49b7ad1847d528fe33f774e Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 9 Aug 2022 14:16:51 -0700 Subject: [PATCH 09/12] fix sanity --- test/unit/test_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py index fc1dbca9..042de62c 100644 --- a/test/unit/test_transformer.py +++ b/test/unit/test_transformer.py @@ -9,11 +9,11 @@ # or in the 'license' file accompanying this file. This file is # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. from __future__ import absolute_import -from mock import Mock, patch +from mock import Mock, patch import pytest from sagemaker_inference import environment From 2d6ce947d672d266eb1c5a637de034f73979d429 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 9 Aug 2022 14:34:12 -0700 Subject: [PATCH 10/12] fix test --- test/unit/test_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py index 042de62c..4673cd63 100644 --- a/test/unit/test_transformer.py +++ b/test/unit/test_transformer.py @@ -90,7 +90,7 @@ def test_transform(validate, retrieve_content_type_header, accept_key): validate.assert_called_once() retrieve_content_type_header.assert_called_once_with(request_property) - transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT) + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) context.set_response_content_type.assert_called_once_with(0, ACCEPT) assert isinstance(result, list) assert result[0] == RESULT From 0d293f47b3860883ceb10dcf4ae72b167d9bc2b9 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 9 Aug 2022 16:58:56 -0700 Subject: [PATCH 11/12] add test --- test/unit/test_transformer.py | 152 +++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) diff --git a/test/unit/test_transformer.py b/test/unit/test_transformer.py index 4673cd63..7b304acd 100644 --- a/test/unit/test_transformer.py +++ b/test/unit/test_transformer.py @@ -16,14 +16,21 @@ from mock import Mock, patch import pytest -from sagemaker_inference import environment +try: + import http.client as http_client +except ImportError: + import httplib as http_client + +from sagemaker_inference import content_types, environment from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler +from sagemaker_inference.errors import BaseInferenceToolkitError from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer INPUT_DATA = "input_data" CONTENT_TYPE = "content_type" ACCEPT = "accept" +DEFAULT_ACCEPT = "default_accept" RESULT = "result" MODEL = "foo" @@ -96,6 +103,81 @@ def test_transform(validate, retrieve_content_type_header, accept_key): assert result[0] == RESULT +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform_any_accept(validate, retrieve_content_type_header): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + environment = Mock() + environment.default_accept = DEFAULT_ACCEPT + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": content_types.ANY} + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._environment = environment + transformer._context = context + + transformer.transform(data, context) + + validate.assert_called_once() + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT, context) + + +@pytest.mark.parametrize("content_type", content_types.UTF8_TYPES) +@patch("sagemaker_inference.utils.retrieve_content_type_header") +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform_decode(validate, retrieve_content_type_header, content_type): + input_data = Mock() + context = Mock() + request_processor = Mock() + transform_fn = Mock() + data = [{"body": input_data}] + + input_data.decode.return_value = INPUT_DATA + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": ACCEPT} + retrieve_content_type_header.return_value = content_type + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + transformer.transform(data, context) + + input_data.decode.assert_called_once_with("utf-8") + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, content_type, ACCEPT, context) + + +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer.validate_and_initialize") +def test_transform_tuple(validate, retrieve_content_type_header): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock(return_value=(RESULT, ACCEPT)) + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": ACCEPT} + + transformer = PyTorchTransformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + transform_fn.assert_called_once_with(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) + context.set_response_content_type.assert_called_once_with(0, transform_fn()[1]) + assert isinstance(result, list) + assert result[0] == transform_fn()[0] + + @patch("sagemaker_pytorch_serving_container.transformer.PyTorchTransformer._validate_user_module_and_set_functions") @patch("sagemaker_inference.environment.Environment") def test_validate_and_initialize(env, validate_user_module): @@ -120,6 +202,74 @@ def test_validate_and_initialize(env, validate_user_module): validate_user_module.assert_called_once_with() +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_handle_validate_and_initialize_error(env, validate_user_module): + data = [{"body": INPUT_DATA}] + request_processor = Mock() + + context = Mock() + context.request_processor = [request_processor] + + transform_fn = Mock() + model_fn = Mock() + + transformer = PyTorchTransformer() + + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._model_fn = model_fn + transformer._context = context + + test_error_message = "Foo" + validate_user_module.side_effect = ValueError(test_error_message) + + assert transformer._initialized is False + + response = transformer.transform(data, context) + assert test_error_message in str(response) + assert "Traceback (most recent call last)" in str(response) + context.set_response_status.assert_called_with( + code=http_client.INTERNAL_SERVER_ERROR, phrase=test_error_message + ) + + +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_handle_validate_and_initialize_user_error(env, validate_user_module): + test_status_code = http_client.FORBIDDEN + test_error_message = "Foo" + + class FooUserError(BaseInferenceToolkitError): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.phrase = "Foo" + + data = [{"body": INPUT_DATA}] + context = Mock() + transform_fn = Mock() + model_fn = Mock() + + transformer = PyTorchTransformer() + + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._model_fn = model_fn + transformer._context = context + + validate_user_module.side_effect = FooUserError(test_status_code, test_error_message) + + assert transformer._initialized is False + + response = transformer.transform(data, context) + assert test_error_message in str(response) + assert "Traceback (most recent call last)" in str(response) + context.set_response_status.assert_called_with( + code=http_client.FORBIDDEN, phrase=test_error_message + ) + + def test_default_transform_fn(): transformer = PyTorchTransformer() context = Mock() From ef19cb7537e9380ea32e0a70e3ed1258d71cdacd Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 10 Aug 2022 10:52:05 -0700 Subject: [PATCH 12/12] fix protobuf version --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 3ee46f4e..ee9acbc0 100644 --- a/tox.ini +++ b/tox.ini @@ -70,6 +70,7 @@ deps = six future pyyaml + protobuf <= 3.20.1 #https://exerror.com/typeerror-descriptors-cannot-not-be-created-directly/ [testenv:flake8] basepython = python3