Skip to content

Commit d869551

Browse files
author
Wei Chu
committed
fix sanity
1 parent 2a0dbda commit d869551

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def default_model_fn(self, model_dir, context=None):
6666
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
6767
) from e
6868
else:
69-
if context:
69+
if context:
7070
properties = context.system_properties
7171
device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
7272
else:
@@ -100,7 +100,7 @@ def default_input_fn(self, input_data, content_type, context=None):
100100
Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor,
101101
depending if cuda is available.
102102
"""
103-
if context:
103+
if context:
104104
properties = context.system_properties
105105
device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
106106
else:
@@ -130,9 +130,10 @@ def default_predict_fn(self, data, model, context=None):
130130
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
131131
output = model(input_data)
132132
else:
133-
if context:
133+
if context:
134134
properties = context.system_properties
135-
device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
135+
device = torch.device("cuda:" + str(properties.get("gpu_id"))
136+
if torch.cuda.is_available() else "cpu")
136137
else:
137138
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138139
model = model.to(device)

src/sagemaker_pytorch_serving_container/handler_service.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PYTHON_PATH_ENV = "PYTHONPATH"
2323
ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
2424

25+
2526
class HandlerService(DefaultHandlerService):
2627
"""
2728
Handler service that is executed by the model server.

src/sagemaker_pytorch_serving_container/transformer.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from __future__ import absolute_import
1515

16-
import logging
1716
import traceback
1817

1918
from six.moves import http_client
@@ -88,7 +87,7 @@ def transform(self, data, context):
8887
GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)),
8988
trace,
9089
)
91-
90+
9291
def validate_and_initialize(self, model_dir=environment.model_dir, context=None):
9392
"""Validates the user module against the SageMaker inference contract.
9493
Load the model as defined by the ``model_fn`` to prepare handling predictions.
@@ -126,7 +125,7 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
126125
result = self._run_handle_function(self._output_fn, *(prediction, accept))
127126

128127
return result
129-
128+
130129
def _run_handle_function(self, func, *argv):
131130
"""Wrapper to call the handle function which covers 2 cases:
132131
1. context passed to the handle function
@@ -137,5 +136,5 @@ def _run_handle_function(self, func, *argv):
137136
result = func(*argv_context)
138137
except TypeError:
139138
result = func(*argv)
140-
139+
141140
return result

test/utils/file_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
def make_tarfile(script, model, output_path, filename="model.tar.gz", script_path=None):
2020
output_filename = os.path.join(output_path, filename)
2121
with tarfile.open(output_filename, "w:gz") as tar:
22-
if(script_path):
22+
if (script_path):
2323
tar.add(script, arcname=os.path.join(script_path, os.path.basename(script)))
2424
else:
2525
tar.add(script, arcname=os.path.basename(script))

0 commit comments

Comments
 (0)