Skip to content

Commit cbae73f

Browse files
samadwarbencrabtreenavinsonijeniyatDewen Qi
authored
feature: Support for remote docker host (#2864)
Co-authored-by: Ben Crabtree <[email protected]> Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: qidewenwhen <[email protected]> Co-authored-by: Qingzi-Lan <[email protected]> Co-authored-by: Payton Staub <[email protected]> Co-authored-by: EC2 Default User <[email protected]> Co-authored-by: Miyoung <[email protected]> Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent fb50e76 commit cbae73f

File tree

6 files changed

+94
-4
lines changed

6 files changed

+94
-4
lines changed

src/sagemaker/local/entities.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import sagemaker.local.data
2424
from sagemaker.local.image import _SageMakerContainer
25-
from sagemaker.local.utils import copy_directory_structure, move_to_destination
25+
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
2626
from sagemaker.utils import DeferredError, get_config_value
2727

2828
logger = logging.getLogger(__name__)
@@ -295,7 +295,7 @@ def start(self, input_data, output_data, transform_resources, **kwargs):
295295
_wait_for_serving_container(serving_port)
296296

297297
# Get capabilities from Container if needed
298-
endpoint_url = "http://localhost:%s/execution-parameters" % serving_port
298+
endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port)
299299
response, code = _perform_request(endpoint_url)
300300
if code == 200:
301301
execution_parameters = json.loads(response.read())
@@ -607,7 +607,7 @@ def _wait_for_serving_container(serving_port):
607607
i = 0
608608
http = urllib3.PoolManager()
609609

610-
endpoint_url = "http://localhost:%s/ping" % serving_port
610+
endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port)
611611
while True:
612612
i += 5
613613
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:

src/sagemaker/local/local_session.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from botocore.exceptions import ClientError
2222

2323
from sagemaker.local.image import _SageMakerContainer
24+
from sagemaker.local.utils import get_docker_host
2425
from sagemaker.local.entities import (
2526
_LocalEndpointConfig,
2627
_LocalEndpoint,
@@ -448,7 +449,7 @@ def invoke_endpoint(
448449
Returns:
449450
object: Inference for the given input.
450451
"""
451-
url = "http://localhost:%s/invocations" % self.serving_port
452+
url = "http://%s:%d/invocations" % (get_docker_host(), self.serving_port)
452453
headers = {}
453454

454455
if ContentType is not None:

src/sagemaker/local/utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import shutil
1818
import subprocess
19+
import json
1920

2021
from distutils.dir_util import copy_tree
2122
from six.moves.urllib.parse import urlparse
@@ -127,3 +128,27 @@ def get_child_process_ids(pid):
127128
return pids + get_child_process_ids(child_pid)
128129
else:
129130
return []
131+
132+
133+
def get_docker_host():
134+
"""Discover remote docker host address (if applicable) or use "localhost"
135+
136+
Use "docker context inspect" to read current docker host endpoint url,
137+
url must start with "tcp://"
138+
139+
Args:
140+
141+
Returns:
142+
docker_host (str): Docker host DNS or IP address
143+
"""
144+
cmd = "docker context inspect".split()
145+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
146+
output, err = process.communicate()
147+
if err:
148+
return "localhost"
149+
docker_context_string = output.decode("utf-8")
150+
docker_context_host_url = json.loads(docker_context_string)[0]["Endpoints"]["docker"]["Host"]
151+
parsed_url = urlparse(docker_context_host_url)
152+
if parsed_url.hostname and parsed_url.scheme == "tcp":
153+
return parsed_url.hostname
154+
return "localhost"

tests/unit/test_local_entities.py

+25
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,28 @@ def test_local_transform_job_perform_batch_inference(
163163
assert len(output_files) == 2
164164
assert "file1.out" in output_files
165165
assert "file2.out" in output_files
166+
167+
168+
@patch("sagemaker.local.entities._SageMakerContainer", Mock())
169+
@patch("sagemaker.local.entities.get_docker_host")
170+
@patch("sagemaker.local.entities._perform_request")
171+
@patch("sagemaker.local.entities._LocalTransformJob._perform_batch_inference")
172+
def test_start_local_transform_job_from_remote_docker_host(
173+
m_perform_batch_inference, m_perform_request, m_get_docker_host, local_transform_job
174+
):
175+
input_data = {}
176+
output_data = {}
177+
transform_resources = {"InstanceType": "local"}
178+
m_get_docker_host.return_value = "some_host"
179+
perform_request_mock = Mock()
180+
m_perform_request.return_value = (perform_request_mock, 200)
181+
perform_request_mock.read.return_value = '{"BatchStrategy": "SingleRecord"}'
182+
local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model"
183+
local_transform_job.start(input_data, output_data, transform_resources, Environment={})
184+
endpoints = [
185+
"http://%s:%d/ping" % ("some_host", 8080),
186+
"http://%s:%d/execution-parameters" % ("some_host", 8080),
187+
]
188+
calls = m_perform_request.call_args_list
189+
for call, endpoint in zip(calls, endpoints):
190+
assert call[0][0] == endpoint

tests/unit/test_local_session.py

+15
Original file line numberDiff line numberDiff line change
@@ -857,3 +857,18 @@ def test_local_session_download_with_custom_s3_endpoint_url(sagemaker_session_cu
857857
Filename="{}/{}".format(DOWNLOAD_DATA_TESTS_FILES_DIR, "test.csv"),
858858
ExtraArgs=None,
859859
)
860+
861+
862+
@patch("sagemaker.local.local_session.get_docker_host")
863+
@patch("urllib3.PoolManager.request")
864+
def test_invoke_local_endpoint_with_remote_docker_host(
865+
m_request,
866+
m_get_docker_host,
867+
):
868+
m_get_docker_host.return_value = "some_host"
869+
Body = "Body".encode("utf-8")
870+
url = "http://%s:%d/invocations" % ("some_host", 8080)
871+
sagemaker.local.local_session.LocalSagemakerRuntimeClient().invoke_endpoint(
872+
Body, "local_endpoint"
873+
)
874+
m_request.assert_called_with("POST", url, body=Body, preload_content=False, headers={})

tests/unit/test_local_utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,27 @@ def test_get_child_process_ids(m_subprocess):
9292
m_subprocess.Popen.return_value = process_mock
9393
sagemaker.local.utils.get_child_process_ids("pid")
9494
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)
95+
96+
97+
@patch("sagemaker.local.utils.subprocess")
98+
def test_get_docker_host(m_subprocess):
99+
cmd = "docker context inspect".split()
100+
process_mock = Mock()
101+
endpoints = [
102+
{"test": "tcp://host:port", "result": "host"},
103+
{"test": "fd://something", "result": "localhost"},
104+
{"test": "unix://path/to/socket", "result": "localhost"},
105+
{"test": "npipe:////./pipe/foo", "result": "localhost"},
106+
]
107+
for endpoint in endpoints:
108+
return_value = (
109+
'[\n{\n"Endpoints":{\n"docker":{\n"Host": "%s"}\n}\n}\n]\n' % endpoint["test"]
110+
)
111+
attrs = {"communicate.return_value": (return_value.encode("utf-8"), None), "returncode": 0}
112+
process_mock.configure_mock(**attrs)
113+
m_subprocess.Popen.return_value = process_mock
114+
host = sagemaker.local.utils.get_docker_host()
115+
m_subprocess.Popen.assert_called_with(
116+
cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE
117+
)
118+
assert host == endpoint["result"]

0 commit comments

Comments
 (0)