-
Notifications
You must be signed in to change notification settings - Fork 8
feat: Allow agent to handle signals for graceful termination #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
e77e535
4961a58
201d4b9
42d04ec
3a3d108
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import os | ||
| import socket | ||
| import subprocess | ||
| from contextlib import contextmanager | ||
|
|
@@ -24,7 +25,9 @@ class AgentError(Exception): | |
| """An internal problem caused by (most probably) the agent.""" | ||
|
|
||
|
|
||
| PROCESS_SHUTDOWN_TIMEOUT = 5 # seconds | ||
| PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float( | ||
| os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60") | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -139,13 +142,11 @@ def terminate_process(self, process: Union[None, subprocess.Popen]) -> None: | |
| try: | ||
| print("Terminating the agent process...") | ||
| process.terminate() | ||
| process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT) | ||
| process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS) | ||
| print("Agent process shutdown gracefully") | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be excessive logging, I can take it out. But it was useful for debugging. |
||
| except Exception as exc: | ||
| print(f"Failed to shutdown the agent process gracefully: {exc}") | ||
| if process: | ||
| process.kill() | ||
| print("Agent process was forcefully killed") | ||
| process.kill() | ||
|
|
||
| def get_python_cmd( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -179,11 +179,17 @@ class RunTask: | |
|
|
||
| def cancel(self): | ||
| while True: | ||
| self.future.cancel() | ||
| # Cancelling a running future is not possible, and it sometimes blocks, | ||
| # which means we never terminate the agent. So check if it's not running | ||
| if self.future and not self.future.running(): | ||
| self.future.cancel() | ||
|
Comment on lines
+184
to
+185
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But if we dont cancel it, then what happens?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in almost all cases, nothing. But there could be rare race conditions where the future that hasn't started yet starts executing after this leading to an orphaned agent process (more likely to happen when server is handling multiple tasks). The chances are quite low, but it's more correct to always cancel imo.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can just do a log about this scenario then |
||
|
|
||
| if self.agent: | ||
| self.agent.terminate() | ||
|
|
||
| try: | ||
| self.future.exception(timeout=0.1) | ||
| if self.future: | ||
| self.future.exception(timeout=0.1) | ||
| return | ||
| except futures.TimeoutError: | ||
| pass | ||
|
|
@@ -424,12 +430,13 @@ def Cancel( | |
|
|
||
| def shutdown(self) -> None: | ||
| if self._shutting_down: | ||
| print("Shutdown already in progress...") | ||
| return | ||
|
|
||
| self._shutting_down = True | ||
| print("Shutting down, canceling all tasks...") | ||
| task_count = len(self.background_tasks) | ||
| print(f"Shutting down, canceling {task_count} tasks...") | ||
| self.cancel_tasks() | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This cancels in sequence, and not parallel but it should be ok because we only run one at a time AFAIK? |
||
| self._thread_pool.shutdown(wait=True) | ||
| print("All tasks canceled.") | ||
|
|
||
| def watch_queue_until_completed( | ||
|
|
@@ -599,6 +606,7 @@ def termination() -> None: | |
| self.servicer.shutdown() | ||
| # Stop the server after the Run task is finished | ||
| self.server.stop(grace=0.1) | ||
| print("Server stopped") | ||
|
|
||
| elif is_submit: | ||
| # Wait until the task_id is assigned | ||
|
|
@@ -625,6 +633,7 @@ def _stop(*args): | |
| print("Stopping server since the task is finished") | ||
| self.servicer.shutdown() | ||
| self.server.stop(grace=0.1) | ||
| print("Server stopped") | ||
|
|
||
| # Add a callback which will stop the server | ||
| # after the task is finished | ||
|
|
@@ -693,8 +702,8 @@ def handle_termination(*args): | |
| signal.signal(signal.SIGINT, handle_termination) | ||
| signal.signal(signal.SIGTERM, handle_termination) | ||
|
|
||
| server.add_insecure_port("[::]:50001") | ||
| print("Started listening at localhost:50001") | ||
| server.add_insecure_port(f"[::]:{options.port}") | ||
| print(f"Started listening at {options.host}:{options.port}") | ||
|
|
||
| server.start() | ||
| server.wait_for_termination() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| """End-to-end tests for graceful shutdown behavior of IsolateServicer.""" | ||
|
|
||
| import functools | ||
| import os | ||
| import signal | ||
| import subprocess | ||
| import sys | ||
| import threading | ||
| import time | ||
| from unittest.mock import Mock | ||
|
|
||
| import grpc | ||
| import pytest | ||
| from isolate.server.definitions.server_pb2 import BoundFunction, EnvironmentDefinition | ||
| from isolate.server.definitions.server_pb2_grpc import IsolateStub | ||
| from isolate.server.interface import to_serialized_object | ||
| from isolate.server.server import BridgeManager, IsolateServicer, RunnerAgent, RunTask | ||
|
|
||
|
|
||
| def create_run_request(func, *args, **kwargs): | ||
| """Convert a Python function into a BoundFunction request for stub.Run().""" | ||
| bound_function = functools.partial(func, *args, **kwargs) | ||
| serialized_function = to_serialized_object(bound_function, method="cloudpickle") | ||
|
|
||
| env_def = EnvironmentDefinition() | ||
| env_def.kind = "local" | ||
|
|
||
| request = BoundFunction() | ||
| request.function.CopyFrom(serialized_function) | ||
| request.environments.append(env_def) | ||
| request.stream_logs = True | ||
|
|
||
| return request | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def servicer(): | ||
| """Create a real IsolateServicer instance for testing.""" | ||
| with BridgeManager() as bridge_manager: | ||
| servicer = IsolateServicer(bridge_manager) | ||
| yield servicer | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def isolate_server_subprocess(monkeypatch): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| """Set up a gRPC server with the IsolateServicer for testing.""" | ||
| # Find a free port | ||
| import socket | ||
|
|
||
| monkeypatch.setenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "2") | ||
|
|
||
| with socket.socket() as s: | ||
| s.bind(("", 0)) | ||
| port = s.getsockname()[1] | ||
|
|
||
| process = subprocess.Popen( | ||
| [ | ||
| sys.executable, | ||
| "-m", | ||
| "isolate.server.server", | ||
| "--single-use", | ||
| "--port", | ||
| str(port), | ||
| ] | ||
| ) | ||
|
|
||
| time.sleep(5) # Wait for server to start | ||
| yield process, port | ||
|
|
||
| # Cleanup | ||
| if process.poll() is None: | ||
| process.terminate() | ||
| process.wait(timeout=10) | ||
|
|
||
|
|
||
| def consume_responses(responses): | ||
| def _consume(): | ||
| try: | ||
| for response in responses: | ||
| pass | ||
| except grpc.RpcError: | ||
| # Expected when connection is closed | ||
| pass | ||
|
|
||
| response_thread = threading.Thread(target=_consume, daemon=True) | ||
| response_thread.start() | ||
|
|
||
|
|
||
| def test_shutdown_with_terminate(servicer): | ||
| task = RunTask(request=Mock(), future=Mock()) | ||
| servicer.background_tasks["TEST_BLOCKING"] = task | ||
| task.agent = RunnerAgent(Mock(), Mock(), Mock(), Mock()) | ||
| task.agent.terminate = Mock(wraps=task.agent.terminate) | ||
| servicer.shutdown() | ||
| task.agent.terminate.assert_called_once() # agent should be terminated | ||
|
|
||
|
|
||
| def test_exit_on_client_close(isolate_server_subprocess): | ||
| """Connect with grpc client, run a task and then close the client.""" | ||
| process, port = isolate_server_subprocess | ||
| channel = grpc.insecure_channel(f"localhost:{port}") | ||
| stub = IsolateStub(channel) | ||
|
|
||
| def fn(): | ||
| import time | ||
|
|
||
| time.sleep(30) # Simulate long-running task | ||
|
|
||
| responses = stub.Run(create_run_request(fn)) | ||
| consume_responses(responses) | ||
|
|
||
| # Give task time to start | ||
| time.sleep(2) | ||
|
|
||
| # there is a running grpc client connected to an isolate servicer which is | ||
| # emitting responses from an agent running a infinite loop | ||
| assert process.poll() is None, "Server should be running while client is connected" | ||
|
|
||
| # Close the channel to simulate client disconnect | ||
| channel.close() | ||
|
|
||
| # Give time for the channel close to propagate and trigger termination | ||
| time.sleep(1.0) | ||
|
|
||
| try: | ||
| # Wait for server process to exit | ||
| process.wait(timeout=5) | ||
| except subprocess.TimeoutExpired: | ||
| raise AssertionError("Server did not shut down after client disconnect") | ||
|
|
||
| assert ( | ||
| process.poll() is not None | ||
| ), "Server should have shut down after client disconnect" | ||
|
|
||
|
|
||
| def test_running_function_receives_sigterm(isolate_server_subprocess, tmp_path): | ||
| """Test that the user provided code receives the SIGTERM""" | ||
| process, port = isolate_server_subprocess | ||
| channel = grpc.insecure_channel(f"localhost:{port}") | ||
| stub = IsolateStub(channel) | ||
|
|
||
| # Send SIGTERM to the current process | ||
| assert process.poll() is None, "Server should be running initially" | ||
|
|
||
| def func_with_sigterm_handler(filepath): | ||
| import os | ||
| import pathlib | ||
| import signal | ||
| import time | ||
|
|
||
| def handle_term(signum, frame): | ||
| print("Received SIGTERM, exiting gracefully...") | ||
| pathlib.Path(filepath).touch() | ||
| os._exit(0) | ||
|
|
||
| signal.signal(signal.SIGTERM, handle_term) | ||
|
|
||
| time.sleep(30) # Simulate long-running task | ||
|
|
||
| sigterm_file_path = tmp_path.joinpath("sigterm_test") | ||
|
|
||
| assert not sigterm_file_path.exists() | ||
|
|
||
| responses = stub.Run( | ||
| create_run_request(func_with_sigterm_handler, str(sigterm_file_path)) | ||
| ) | ||
| consume_responses(responses) | ||
| time.sleep(2) # Give task time to start | ||
|
|
||
| os.kill(process.pid, signal.SIGTERM) | ||
| process.wait(timeout=5) | ||
| assert process.poll() is not None, "Server should have shut down after SIGTERM" | ||
| assert ( | ||
| sigterm_file_path.exists() | ||
| ), "Function should have received SIGTERM and created the file" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very good solution