Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/isolate/connections/grpc/_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import socket
import subprocess
from contextlib import contextmanager
Expand All @@ -24,7 +25,9 @@ class AgentError(Exception):
"""An internal problem caused by (most probably) the agent."""


PROCESS_SHUTDOWN_TIMEOUT = 5 # seconds
PROCESS_SHUTDOWN_TIMEOUT_SECONDS = float(
os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "60")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good solution

)


@dataclass
Expand Down Expand Up @@ -139,13 +142,11 @@ def terminate_process(self, process: Union[None, subprocess.Popen]) -> None:
try:
print("Terminating the agent process...")
process.terminate()
process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT)
process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT_SECONDS)
print("Agent process shutdown gracefully")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be excessive logging, I can take it out. But it was useful for debugging.

except Exception as exc:
print(f"Failed to shutdown the agent process gracefully: {exc}")
if process:
process.kill()
print("Agent process was forcefully killed")
process.kill()

def get_python_cmd(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Iterable,
)

Expand Down Expand Up @@ -54,7 +54,7 @@ async def Run(
self,
request: definitions.FunctionCall,
context: aio.ServicerContext,
) -> AsyncGenerator[PartialRunResult, None]:
) -> AsyncIterator[PartialRunResult]:
self.log(f"A connection has been established: {context.peer()}!")
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
self.log(f"Isolate info: server {server_version}, agent {agent_version}")
Expand Down
21 changes: 15 additions & 6 deletions src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,17 @@ class RunTask:

def cancel(self):
while True:
self.future.cancel()
# Cancelling a running future is not possible, and it sometimes blocks,
# which means we never terminate the agent. So check if it's not running
if self.future and not self.future.running():
self.future.cancel()
Comment on lines +184 to +185
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we dont cancel it, then what happens?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in almost all cases, nothing. But there could be rare race conditions where the future that hasn't started yet starts executing after this leading to an orphaned agent process (more likely to happen when server is handling multiple tasks).

The chances are quite low, but it's more correct to always cancel imo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can just do a log about this scenario then


if self.agent:
self.agent.terminate()

