Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ addopts = [
"--numprocesses=auto",
"--timeout=30"
]
markers = [
"requires_cap_kill: tests that require CAP_KILL Linux capability",
]


[tool.coverage.run]
Expand All @@ -173,6 +176,9 @@ fail_under = 79
"src/openjd/sessions/_scripts/_windows/*.py",
"src/openjd/sessions/_windows*.py"
]
"sys_platform != 'linux'" = [
"src/openjd/sessions/_linux/*.py",
]

[tool.coverage.coverage_conditional_plugin.rules]
# This cannot be empty otherwise coverage-conditional-plugin crashes with:
Expand Down
23 changes: 23 additions & 0 deletions scripts/run_sudo_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,26 @@ if test "${BUILD_ONLY}" == "True"; then
fi

docker run --name test_openjd_sudo --rm ${ARGS} "${CONTAINER_IMAGE_TAG}:latest"

if test "${USE_LDAP}" != "True"; then
# Run capability tests
# First with CAP_KILL in effective and permitted capability sets
docker run --name test_openjd_sudo --user root --rm ${ARGS} "${CONTAINER_IMAGE_TAG}:latest" \
capsh \
--caps='cap_setuid,cap_setgid,cap_setpcap=ep cap_kill=eip' \
--keep=1 \
--user=hostuser \
--addamb=cap_kill \
-- \
-c 'capsh --noamb --caps=cap_kill=ep -- -c "hatch run test --no-cov -m requires_cap_kill"'
# Second with CAP_KILL in permitted capability set but not effective capability set
# this tests that OpenJD will add CAP_KILL to the effective capability set if needed
docker run --name test_openjd_sudo --user root --rm ${ARGS} "${CONTAINER_IMAGE_TAG}:latest" \
capsh \
--caps='cap_setuid,cap_setgid,cap_setpcap=ep cap_kill=eip' \
--keep=1 \
--user=hostuser \
--addamb=cap_kill \
-- \
-c 'capsh --noamb --caps=cap_kill=p -- -c "hatch run test --no-cov -m requires_cap_kill"'
fi
1 change: 1 addition & 0 deletions src/openjd/sessions/_linux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
242 changes: 242 additions & 0 deletions src/openjd/sessions/_linux/_capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

"""This module contains code for interacting with Linux capabilities. The module uses the ctypes
module from the Python standard library to wrap the libcap library.

See https://man7.org/linux/man-pages/man7/capabilities.7.html for details on this Linux kernel
feature.
"""

import ctypes
import os
import sys
from contextlib import contextmanager
from ctypes.util import find_library
from enum import Enum
from functools import cache
from typing import Any, Generator, Tuple, TYPE_CHECKING


from .._logging import LOG


# Capability sets
CAP_EFFECTIVE = 0
CAP_PERMITTED = 1
CAP_INHERITABLE = 2

# Capability bit numbers
CAP_KILL = 5

# Values for cap_flag_value_t arguments
CAP_CLEAR = 0
CAP_SET = 1

cap_flag_t = ctypes.c_int
cap_flag_value_t = ctypes.c_int
cap_value_t = ctypes.c_int


class CapabilitySetType(Enum):
INHERITABLE = CAP_INHERITABLE
PERMITTED = CAP_PERMITTED
EFFECTIVE = CAP_EFFECTIVE


class UserCapHeader(ctypes.Structure):
_fields_ = [
("version", ctypes.c_uint32),
("pid", ctypes.c_int),
]


class UserCapData(ctypes.Structure):
_fields_ = [
("effective", ctypes.c_uint32),
("permitted", ctypes.c_uint32),
("inheritable", ctypes.c_uint32),
]


class Cap(ctypes.Structure):
_fields_ = [
("head", UserCapHeader),
("data", UserCapData),
]


if TYPE_CHECKING:
cap_t = ctypes._Pointer[Cap]
cap_flag_value_ptr = ctypes._Pointer[cap_flag_value_t]
cap_value_ptr = ctypes._Pointer[cap_value_t]
ssize_ptr_t = ctypes._Pointer[ctypes.c_ssize_t]
else:
cap_t = ctypes.POINTER(Cap)
cap_flag_value_ptr = ctypes.POINTER(cap_flag_value_t)
cap_value_ptr = ctypes.POINTER(cap_value_t)
ssize_ptr_t = ctypes.POINTER(ctypes.c_ssize_t)


