diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 2c2a5a1c90..858ef60113 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -153,7 +153,8 @@ def get_child_process_ids(pid): def get_docker_host(): """Discover remote docker host address (if applicable) or use "localhost" - Use "docker context inspect" to read current docker host endpoint url, + When rootlessDocker is enabled (Cgroup Driver: none), use fixed SageMaker IP. + Otherwise, Use "docker context inspect" to read current docker host endpoint url, url must start with "tcp://" Args: @@ -161,6 +162,27 @@ def get_docker_host(): Returns: docker_host (str): Docker host DNS or IP address """ + # Check if using SageMaker rootless Docker by examining storage driver + try: + cmd = ["docker", "info"] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = process.communicate() + if process.returncode == 0: # Check return code instead of stderr + output_text = output.decode("utf-8") + # Check for rootless Docker by looking at Cgroup Driver + if "Cgroup Driver: none" in output_text: + # log the result of check + logger.warning("RootlessDocker detected (Cgroup Driver: none), returning fixed IP.") + # SageMaker rootless Docker detected - return fixed IP + return "172.17.0.1" + else: + logger.warning( + "RootlessDocker not detected, falling back to remote host IP or localhost." + ) + except subprocess.SubprocessError: + pass + + # Fallback to existing logic for remote Docker hosts cmd = "docker context inspect".split() process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index a9aae53fb2..82e3207266 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -135,6 +135,68 @@ def test_get_docker_host(m_subprocess): assert host == endpoint["result"] +@patch("sagemaker.local.utils.subprocess") +def test_get_docker_host_rootless_docker(m_subprocess): + """Test that rootless Docker is detected and returns fixed IP""" + # Mock docker info process for rootless Docker + info_process_mock = Mock() + info_attrs = {"communicate.return_value": (b"Cgroup Driver: none", b""), "returncode": 0} + info_process_mock.configure_mock(**info_attrs) + m_subprocess.Popen.return_value = info_process_mock + + host = sagemaker.local.utils.get_docker_host() + assert host == "172.17.0.1" + + # Verify docker info was called + m_subprocess.Popen.assert_called_with( + ["docker", "info"], stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE + ) + + +@patch("sagemaker.local.utils.subprocess") +def test_get_docker_host_traditional_docker(m_subprocess): + """Test that traditional Docker falls back to existing logic""" + scenarios = [ + { + "docker_info": b"Cgroup Driver: cgroupfs", + "context_host": "tcp://host:port", + "result": "host", + }, + { + "docker_info": b"Cgroup Driver: cgroupfs", + "context_host": "unix:///var/run/docker.sock", + "result": "localhost", + }, + { + "docker_info": b"Cgroup Driver: cgroupfs", + "context_host": "fd://something", + "result": "localhost", + }, + ] + + for scenario in scenarios: + # Mock docker info process for traditional Docker + info_process_mock = Mock() + info_attrs = {"communicate.return_value": (scenario["docker_info"], b""), "returncode": 0} + info_process_mock.configure_mock(**info_attrs) + + # Mock docker context inspect process + context_return_value = ( + '[\n{\n"Endpoints":{\n"docker":{\n"Host": "%s"}\n}\n}\n]\n' % scenario["context_host"] + ) + context_process_mock = Mock() + context_attrs = { + "communicate.return_value": (context_return_value.encode("utf-8"), None), + "returncode": 0, + } + context_process_mock.configure_mock(**context_attrs) + + m_subprocess.Popen.side_effect = [info_process_mock, context_process_mock] + + host = sagemaker.local.utils.get_docker_host() + assert host == scenario["result"] + + @pytest.mark.parametrize( "json_path, expected", [