Skip to content

change: When rootlessDocker is enabled, return a fixed SageMaker IP #5236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,36 @@ def get_child_process_ids(pid):
def get_docker_host():
"""Discover remote docker host address (if applicable) or use "localhost"

Use "docker context inspect" to read current docker host endpoint url,
When rootlessDocker is enabled (Cgroup Driver: none), use fixed SageMaker IP.
Otherwise, Use "docker context inspect" to read current docker host endpoint url,
url must start with "tcp://"

Args:

Returns:
docker_host (str): Docker host DNS or IP address
"""
# Check if using SageMaker rootless Docker by examining storage driver
try:
cmd = ["docker", "info"]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if process.returncode == 0: # Check return code instead of stderr
output_text = output.decode("utf-8")
# Check for rootless Docker by looking at Cgroup Driver
if "Cgroup Driver: none" in output_text:
# log the result of check
logger.warning("RootlessDocker detected (Cgroup Driver: none), returning fixed IP.")
# SageMaker rootless Docker detected - return fixed IP
return "172.17.0.1"
else:
logger.warning(
"RootlessDocker not detected, falling back to remote host IP or localhost."
)
except subprocess.SubprocessError:
pass

# Fallback to existing logic for remote Docker hosts
cmd = "docker context inspect".split()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
Expand Down
62 changes: 62 additions & 0 deletions tests/unit/sagemaker/local/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,68 @@ def test_get_docker_host(m_subprocess):
assert host == endpoint["result"]


@patch("sagemaker.local.utils.subprocess")
def test_get_docker_host_rootless_docker(m_subprocess):
"""Test that rootless Docker is detected and returns fixed IP"""
# Mock docker info process for rootless Docker
info_process_mock = Mock()
info_attrs = {"communicate.return_value": (b"Cgroup Driver: none", b""), "returncode": 0}
info_process_mock.configure_mock(**info_attrs)
m_subprocess.Popen.return_value = info_process_mock

host = sagemaker.local.utils.get_docker_host()
assert host == "172.17.0.1"

# Verify docker info was called
m_subprocess.Popen.assert_called_with(
["docker", "info"], stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE
)


@patch("sagemaker.local.utils.subprocess")
def test_get_docker_host_traditional_docker(m_subprocess):
"""Test that traditional Docker falls back to existing logic"""
scenarios = [
{
"docker_info": b"Cgroup Driver: cgroupfs",
"context_host": "tcp://host:port",
"result": "host",
},
{
"docker_info": b"Cgroup Driver: cgroupfs",
"context_host": "unix:///var/run/docker.sock",
"result": "localhost",
},
{
"docker_info": b"Cgroup Driver: cgroupfs",
"context_host": "fd://something",
"result": "localhost",
},
]

for scenario in scenarios:
# Mock docker info process for traditional Docker
info_process_mock = Mock()
info_attrs = {"communicate.return_value": (scenario["docker_info"], b""), "returncode": 0}
info_process_mock.configure_mock(**info_attrs)

# Mock docker context inspect process
context_return_value = (
'[\n{\n"Endpoints":{\n"docker":{\n"Host": "%s"}\n}\n}\n]\n' % scenario["context_host"]
)
context_process_mock = Mock()
context_attrs = {
"communicate.return_value": (context_return_value.encode("utf-8"), None),
"returncode": 0,
}
context_process_mock.configure_mock(**context_attrs)

m_subprocess.Popen.side_effect = [info_process_mock, context_process_mock]

host = sagemaker.local.utils.get_docker_host()
assert host == scenario["result"]


@pytest.mark.parametrize(
"json_path, expected",
[
Expand Down