def _cap_set_err_check(
result: ctypes.c_int,
func: Any,
args: Tuple[Any, ...],
) -> ctypes.c_int:
if result != 0:
errno = ctypes.get_errno()
raise OSError(errno, os.strerror(errno))
return result


def _cap_get_proc_err_check(
result: cap_t,
func: Any,
args: Tuple[cap_t, cap_flag_t, ctypes.c_int, cap_value_ptr, cap_flag_value_t],
) -> cap_t:
if not result:
errno = ctypes.get_errno()
raise OSError(errno, os.strerror(errno))
return result


def _cap_get_flag_errcheck(
result: ctypes.c_int, func: Any, args: Tuple[cap_t, cap_value_t, cap_flag_t, cap_flag_value_ptr]
) -> ctypes.c_int:
if result != 0:
errno = ctypes.get_errno()
raise OSError(errno, os.strerror(errno))
return result


@cache
def _get_libcap() -> ctypes.CDLL:
if not sys.platform.startswith("linux"):
raise OSError(f"libcap is only available on Linux, but found platform: {sys.platform}")

libcap = ctypes.CDLL(find_library("cap"), use_errno=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if libcap is not installed/available? I notice that you had to install the lib into the docker container, for instance. If we get an exception, then that'll probably end up unhandled in the LoggingSubprocess?

Related: We should also update the README to let folk know that having this library available is beneficial.


# https://man7.org/linux/man-pages/man3/cap_set_proc.3.html
libcap.cap_set_proc.restype = ctypes.c_int
libcap.cap_set_proc.argtypes = [
ctypes.POINTER(Cap),
]
libcap.cap_set_proc.errcheck = _cap_set_err_check # type: ignore

# https://man7.org/linux/man-pages/man3/cap_get_proc.3.html
libcap.cap_get_proc.restype = cap_t
libcap.cap_get_proc.argtypes = []
libcap.cap_get_proc.errcheck = _cap_get_proc_err_check # type: ignore

# https://man7.org/linux/man-pages/man3/cap_set_flag.3.html
libcap.cap_set_flag.restype = ctypes.c_int
libcap.cap_set_flag.argtypes = [
cap_t,
cap_flag_t,
ctypes.c_int,
cap_value_ptr,
cap_flag_value_t,
]

# https://man7.org/linux/man-pages/man3/cap_get_flag.3.html
libcap.cap_get_flag.restype = ctypes.c_int
libcap.cap_get_flag.argtypes = (
cap_t,
cap_value_t,
cap_flag_t,
cap_flag_value_ptr,
)
libcap.cap_get_flag.errcheck = _cap_get_flag_errcheck # type: ignore

return libcap


def _has_capability(
*,
caps: cap_t,
capability: int,
capability_set_type: CapabilitySetType,
) -> bool:
libcap = _get_libcap()
flag_value = cap_flag_value_t()
libcap.cap_get_flag(caps, capability, capability_set_type.value, ctypes.byref(flag_value))
return flag_value.value == CAP_SET


@contextmanager
def try_use_cap_kill() -> Generator[bool, None, None]:
"""
A context-manager that attempts to leverage the CAP_KILL Linux capability.

If CAP_KILL is in the current thread's effective set, this context-manager takes no action and
yields True.

If CAP_KILL is not in the effective set but is in the permitted set, the context-manager:
1. adds CAP_KILL to the effective set before entering the context-manager
2. yields True
3. clears CAP_KILL from the effective set when exiting the context-manager

Otherwise, the context-manager does nothing and yields False

Returns:
A context manager that yields a bool. See above for details.
"""
if not sys.platform.startswith("linux"):
raise OSError(f"Only Linux is supported, but platform is {sys.platform}")

libcap = _get_libcap()
caps = libcap.cap_get_proc()

if _has_capability(
caps=caps, capability=CAP_KILL, capability_set_type=CapabilitySetType.EFFECTIVE
):
LOG.debug("CAP_KILL is in the thread's effective set")
# CAP_KILL is already in the effective set
yield True
elif _has_capability(
caps=caps,
capability=CAP_KILL,
capability_set_type=CapabilitySetType.PERMITTED,
):
# CAP_KILL is in the permitted set. We will temporarily add it to the effective set
LOG.debug("CAP_KILL is in the thread's permitted set. Temporarily adding to effective set")
cap_value_arr_t = cap_value_t * 1
cap_value_arr = cap_value_arr_t()
cap_value_arr[0] = CAP_KILL
libcap.cap_set_flag(
caps,
CAP_EFFECTIVE,
1,
cap_value_arr,
CAP_SET,
)
libcap.cap_set_proc(caps)
try:
yield True
finally:
# Clear CAP_KILL from the effective set
LOG.debug("Clearing CAP_KILL from the thread's effective set")
libcap.cap_set_flag(
caps,
CAP_EFFECTIVE,
1,
cap_value_arr,
CAP_CLEAR,
)
libcap.cap_set_proc(caps)
else:
yield False


def main() -> None:
"""A developer debugging entrypoint for testing the try_use_cap_kill() behaviour"""
import logging

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("openjd.sessions").setLevel(logging.DEBUG)

with try_use_cap_kill() as has_cap_kill:
LOG.info("Has CAP_KILL: %s", has_cap_kill)


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions src/openjd/sessions/_os_checker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

import os
import sys

LINUX = "linux"
MACOS = "darwin"
POSIX = "posix"
WINDOWS = "nt"


def is_linux() -> bool:
return sys.platform == LINUX


def is_posix() -> bool:
return os.name == POSIX

Expand Down
16 changes: 2 additions & 14 deletions src/openjd/sessions/_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,18 +376,7 @@ def _run(self, args: Sequence[str], time_limit: Optional[timedelta] = None) -> N
def _generate_command_shell_script(self, args: Sequence[str]) -> str:
"""Generate a shell script for running a command given by the args."""
script = list[str]()
script.append(
(
"#!/bin/sh\n"
"_term() {\n"
" echo 'Caught SIGTERM'\n"
' test "${CHILD_PID:-}" != "" && echo "Sending SIGTERM to ${CHILD_PID}" && kill -s TERM "${CHILD_PID}"\n'
' wait "${CHILD_PID}"\n'
" exit $?\n" # The wait returns the exit code of the waited-for process
"}\n"
"trap _term TERM"
)
)
script.append("#!/bin/sh\n")
if self._os_env_vars:
for name, value in self._os_env_vars.items():
if value is None:
Expand All @@ -398,8 +387,7 @@ def _generate_command_shell_script(self, args: Sequence[str]) -> str:
# Note: Single quotes around the path as it may have spaces, and we don't want to
# process any shell commands in the path.
script.append(f"cd '{self._startup_directory}'")
script.append(shlex.join(args) + " &")
script.append(("CHILD_PID=$!\n" 'wait "$CHILD_PID"\n' "exit $?\n"))
script.append("exec " + shlex.join(args))
return "\n".join(script)

def _materialize_files(
Expand Down
21 changes: 0 additions & 21 deletions src/openjd/sessions/_scripts/_posix/_signal_subprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,14 @@ set -x

PID="$1"
SIG="$2"
SIGNAL_CHILD="${3:-False}"
INCL_SUBPROCS="${4:-False}"

[ -f /bin/kill ] && KILL=/bin/kill
[ ! -n "${KILL:-}" ] && [ -f /usr/bin/kill ] && KILL=/usr/bin/kill

[ -f /bin/pgrep ] && PGREP=/bin/pgrep
[ ! -n "${PGREP:-}" ] && [ -f /usr/bin/pgrep ] && PGREP=/usr/bin/pgrep

if [ ! -n "${KILL:-}" ]
then
echo "ERROR - Could not find the 'kill' command under /bin or /usr/bin. Please install it."
exit 1
fi

if [ ! -n "${PGREP:-}" ]
then
echo "ERROR - Could not find the 'pgrep' command under /bin or /usr/bin. Please install it."
exit 1
fi

if test "${SIGNAL_CHILD}" = "True"
then
PID=$( "${PGREP}" -P "${PID}" )
fi

if test "${INCL_SUBPROCS}" = "True"
then
PID=-"${PID}"
fi

exec "$KILL" -s "$SIG" -- "$PID"
Loading