diff --git a/setup.py b/setup.py index 22cd9d2..6a90303 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def read(fname): # freeze numpy version because of the python2 bug # in 16.0: https://github.com/numpy/numpy/pull/12754 install_requires=['sagemaker-containers==2.5.0', 'chainer==5.0.0', 'retrying==1.3.3', - 'numpy==1.16.2'], + 'numpy==1.16.2', 'requests==2.22.0'], extras_require={ 'test': [ diff --git a/src/sagemaker_chainer_container/serving.py b/src/sagemaker_chainer_container/serving.py index 5292c8d..f9ef2a5 100644 --- a/src/sagemaker_chainer_container/serving.py +++ b/src/sagemaker_chainer_container/serving.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import logging +import requests import chainer import numpy as np @@ -29,6 +30,17 @@ logger.setLevel(logging.DEBUG) +def default_healthcheck_fn(): + """A default healthcheck_fn for Chainer. Sends GET request to model server. + + Returns: + (flask.Response): status code returned by server. + """ + ping_url = "http://0.0.0.0:{}/ping".format(env.http_port) + res = requests.get(ping_url) + return requests.Response(res.status_code) + + def default_input_fn(input_data, content_type): """Takes request data and de-serializes the data into an object for prediction. @@ -117,6 +129,7 @@ def main(environ, start_response): user_module_transformer = _user_module_transformer(user_module) user_module_transformer.initialize() app = worker.Worker(transform_fn=user_module_transformer.transform, - module_name=serving_env.module_name) + module_name=serving_env.module_name, + healthcheck_fn=default_healthcheck_fn) return app(environ, start_response)