Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi gpu support #127

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def _is_model_file(filename):
is_model_file = ext in [".pt", ".pth"]
return is_model_file

def default_model_fn(self, model_dir):
def default_model_fn(self, model_dir, context=None):
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
In other cases, users should provide customized model_fn() in script.

Args:
model_dir: a directory where model is saved.
context: context for the request.

Returns: A PyTorch model.
"""
Expand All @@ -65,7 +66,12 @@ def default_model_fn(self, model_dir):
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
) from e
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if context:
properties = context.system_properties
device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
if not os.path.exists(model_path):
model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)]
Expand All @@ -83,29 +89,35 @@ def default_model_fn(self, model_dir):
model = model.to(device)
return model

def default_input_fn(self, input_data, content_type):
def default_input_fn(self, input_data, content_type, context=None):
"""A default input_fn that can handle JSON, CSV and NPZ formats.

Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
context: context for the request

Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor,
depending if cuda is available.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if context:
properties = context.system_properties
device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np_array = decoder.decode(input_data, content_type)
tensor = torch.FloatTensor(
np_array) if content_type in content_types.UTF8_TYPES else torch.from_numpy(np_array)
return tensor.to(device)

def default_predict_fn(self, data, model):
def default_predict_fn(self, data, model, context=None):
"""A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
Runs prediction on GPU if cuda is available.

Args:
data: input data (torch.Tensor) for prediction deserialized by input_fn
model: PyTorch model loaded in memory by model_fn
context: context for the request

Returns: a prediction
"""
Expand All @@ -118,7 +130,12 @@ def default_predict_fn(self, data, model):
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
output = model(input_data)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if context:
properties = context.system_properties
device = torch.device("cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_data = data.to(device)
model.eval()
Expand Down
28 changes: 18 additions & 10 deletions src/sagemaker_pytorch_serving_container/handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,30 @@
from __future__ import absolute_import

from sagemaker_inference.default_handler_service import DefaultHandlerService
from sagemaker_inference.transformer import Transformer
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler
from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer

import os
import sys

PYTHON_PATH_ENV = "PYTHONPATH"
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"


class HandlerService(DefaultHandlerService):
"""
Handler service that is executed by the model server.

"""Handler service that is executed by the model server.

Determines specific default inference handlers to use based on the type MXNet model being used.
Determines specific default inference handlers to use based on the type pytorch model being used.

This class extends ``DefaultHandlerService``, which define the following:
- The ``handle`` method is invoked for all incoming inference requests to the model server.
- The ``initialize`` method is invoked at model server start up.

Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/custom_service.md

"""

def __init__(self):
self._initialized = False

transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler())
transformer = PyTorchTransformer()
super(HandlerService, self).__init__(transformer=transformer)

def initialize(self, context):
Expand All @@ -48,4 +46,14 @@ def initialize(self, context):
sys.path.append(code_dir)
self._initialized = True

super().initialize(context)
properties = context.system_properties
model_dir = properties.get("model_dir")

# add model_dir/code to python path
code_dir_path = "{}:".format(model_dir + "/code")
if PYTHON_PATH_ENV in os.environ:
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
else:
os.environ[PYTHON_PATH_ENV] = code_dir_path

self._service.validate_and_initialize(model_dir=model_dir, context=context)
141 changes: 141 additions & 0 deletions src/sagemaker_pytorch_serving_container/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import absolute_import

import traceback

from six.moves import http_client
from sagemaker_inference.transformer import Transformer
from sagemaker_inference import content_types, environment, utils
from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError
from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler


class PyTorchTransformer(Transformer):
"""Represents the execution workflow for handling pytorch inference requests
sent to the model server.
"""
def __init__(self, default_inference_handler=DefaultPytorchInferenceHandler()):
super().__init__(default_inference_handler)
self._context = None

def transform(self, data, context):
"""Take a request with input data, deserialize it, make a prediction, and return a
serialized response.
Args:
data (obj): the request data.
context (obj): metadata on the incoming request data.
Returns:
list[obj]: The serialized prediction result wrapped in a list if
inference is successful. Otherwise returns an error message
with the context set appropriately.
"""

try:
properties = context.system_properties
model_dir = properties.get("model_dir")
self.validate_and_initialize(model_dir=model_dir, context=context)

response_list = []
for i in range(len(data)):
input_data = data[i].get("body")

request_processor = context.request_processor[0]

request_property = request_processor.get_request_properties()
content_type = utils.retrieve_content_type_header(request_property)
accept = request_property.get("Accept") or request_property.get("accept")

if not accept or accept == content_types.ANY:
accept = self._environment.default_accept

if content_type in content_types.UTF8_TYPES:
input_data = input_data.decode("utf-8")

result = self._run_handle_function(self._transform_fn, *(self._model, input_data, content_type, accept))

response = result
response_content_type = accept

if isinstance(result, tuple):
# handles tuple for backwards compatibility
response = result[0]
response_content_type = result[1]

context.set_response_content_type(0, response_content_type)

response_list.append(response)

return response_list
except Exception as e: # pylint: disable=broad-except
trace = traceback.format_exc()
if isinstance(e, BaseInferenceToolkitError):
return super().handle_error(context, e, trace)
else:
return super().handle_error(
context,
GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)),
trace,
)

def validate_and_initialize(self, model_dir=environment.model_dir, context=None):
"""Validates the user module against the SageMaker inference contract.
Load the model as defined by the ``model_fn`` to prepare handling predictions.
"""
if not self._initialized:
self._context = context
self._environment = environment.Environment()
self._validate_user_module_and_set_functions()

if self._pre_model_fn is not None:
self._run_handle_function(self._pre_model_fn, *(model_dir, ))

self._model = self._run_handle_function(self._model_fn, *(model_dir, ))

if self._model_warmup_fn is not None:
self._run_handle_function(self._model_warmup_fn, *(model_dir, self._model))

self._initialized = True

def _default_transform_fn(self, model, input_data, content_type, accept):
"""Make predictions against the model and return a serialized response.
This serves as the default implementation of transform_fn, used when the
user has not provided an implementation.
Args:
model (obj): model loaded by model_fn.
input_data (obj): the request data.
content_type (str): the request content type.
accept (str): accept header expected by the client.
Returns:
obj: the serialized prediction result or a tuple of the form
(response_data, content_type)
"""
data = self._run_handle_function(self._input_fn, *(input_data, content_type))
prediction = self._run_handle_function(self._predict_fn, *(data, model))
result = self._run_handle_function(self._output_fn, *(prediction, accept))

return result

def _run_handle_function(self, func, *argv):
"""Wrapper to call the handle function which covers 2 cases:
1. context passed to the handle function
2. context not passed to the handle function
"""
try:
argv_context = argv + (self._context, )
result = func(*argv_context)
except TypeError:
result = func(*argv)

return result
10 changes: 5 additions & 5 deletions test/unit/test_handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@


@patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler')
@patch('sagemaker_inference.transformer.Transformer')
def test_hosting_start(Transformer, DefaultPytorchInferenceHandler):
@patch('sagemaker_pytorch_serving_container.transformer.PyTorchTransformer')
def test_hosting_start(PyTorchTransformer, DefaultPytorchInferenceHandler):
from sagemaker_pytorch_serving_container import handler_service

handler_service.HandlerService()

Transformer.assert_called_with(default_inference_handler=DefaultPytorchInferenceHandler())
PyTorchTransformer.assert_called_with()


@patch('sagemaker_pytorch_serving_container.default_pytorch_inference_handler.DefaultPytorchInferenceHandler')
@patch('sagemaker_inference.transformer.Transformer')
def test_hosting_start_enable_multi_model(Transformer, DefaultPytorchInferenceHandler):
@patch('sagemaker_pytorch_serving_container.transformer.PyTorchTransformer')
def test_hosting_start_enable_multi_model(PyTorchTransformer, DefaultPytorchInferenceHandler):
from sagemaker_pytorch_serving_container import handler_service

context = Mock()
Expand Down
Loading