diff --git a/pyproject.toml b/pyproject.toml index 58797c09..a750e944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,9 @@ addopts = [ "--numprocesses=auto", "--timeout=30" ] +markers = [ + "requires_cap_kill: tests that require CAP_KILL Linux capability", +] [tool.coverage.run] @@ -173,6 +176,9 @@ fail_under = 79 "src/openjd/sessions/_scripts/_windows/*.py", "src/openjd/sessions/_windows*.py" ] +"sys_platform != 'linux'" = [ + "src/openjd/sessions/_linux/*.py", +] [tool.coverage.coverage_conditional_plugin.rules] # This cannot be empty otherwise coverage-conditional-plugin crashes with: diff --git a/scripts/run_sudo_tests.sh b/scripts/run_sudo_tests.sh index d9bc9a47..6bbb5f7c 100755 --- a/scripts/run_sudo_tests.sh +++ b/scripts/run_sudo_tests.sh @@ -67,3 +67,26 @@ if test "${BUILD_ONLY}" == "True"; then fi docker run --name test_openjd_sudo --rm ${ARGS} "${CONTAINER_IMAGE_TAG}:latest" + +if test "${USE_LDAP}" != "True"; then + # Run capability tests + # First with CAP_KILL in effective and permitted capability sets + docker run --name test_openjd_sudo --user root --rm ${ARGS} "${CONTAINER_IMAGE_TAG}:latest" \ + capsh \ + --caps='cap_setuid,cap_setgid,cap_setpcap=ep cap_kill=eip' \ + --keep=1 \ + --user=hostuser \ + --addamb=cap_kill \ + -- \ + -c 'capsh --noamb --caps=cap_kill=ep -- -c "hatch run test --no-cov -m requires_cap_kill"' + # Second with CAP_KILL in permitted capability set but not effective capability set + # this tests that OpenJD will add CAP_KILL to the effective capability set if needed + docker run --name test_openjd_sudo --user root --rm ${ARGS} "${CONTAINER_IMAGE_TAG}:latest" \ + capsh \ + --caps='cap_setuid,cap_setgid,cap_setpcap=ep cap_kill=eip' \ + --keep=1 \ + --user=hostuser \ + --addamb=cap_kill \ + -- \ + -c 'capsh --noamb --caps=cap_kill=p -- -c "hatch run test --no-cov -m requires_cap_kill"' +fi diff --git a/src/openjd/sessions/_linux/__init__.py b/src/openjd/sessions/_linux/__init__.py new file mode 100644 index 00000000..8d929cc8 --- /dev/null +++ b/src/openjd/sessions/_linux/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/src/openjd/sessions/_linux/_capabilities.py b/src/openjd/sessions/_linux/_capabilities.py new file mode 100644 index 00000000..0a7983a1 --- /dev/null +++ b/src/openjd/sessions/_linux/_capabilities.py @@ -0,0 +1,258 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""This module contains code for interacting with Linux capabilities. The module uses the ctypes +module from the Python standard library to wrap the libcap library. + +See https://man7.org/linux/man-pages/man7/capabilities.7.html for details on this Linux kernel +feature. +""" + +import ctypes +import os +import sys +from contextlib import contextmanager +from ctypes.util import find_library +from enum import Enum +from functools import cache +from typing import Any, Generator, Optional, Tuple, TYPE_CHECKING + + +from .._logging import LOG + + +# Capability sets +CAP_EFFECTIVE = 0 +CAP_PERMITTED = 1 +CAP_INHERITABLE = 2 + +# Capability bit numbers +CAP_KILL = 5 + +# Values for cap_flag_value_t arguments +CAP_CLEAR = 0 +CAP_SET = 1 + +cap_flag_t = ctypes.c_int +cap_flag_value_t = ctypes.c_int +cap_value_t = ctypes.c_int + + +class CapabilitySetType(Enum): + INHERITABLE = CAP_INHERITABLE + PERMITTED = CAP_PERMITTED + EFFECTIVE = CAP_EFFECTIVE + + +class UserCapHeader(ctypes.Structure): + _fields_ = [ + ("version", ctypes.c_uint32), + ("pid", ctypes.c_int), + ] + + +class UserCapData(ctypes.Structure): + _fields_ = [ + ("effective", ctypes.c_uint32), + ("permitted", ctypes.c_uint32), + ("inheritable", ctypes.c_uint32), + ] + + +class Cap(ctypes.Structure): + _fields_ = [ + ("head", UserCapHeader), + ("data", UserCapData), + ] + + +if TYPE_CHECKING: + cap_t = ctypes._Pointer[Cap] + cap_flag_value_ptr = ctypes._Pointer[cap_flag_value_t] + cap_value_ptr = ctypes._Pointer[cap_value_t] + ssize_ptr_t = ctypes._Pointer[ctypes.c_ssize_t] +else: + cap_t = ctypes.POINTER(Cap) + cap_flag_value_ptr = ctypes.POINTER(cap_flag_value_t) + cap_value_ptr = ctypes.POINTER(cap_value_t) + ssize_ptr_t = ctypes.POINTER(ctypes.c_ssize_t) + + +def _cap_set_err_check( + result: ctypes.c_int, + func: Any, + args: Tuple[Any, ...], +) -> ctypes.c_int: + if result != 0: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +def _cap_get_proc_err_check( + result: cap_t, + func: Any, + args: Tuple[cap_t, cap_flag_t, ctypes.c_int, cap_value_ptr, cap_flag_value_t], +) -> cap_t: + if not result: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +def _cap_get_flag_errcheck( + result: ctypes.c_int, func: Any, args: Tuple[cap_t, cap_value_t, cap_flag_t, cap_flag_value_ptr] +) -> ctypes.c_int: + if result != 0: + errno = ctypes.get_errno() + raise OSError(errno, os.strerror(errno)) + return result + + +@cache +def _get_libcap() -> Optional[ctypes.CDLL]: + if not sys.platform.startswith("linux"): + raise OSError(f"libcap is only available on Linux, but found platform: {sys.platform}") + + libcap_path = find_library("cap") + if libcap_path is None: + LOG.info( + "Unable to locate libcap. Session action cancelation signals will be sent using sudo" + ) + return None + + libcap = ctypes.CDLL(libcap_path, use_errno=True) + + # https://man7.org/linux/man-pages/man3/cap_set_proc.3.html + libcap.cap_set_proc.restype = ctypes.c_int + libcap.cap_set_proc.argtypes = [ + cap_t, + ] + libcap.cap_set_proc.errcheck = _cap_set_err_check # type: ignore + + # https://man7.org/linux/man-pages/man3/cap_get_proc.3.html + libcap.cap_get_proc.restype = cap_t + libcap.cap_get_proc.argtypes = [] + libcap.cap_get_proc.errcheck = _cap_get_proc_err_check # type: ignore + + # https://man7.org/linux/man-pages/man3/cap_set_flag.3.html + libcap.cap_set_flag.restype = ctypes.c_int + libcap.cap_set_flag.argtypes = [ + cap_t, + cap_flag_t, + ctypes.c_int, + cap_value_ptr, + cap_flag_value_t, + ] + + # https://man7.org/linux/man-pages/man3/cap_get_flag.3.html + libcap.cap_get_flag.restype = ctypes.c_int + libcap.cap_get_flag.argtypes = ( + cap_t, + cap_value_t, + cap_flag_t, + cap_flag_value_ptr, + ) + libcap.cap_get_flag.errcheck = _cap_get_flag_errcheck # type: ignore + + return libcap + + +def _has_capability( + *, + libcap: ctypes.CDLL, + caps: cap_t, + capability: int, + capability_set_type: CapabilitySetType, +) -> bool: + flag_value = cap_flag_value_t() + libcap.cap_get_flag(caps, capability, capability_set_type.value, ctypes.byref(flag_value)) + return flag_value.value == CAP_SET + + +@contextmanager +def try_use_cap_kill() -> Generator[bool, None, None]: + """ + A context-manager that attempts to leverage the CAP_KILL Linux capability. + + If CAP_KILL is in the current thread's effective set, this context-manager takes no action and + yields True. + + If CAP_KILL is not in the effective set but is in the permitted set, the context-manager: + 1. adds CAP_KILL to the effective set before entering the context-manager + 2. yields True + 3. clears CAP_KILL from the effective set when exiting the context-manager + + Otherwise, the context-manager does nothing and yields False + + Returns: + A context manager that yields a bool. See above for details. + """ + if not sys.platform.startswith("linux"): + raise OSError(f"Only Linux is supported, but platform is {sys.platform}") + + libcap = _get_libcap() + # If libcap is not found, we yield False indicating we are not aware of having CAP_KILL + if not libcap: + yield False + return + + caps = libcap.cap_get_proc() + + if _has_capability( + libcap=libcap, + caps=caps, + capability=CAP_KILL, + capability_set_type=CapabilitySetType.EFFECTIVE, + ): + LOG.debug("CAP_KILL is in the thread's effective set") + # CAP_KILL is already in the effective set + yield True + elif _has_capability( + libcap=libcap, + caps=caps, + capability=CAP_KILL, + capability_set_type=CapabilitySetType.PERMITTED, + ): + # CAP_KILL is in the permitted set. We will temporarily add it to the effective set + LOG.debug("CAP_KILL is in the thread's permitted set. Temporarily adding to effective set") + cap_value_arr_t = cap_value_t * 1 + cap_value_arr = cap_value_arr_t() + cap_value_arr[0] = CAP_KILL + libcap.cap_set_flag( + caps, + CAP_EFFECTIVE, + 1, + cap_value_arr, + CAP_SET, + ) + libcap.cap_set_proc(caps) + try: + yield True + finally: + # Clear CAP_KILL from the effective set + LOG.debug("Clearing CAP_KILL from the thread's effective set") + libcap.cap_set_flag( + caps, + CAP_EFFECTIVE, + 1, + cap_value_arr, + CAP_CLEAR, + ) + libcap.cap_set_proc(caps) + else: + yield False + + +def main() -> None: + """A developer debugging entrypoint for testing the try_use_cap_kill() behaviour""" + import logging + + logging.basicConfig(level=logging.DEBUG) + logging.getLogger("openjd.sessions").setLevel(logging.DEBUG) + + with try_use_cap_kill() as has_cap_kill: + LOG.info("Has CAP_KILL: %s", has_cap_kill) + + +if __name__ == "__main__": + main() diff --git a/src/openjd/sessions/_linux/_sudo.py b/src/openjd/sessions/_linux/_sudo.py new file mode 100644 index 00000000..43a1fa8c --- /dev/null +++ b/src/openjd/sessions/_linux/_sudo.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import glob +import os +import sys +import time +from subprocess import Popen, DEVNULL, PIPE, STDOUT, run +from typing import Optional + +from .._logging import LoggerAdapter, LogContent, LogExtraInfo +from .._os_checker import is_posix, is_linux + + +class FindSignalTargetError(Exception): + """Exception when unable to detect the signal target""" + + pass + + +def find_sudo_child_process_group_id( + *, + logger: LoggerAdapter, + sudo_process: Popen, + timeout_seconds: float = 1, +) -> Optional[int]: + # Hint to mypy to not raise module attribute errors (e.g. missing os.getpgid) + if sys.platform == "win32": + raise NotImplementedError("This method is for POSIX hosts only") + if not is_posix(): + raise NotImplementedError(f"Only POSIX supported, but running on {sys.platform}") + if timeout_seconds <= 0: + raise ValueError(f"Expected positive value for timeout_seconds but got {timeout_seconds}") + + # For cross-user support, we use sudo which creates an intermediate process: + # + # openjd-process + # | + # +-- sudo + # | + # +-- subprocess + # + # Sudo forwards signals that it is able to handle, but in the case of SIGKILL sudo cannot + # handle the signal and the kernel will kill it leaving the child orphaned. We need to + # send SIGKILL signals to the subprocess of sudo + start = time.monotonic() + now = start + sudo_pgid = os.getpgid(sudo_process.pid) + + # Repeatedly scan for child processes + # + # This is put in a retry loop, because it takes a non-zero amount of time before sudo and + # the kernel finish creating the subprocess. We cap this because the process may exit + # quickly and we may never find the child process. + sudo_child_pid: Optional[int] = None + sudo_child_pgid: Optional[int] = None + try: + while now - start < timeout_seconds: + if not sudo_child_pid: + if is_linux(): + sudo_child_pid = find_sudo_child_process_id_procfs( + sudo_pid=sudo_process.pid, + logger=logger, + ) + else: + sudo_child_pid = find_child_process_id_pgrep( + sudo_pid=sudo_process.pid, + ) + + if sudo_child_pid: + try: + sudo_child_pgid = os.getpgid(sudo_child_pid) + except ProcessLookupError: + # If the process has exited, we short-circuit + return None + # sudo first forks, then creates a new process group. There is a race condition + # where the process group ID we observe has not yet changed. If the PGID detected + # matches the PGID of sudo, then we retry again in the loop + if sudo_child_pgid == sudo_pgid: + sudo_child_pgid = None + else: + break + + # If we did not find any child processes yet, sleep for some time and retry + time.sleep(min(0.05, timeout_seconds - (now - start))) + now = time.monotonic() + if not sudo_child_pid or not sudo_child_pgid: + raise FindSignalTargetError("unable to detect subprocess before timeout") + except FindSignalTargetError as e: + logger.warning( + f"Unable to determine signal target: {e}", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + + if sudo_child_pgid: + logger.debug( + f"Signal target PGID = {sudo_child_pgid}", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + + return sudo_child_pgid + + +def find_sudo_child_process_id_procfs( + *, + logger: LoggerAdapter, + sudo_pid: int, +) -> Optional[int]: + # Look for the child process of sudo using procfs. See + # https://docs.kernel.org/filesystems/proc.html#proc-pid-task-tid-children-information-about-task-children + + child_pids: set[int] = set() + for task_children_path in glob.glob(f"/proc/{sudo_pid}/task/**/children"): + with open(task_children_path, "r") as f: + child_pids.update(int(pid_str) for pid_str in f.read().split()) + + # If we found exactly one child, we return it + if len(child_pids) == 1: + + child_pid = child_pids.pop() + + logger.debug( + f"Session action process (sudo child) PID is {child_pid}", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + return child_pid + # If we found multiple child processes, this violates our assumptions about how sudo + # works. We will fall-back to using pkill for signalling the process + elif len(child_pids) > 1: + raise FindSignalTargetError( + f"Expected single child processes of sudo, but found {child_pids}" + ) + return None + + +def find_child_process_id_pgrep( + *, + sudo_pid: int, +) -> Optional[int]: + pgrep_result = run( + ["pgrep", "-P", str(sudo_pid)], + stdout=PIPE, + stderr=STDOUT, + stdin=DEVNULL, + text=True, + ) + if pgrep_result.returncode != 0: + raise FindSignalTargetError("Unable to query child processes of sudo process") + results = pgrep_result.stdout.splitlines() + if len(results) > 1: + raise FindSignalTargetError(f"Expected a single child process of sudo, but found {results}") + elif len(results) == 0: + return None + sudo_subproc_pid = int(results[0]) + return sudo_subproc_pid diff --git a/src/openjd/sessions/_os_checker.py b/src/openjd/sessions/_os_checker.py index 6166326d..c42c2dda 100644 --- a/src/openjd/sessions/_os_checker.py +++ b/src/openjd/sessions/_os_checker.py @@ -1,11 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import os +import sys +LINUX = "linux" +MACOS = "darwin" POSIX = "posix" WINDOWS = "nt" +def is_linux() -> bool: + return sys.platform == LINUX + + def is_posix() -> bool: return os.name == POSIX diff --git a/src/openjd/sessions/_runner_base.py b/src/openjd/sessions/_runner_base.py index bced5d4b..17178448 100644 --- a/src/openjd/sessions/_runner_base.py +++ b/src/openjd/sessions/_runner_base.py @@ -376,18 +376,7 @@ def _run(self, args: Sequence[str], time_limit: Optional[timedelta] = None) -> N def _generate_command_shell_script(self, args: Sequence[str]) -> str: """Generate a shell script for running a command given by the args.""" script = list[str]() - script.append( - ( - "#!/bin/sh\n" - "_term() {\n" - " echo 'Caught SIGTERM'\n" - ' test "${CHILD_PID:-}" != "" && echo "Sending SIGTERM to ${CHILD_PID}" && kill -s TERM "${CHILD_PID}"\n' - ' wait "${CHILD_PID}"\n' - " exit $?\n" # The wait returns the exit code of the waited-for process - "}\n" - "trap _term TERM" - ) - ) + script.append("#!/bin/sh\n") if self._os_env_vars: for name, value in self._os_env_vars.items(): if value is None: @@ -398,8 +387,7 @@ def _generate_command_shell_script(self, args: Sequence[str]) -> str: # Note: Single quotes around the path as it may have spaces, and we don't want to # process any shell commands in the path. script.append(f"cd '{self._startup_directory}'") - script.append(shlex.join(args) + " &") - script.append(("CHILD_PID=$!\n" 'wait "$CHILD_PID"\n' "exit $?\n")) + script.append("exec " + shlex.join(args)) return "\n".join(script) def _materialize_files( diff --git a/src/openjd/sessions/_scripts/_posix/_signal_subprocess.sh b/src/openjd/sessions/_scripts/_posix/_signal_subprocess.sh index eaa609e0..75b3565b 100755 --- a/src/openjd/sessions/_scripts/_posix/_signal_subprocess.sh +++ b/src/openjd/sessions/_scripts/_posix/_signal_subprocess.sh @@ -31,35 +31,14 @@ set -x PID="$1" SIG="$2" -SIGNAL_CHILD="${3:-False}" -INCL_SUBPROCS="${4:-False}" [ -f /bin/kill ] && KILL=/bin/kill [ ! -n "${KILL:-}" ] && [ -f /usr/bin/kill ] && KILL=/usr/bin/kill -[ -f /bin/pgrep ] && PGREP=/bin/pgrep -[ ! -n "${PGREP:-}" ] && [ -f /usr/bin/pgrep ] && PGREP=/usr/bin/pgrep - if [ ! -n "${KILL:-}" ] then echo "ERROR - Could not find the 'kill' command under /bin or /usr/bin. Please install it." exit 1 fi -if [ ! -n "${PGREP:-}" ] -then - echo "ERROR - Could not find the 'pgrep' command under /bin or /usr/bin. Please install it." - exit 1 -fi - -if test "${SIGNAL_CHILD}" = "True" -then - PID=$( "${PGREP}" -P "${PID}" ) -fi - -if test "${INCL_SUBPROCS}" = "True" -then - PID=-"${PID}" -fi - exec "$KILL" -s "$SIG" -- "$PID" diff --git a/src/openjd/sessions/_subprocess.py b/src/openjd/sessions/_subprocess.py index 3242669e..ea29011c 100644 --- a/src/openjd/sessions/_subprocess.py +++ b/src/openjd/sessions/_subprocess.py @@ -2,25 +2,29 @@ import os import shlex +import signal +import sys import time -from ._os_checker import is_posix, is_windows - -if is_windows(): - from subprocess import CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW # type: ignore - from ._win32._popen_as_user import PopenWindowsAsUser # type: ignore - from ._windows_process_killer import kill_windows_process_tree +from contextlib import nullcontext +from datetime import timedelta +from pathlib import Path from queue import Queue, Empty -from typing import Any -from threading import Event, Thread -from ._logging import LoggerAdapter, LogContent, LogExtraInfo from subprocess import DEVNULL, PIPE, STDOUT, Popen, list2cmdline, run -from typing import Callable, Optional, Sequence, cast -from pathlib import Path -from datetime import timedelta -import sys +from threading import Event, Thread +from typing import Any +from typing import Callable, Literal, Optional, Sequence, cast +from ._linux._capabilities import try_use_cap_kill +from ._linux._sudo import find_sudo_child_process_group_id +from ._logging import LoggerAdapter, LogContent, LogExtraInfo +from ._os_checker import is_linux, is_posix, is_windows from ._session_user import PosixSessionUser, WindowsSessionUser, SessionUser +if is_windows(): # pragma: nocover + from subprocess import CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW # type: ignore + from ._win32._popen_as_user import PopenWindowsAsUser # type: ignore + from ._windows_process_killer import kill_windows_process_tree + __all__ = ("LoggingSubprocess",) @@ -61,6 +65,7 @@ class LoggingSubprocess(object): _working_dir: Optional[str] _pid: Optional[int] + _sudo_child_process_group_id: Optional[int] _returncode: Optional[int] def __init__( @@ -93,12 +98,11 @@ def __init__( self._has_started = Event() self._pid = None self._returncode = None + self._sudo_child_process_group_id = None @property def pid(self) -> Optional[int]: - if self._pid is not None: - return self._pid - return None + return self._pid @property def exit_code(self) -> Optional[int]: @@ -165,6 +169,17 @@ def run(self) -> None: self._pid = self._process.pid + # Would use is_posix(), but it doesn't short-circuit mypy which then complains + # about os.getpgid not being a valid attribute. + if not sys.platform == "win32": + if not self._user or self._user.is_process_user(): + self._sudo_child_process_group_id = os.getpgid(self._process.pid) + else: + self._sudo_child_process_group_id = find_sudo_child_process_group_id( + logger=self._logger, + sudo_process=self._process, + ) + self._logger.info( f"Command started as pid: {self._process.pid}", extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), @@ -200,7 +215,7 @@ def notify(self) -> None: """ if self._process is not None and self._process.poll() is None: if is_posix(): - self._posix_signal_subprocess(signal="term", signal_subprocesses=False) + self._posix_signal_subprocess(signal_name="term") else: self._windows_notify_subprocess() @@ -216,7 +231,7 @@ def terminate(self) -> None: """ if self._process is not None and self._process.poll() is None: if is_posix(): - self._posix_signal_subprocess(signal="kill", signal_subprocesses=True) + self._posix_signal_subprocess(signal_name="kill") else: self._logger.info( f"INTERRUPT: Start killing the process tree with the root pid: {self._process.pid}", @@ -298,7 +313,7 @@ def _start_subprocess(self) -> Optional[Popen]: ) return None - def _log_subproc_stdout(self): + def _log_subproc_stdout(self) -> None: """ Blocking call which logs the STDOUT of the running subproc until the subprocess exits. @@ -423,10 +438,16 @@ def _tosigned(n: int) -> int: extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), ) - def _posix_signal_subprocess(self, signal: str, signal_subprocesses: bool = False) -> None: - """Send a given named signal, via pkill, to the subprocess when it is running - as a different user than this process. - """ + def _posix_signal_subprocess( + self, + signal_name: Literal["term", "kill"], + ) -> None: + """Send a given named signal to the subprocess.""" + + # Hint to mypy to not raise module attribute errors (e.g. missing os.getpgid) + if sys.platform == "win32": + raise NotImplementedError("This method is for POSIX hosts only") + # We can run into a race condition where the process exits (and another thread sets self._process to None) # before the cancellation happens, so we swap to a local variable to ensure a cancellation that is not needed, # does not raise an exception here. @@ -450,44 +471,119 @@ def _posix_signal_subprocess(self, signal: str, signal_subprocesses: bool = Fals # b. When we run the command using `sudo` then we need to either run code that does the whole # algorithm as the other user, or `sudo` to send every process signal. - cmd = list[str]() - signal_child = False + numeric_signal = 0 + if signal_name == "term": + numeric_signal = signal.SIGTERM + # SIGTERM is the simpler case. In the cross-user sudo case, we can send a signal to the + # sudo process and it will forward the signal. For the same-user case, the subprocess + # is the one we want to signal. + self._logger.info( + f'INTERRUPT: Sending signal "{signal_name}" to process {process.pid}', + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + try: + os.kill(process.pid, numeric_signal) + except OSError: + self._logger.warning( + f"INTERRUPT: Unable to send {signal_name} to {process.pid}", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + return + elif signal_name == "kill": + numeric_signal = signal.SIGKILL + else: + raise NotImplementedError(f"Unsupported signal: {signal_name}") + + kill_cmd = list[str]() if self._user is not None: user = cast(PosixSessionUser, self._user) # Only sudo if the user to run as is not the same as the current user. if not user.is_process_user(): - cmd.extend(["sudo", "-u", user.user, "-i"]) - signal_child = True + kill_cmd = ["sudo", "-u", user.user, "-i"] + + # If we were unable to detect sudo's child process PID after launching the + # subprocess, we try again now + if not self._sudo_child_process_group_id: + self._sudo_child_process_group_id = find_sudo_child_process_group_id( + logger=self._logger, + sudo_process=process, + ) - cmd.extend( + if not self._sudo_child_process_group_id: + self._logger.warning( + f"Failed to send signal '{signal_name}': Unable to determine child process of sudo", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + return + + # Try directly signaling the process(es) first + ctx_mgr = try_use_cap_kill() if is_linux() else nullcontext(enter_result=False) + with ctx_mgr as has_cap_kill: + if has_cap_kill or not self._user or self._user.is_process_user(): + try: + self._logger.info( + f'INTERRUPT: Sending signal "{signal_name}" to process group {self._sudo_child_process_group_id}', + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + os.killpg(self._sudo_child_process_group_id, numeric_signal) + except OSError: + self._logger.info( + "Could not directly send signal {signal_name} to {self._posix_signal_target.pid}, trying sudo.", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + else: + return + else: + self._logger.info( + "Could not directly send signal {signal_name} to {process.pid}, trying sudo.", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + + # Uncomment to visualize process tree when debugging tests + # self._log_process_tree() + + kill_cmd.extend( [ - str(POSIX_SIGNAL_SUBPROC_SCRIPT_PATH), - str(process.pid), - signal, - str(signal_child), - str(signal_subprocesses), + "kill", + "-s", + signal_name, + "--", + f"-{self._sudo_child_process_group_id}", ] ) self._logger.info( - f"INTERRUPT: Running: {shlex.join(cmd)}", + f"INTERRUPT: Running: {shlex.join(kill_cmd)}", extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), ) result = run( - cmd, + kill_cmd, stdout=PIPE, stderr=STDOUT, stdin=DEVNULL, ) if result.returncode != 0: self._logger.warning( - f"Failed to send signal '{signal}' to subprocess {process.pid}: %s", + f"Failed to send signal '{signal_name}' to PGID {self._sudo_child_process_group_id}: %s", result.stdout.decode("utf-8"), extra=LogExtraInfo( openjd_log_content=LogContent.PROCESS_CONTROL | LogContent.EXCEPTION_INFO ), ) + def _log_process_tree(self) -> None: + """A developer method to visualize the process tree including PIDs and PGIDs when debuging tests""" + pstree_result = run(["pstree", "-pg"], stdout=PIPE, stderr=STDOUT, stdin=DEVNULL, text=True) + self._logger.debug( + f"pstree -pg output: {pstree_result.stdout}", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + ps_result = run(["ps", "-ejH"], stdout=PIPE, stderr=STDOUT, stdin=DEVNULL, text=True) + self._logger.debug( + f"ps -ejH output:\n{ps_result.stdout}", + extra=LogExtraInfo(openjd_log_content=LogContent.PROCESS_CONTROL), + ) + def _windows_notify_subprocess(self) -> None: """Sends a CTRL_BREAK_EVENT signal to the subprocess""" # Convince the type checker that accessing _process is okay diff --git a/test/openjd/sessions/conftest.py b/test/openjd/sessions/conftest.py index 28bec725..659e0dab 100644 --- a/test/openjd/sessions/conftest.py +++ b/test/openjd/sessions/conftest.py @@ -36,6 +36,25 @@ POSIX_SET_DISJOINT_USER_ENV_VARS_MESSAGE = f"Must define environment vars {POSIX_DISJOINT_USER_ENV_VAR} and {POSIX_DISJOINT_GROUP_ENV_VAR} to run target-user impersonation tests on posix." +def pytest_collection_modifyitems(config, items): + """This is a pytest hook that provides a default mark expression if one was not provided. By + default, we want to de-select tests that require the CAP_KILL Linux capability. + + Those tests should only be selected when running the Docker container test workflow + described in DEVELOPMENT.md which grant the necessary capabilities and specify a + mark expression. + + See: + - https://docs.pytest.org/en/8.3.x/reference/reference.html#pytest.hookspec.pytest_collection_modifyitems + - https://docs.pytest.org/en/8.3.x/reference/reference.html#command-line-flags + """ + mark_expr = config.getoption("markexpr", False) + if not mark_expr: + config.option.markexpr = "not requires_cap_kill" + else: + config.option.markexpr = mark_expr + + def build_logger(handler: QueueHandler) -> LoggerAdapter: charset = string.ascii_letters + string.digits + string.punctuation name_suffix = "".join(random.choices(charset, k=32)) diff --git a/test/openjd/sessions/support_files/output_signal_sender.c b/test/openjd/sessions/support_files/output_signal_sender.c new file mode 100644 index 00000000..87a0adc1 --- /dev/null +++ b/test/openjd/sessions/support_files/output_signal_sender.c @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +// This is a minimal C program that sleeps until it receives a SIGTERM signal +// and outputs the process ID of the sender. + +#include +#include +#include +#include +#include +#include +#include +#include + +static bool received_signal = false; + +static void signal_handler(int sig, siginfo_t *siginfo, void *context) { + // Output PID of the sending process + int pid_of_sending_process = (int) siginfo->si_pid; + printf("%d\n", pid_of_sending_process); + + // Tell main loop to exit + received_signal = true; +} + +int main(int argc, char *argv[]) { + // register signal handler + struct sigaction signal_action; + signal_action.sa_sigaction = *signal_handler; + // get details about the signal + signal_action.sa_flags |= SA_SIGINFO; + if(sigaction(SIGTERM, &signal_action, NULL) != 0) { + printf("Could not register signal handler\n"); + return errno; + } + + // sleep until SIGINT received + while(!received_signal) { + sleep(1); + } + + return 0; +} diff --git a/test/openjd/sessions/test_runner_base.py b/test/openjd/sessions/test_runner_base.py index 55bab0eb..03037874 100644 --- a/test/openjd/sessions/test_runner_base.py +++ b/test/openjd/sessions/test_runner_base.py @@ -787,6 +787,100 @@ def test_cancel_notify( delta_t = time_end - now assert timedelta(seconds=1) < delta_t < timedelta(seconds=3) + @pytest.mark.skipif(not is_posix(), reason="posix-only test") + @pytest.mark.xfail( + not has_posix_target_user(), + reason=POSIX_SET_TARGET_USER_ENV_VARS_MESSAGE, + ) + @pytest.mark.requires_cap_kill + def test_cancel_notify_direct_signal_with_cap_kill( + self, + tmp_path: Path, + message_queue: SimpleQueue, + queue_handler: QueueHandler, + ) -> None: + # Test for Linux hosts, that when CAP_KILL is in the permitted (and possibly effective) + # capability set(s), that the runner will: + # 1. directly signal the subprocess to notify + # 2. retain the status of CAP_KILL in the thread's effective capability set + + # GIVEN + logger = build_logger(queue_handler) + + from openjd.sessions._linux._capabilities import ( + _has_capability, + _get_libcap, + CAP_KILL, + CapabilitySetType, + ) + + # Record whether CAP_KILL is in the effective capability set before + # notifying the subprocess + libcap = _get_libcap() + assert libcap is not None, "Libcap not found" + caps = libcap.cap_get_proc() + cap_kill_was_effective = _has_capability( + libcap=libcap, + caps=caps, + capability=CAP_KILL, + capability_set_type=CapabilitySetType.EFFECTIVE, + ) + + with NotifyingRunner(logger=logger, session_working_directory=tmp_path) as runner: + # Path to compiled C program that outputs the PID of the process sending the signal + output_signal_sender_app_loc = ( + Path(__file__).parent / "support_files" / "output_signal_sender" + ).resolve() + assert output_signal_sender_app_loc.exists(), "output_signal_sender is not compiled." + runner._run([str(output_signal_sender_app_loc)]) + + # WHEN + secs = 2 if not is_windows() else 5 + time.sleep(secs) # Give the process a little time to do something + runner.cancel(time_limit=timedelta(seconds=2)) + + # THEN + for _ in range(10): + if runner.state == ScriptRunnerState.CANCELED: + break + time.sleep(1) + else: + # Terminate the subprocess + runner.cancel() + assert False, "output_signal_sender process did not exit when sent SIGTERM" + assert runner.exit_code == 0 + + # Collect stdout lines. Based on the code of output_signal_sender.c, only a single + # line should be output with the PID of the process that sent the SIGINT signal. + # Extracting the log line requires finding the preceeding log line emitted by the runner, + # then taking the following line and parsing it as an integer + messages = collect_queue_messages(message_queue) + for i, message in enumerate(messages): + if message.startswith('INTERRUPT: Sending signal "term" to process '): + break + else: + assert False, "Could not find log line before stdout" + pid_line = messages[i + 1] + signal_sender_pid = int(pid_line) + + current_pid = os.getpid() + assert ( + current_pid == signal_sender_pid + ), "The runner's subprocess was not directly signalled" + + # Assert that the presence/absence of CAP_KILL in the effective capability set + # is unchanged after calling Runner.cancel() + caps = libcap.cap_get_proc() + cap_kill_effective_after_cancel = _has_capability( + libcap=libcap, + caps=caps, + capability=CAP_KILL, + capability_set_type=CapabilitySetType.EFFECTIVE, + ) + assert ( + cap_kill_was_effective == cap_kill_effective_after_cancel + ), "CAP_KILL added/removed from effetive set and persisted after cancelation" + @pytest.mark.usefixtures("message_queue", "queue_handler") def test_cancel_double_cancel_notify( self, diff --git a/testing_containers/localuser_sudo_environment/Dockerfile b/testing_containers/localuser_sudo_environment/Dockerfile index 5ef00112..140f8da2 100644 --- a/testing_containers/localuser_sudo_environment/Dockerfile +++ b/testing_containers/localuser_sudo_environment/Dockerfile @@ -20,7 +20,7 @@ ENV PIP_INDEX_URL=$PIP_INDEX_URL # hostuser: hostuser, sharedgroup # targetuser: targetuser, sharedgroup # disjointuser: disjointuser, disjointgroup -RUN apt-get update && apt-get install -y sudo && \ +RUN apt-get update && apt-get install -y gcc libcap2-bin psmisc sudo && \ # Clean up apt cache rm -rf /var/lib/apt/lists/* && \ apt-get clean && \ @@ -39,6 +39,8 @@ USER hostuser COPY --chown=hostuser:hostuser . /code/ WORKDIR /code -RUN hatch env create +RUN hatch env create && \ + # compile the output_signal_sender program which outputs the PID of a process sending a SIGTERM signal \ + gcc -Wall /code/test/openjd/sessions/support_files/output_signal_sender.c -o /code/test/openjd/sessions/support_files/output_signal_sender CMD ["hatch", "run", "test", "--no-cov"] \ No newline at end of file