try:
self.future.exception(timeout=0.1)
if self.future:
self.future.exception(timeout=0.1)
return
except futures.TimeoutError:
pass
Expand Down Expand Up @@ -424,12 +430,13 @@ def Cancel(

def shutdown(self) -> None:
if self._shutting_down:
print("Shutdown already in progress...")
return

self._shutting_down = True
print("Shutting down, canceling all tasks...")
task_count = len(self.background_tasks)
print(f"Shutting down, canceling {task_count} tasks...")
self.cancel_tasks()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cancels in sequence, and not parallel but it should be ok because we only run one at a time AFAIK?

self._thread_pool.shutdown(wait=True)
print("All tasks canceled.")

def watch_queue_until_completed(
Expand Down Expand Up @@ -599,6 +606,7 @@ def termination() -> None:
self.servicer.shutdown()
# Stop the server after the Run task is finished
self.server.stop(grace=0.1)
print("Server stopped")

elif is_submit:
# Wait until the task_id is assigned
Expand All @@ -625,6 +633,7 @@ def _stop(*args):
print("Stopping server since the task is finished")
self.servicer.shutdown()
self.server.stop(grace=0.1)
print("Server stopped")

# Add a callback which will stop the server
# after the task is finished
Expand Down Expand Up @@ -693,8 +702,8 @@ def handle_termination(*args):
signal.signal(signal.SIGINT, handle_termination)
signal.signal(signal.SIGTERM, handle_termination)

server.add_insecure_port("[::]:50001")
print("Started listening at localhost:50001")
server.add_insecure_port(f"[::]:{options.port}")
print(f"Started listening at {options.host}:{options.port}")

server.start()
server.wait_for_termination()
Expand Down
179 changes: 179 additions & 0 deletions tests/test_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""End-to-end tests for graceful shutdown behavior of IsolateServicer."""

import functools
import os
import signal
import subprocess
import sys
import threading
import time
from unittest.mock import Mock

import grpc
import pytest
from isolate.server.definitions.server_pb2 import BoundFunction, EnvironmentDefinition
from isolate.server.definitions.server_pb2_grpc import IsolateStub
from isolate.server.interface import to_serialized_object
from isolate.server.server import BridgeManager, IsolateServicer, RunnerAgent, RunTask


def create_run_request(func, *args, **kwargs):
"""Convert a Python function into a BoundFunction request for stub.Run()."""
bound_function = functools.partial(func, *args, **kwargs)
serialized_function = to_serialized_object(bound_function, method="cloudpickle")

env_def = EnvironmentDefinition()
env_def.kind = "local"

request = BoundFunction()
request.function.CopyFrom(serialized_function)
request.environments.append(env_def)
request.stream_logs = True

return request


@pytest.fixture
def servicer():
"""Create a real IsolateServicer instance for testing."""
with BridgeManager() as bridge_manager:
servicer = IsolateServicer(bridge_manager)
yield servicer


@pytest.fixture
def isolate_server_subprocess(monkeypatch):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

"""Set up a gRPC server with the IsolateServicer for testing."""
# Find a free port
import socket

monkeypatch.setenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "2")

with socket.socket() as s:
s.bind(("", 0))
port = s.getsockname()[1]

process = subprocess.Popen(
[
sys.executable,
"-m",
"isolate.server.server",
"--single-use",
"--port",
str(port),
]
)

time.sleep(5) # Wait for server to start
yield process, port

# Cleanup
if process.poll() is None:
process.terminate()
process.wait(timeout=10)


def consume_responses(responses):
def _consume():
try:
for response in responses:
pass
except grpc.RpcError:
# Expected when connection is closed
pass

response_thread = threading.Thread(target=_consume, daemon=True)
response_thread.start()


def test_shutdown_with_terminate(servicer):
task = RunTask(request=Mock(), future=Mock())
servicer.background_tasks["TEST_BLOCKING"] = task
task.agent = RunnerAgent(Mock(), Mock(), Mock(), Mock())
task.agent.terminate = Mock(wraps=task.agent.terminate)
servicer.shutdown()
task.agent.terminate.assert_called_once() # agent should be terminated


def test_exit_on_client_close(isolate_server_subprocess):
"""Connect with grpc client, run a task and then close the client."""
process, port = isolate_server_subprocess
channel = grpc.insecure_channel(f"localhost:{port}")
stub = IsolateStub(channel)

def fn():
import time

time.sleep(30) # Simulate long-running task

responses = stub.Run(create_run_request(fn))
consume_responses(responses)

# Give task time to start
time.sleep(2)

# there is a running grpc client connected to an isolate servicer which is
# emitting responses from an agent running a infinite loop
assert process.poll() is None, "Server should be running while client is connected"

# Close the channel to simulate client disconnect
channel.close()

# Give time for the channel close to propagate and trigger termination
time.sleep(1.0)

try:
# Wait for server process to exit
process.wait(timeout=5)
except subprocess.TimeoutExpired:
raise AssertionError("Server did not shut down after client disconnect")

assert (
process.poll() is not None
), "Server should have shut down after client disconnect"


def test_running_function_receives_sigterm(isolate_server_subprocess, tmp_path):
"""Test that the user provided code receives the SIGTERM"""
process, port = isolate_server_subprocess
channel = grpc.insecure_channel(f"localhost:{port}")
stub = IsolateStub(channel)

# Send SIGTERM to the current process
assert process.poll() is None, "Server should be running initially"

def func_with_sigterm_handler(filepath):
import os
import pathlib
import signal
import time

def handle_term(signum, frame):
print("Received SIGTERM, exiting gracefully...")
pathlib.Path(filepath).touch()
os._exit(0)

signal.signal(signal.SIGTERM, handle_term)

time.sleep(30) # Simulate long-running task

sigterm_file_path = tmp_path.joinpath("sigterm_test")

assert not sigterm_file_path.exists()

responses = stub.Run(
create_run_request(func_with_sigterm_handler, str(sigterm_file_path))
)
consume_responses(responses)
time.sleep(2) # Give task time to start

os.kill(process.pid, signal.SIGTERM)
process.wait(timeout=5)
assert process.poll() is not None, "Server should have shut down after SIGTERM"
assert (
sigterm_file_path.exists()
), "Function should have received SIGTERM and created the file"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
18 changes: 1 addition & 17 deletions tools/isolate_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import argparse
import os
import sys

import grpc
Expand All @@ -13,24 +12,9 @@


def func_to_submit() -> str:
import pathlib
import signal
import time

path = pathlib.Path("sigterm_received")
if path.exists():
os.unlink(path)

def handle_sigterm(signum, frame):
print("Received SIGTERM, exiting gracefully...")
path.touch()
sys.exit(0)

try:
signal.signal(signal.SIGTERM, handle_sigterm)
except Exception as e:
print(f"Failed to set signal handler: {e}")

print("Task started, sleeping for 10 seconds...")
time.sleep(10)
return "hello"

Expand Down
Loading