Skip to content

Commit e77e535

Browse files
committed
feat: Allow agent to handle signals for graceful termination
1 parent 76d64d3 commit e77e535

File tree

4 files changed

+232
-23
lines changed

4 files changed

+232
-23
lines changed

src/isolate/connections/grpc/_base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import socket
2+
import subprocess
23
from contextlib import contextmanager
34
from dataclasses import dataclass
45
from pathlib import Path
@@ -23,6 +24,9 @@ class AgentError(Exception):
2324
"""An internal problem caused by (most probably) the agent."""
2425

2526

27+
PROCESS_SHUTDOWN_TIMEOUT = 5 # seconds
28+
29+
2630
@dataclass
2731
class GRPCExecutionBase(EnvironmentConnection):
2832
"""A customizable gRPC-based execution backend."""
@@ -128,9 +132,20 @@ def find_free_port() -> Tuple[str, int]:
128132
with self.start_process(address) as process:
129133
yield address, grpc.local_channel_credentials()
130134
finally:
131-
if process is not None:
132-
# TODO: should we check the status code here?
135+
self.terminate_process(process)
136+
137+
def terminate_process(self, process: Union[None, subprocess.Popen]) -> None:
138+
if process is not None:
139+
try:
140+
print("Terminating the agent process...")
133141
process.terminate()
142+
process.wait(timeout=PROCESS_SHUTDOWN_TIMEOUT)
143+
print("Agent process shutdown gracefully")
144+
except Exception as exc:
145+
print(f"Failed to shutdown the agent process gracefully: {exc}")
146+
if process:
147+
process.kill()
148+
print("Agent process was forcefully killed")
134149

135150
def get_python_cmd(
136151
self,

src/isolate/connections/grpc/agent.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@
1010

1111
from __future__ import annotations
1212

13+
import asyncio
1314
import os
1415
import sys
1516
import traceback
1617
from argparse import ArgumentParser
17-
from concurrent import futures
1818
from dataclasses import dataclass
1919
from typing import (
2020
Any,
21+
AsyncGenerator,
2122
Iterable,
22-
Iterator,
2323
)
2424

25-
import grpc
26-
from grpc import ServicerContext, StatusCode
25+
from grpc import StatusCode, aio, local_server_credentials
26+
27+
from isolate.connections.grpc.definitions import PartialRunResult
2728

2829
try:
2930
from isolate import __version__ as agent_version
@@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None):
4950
self._run_cache: dict[str, Any] = {}
5051
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")
5152

52-
def Run(
53+
async def Run(
5354
self,
5455
request: definitions.FunctionCall,
55-
context: ServicerContext,
56-
) -> Iterator[definitions.PartialRunResult]:
56+
context: aio.ServicerContext,
57+
) -> AsyncGenerator[PartialRunResult, None]:
5758
self.log(f"A connection has been established: {context.peer()}!")
5859
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
5960
self.log(f"Isolate info: server {server_version}, agent {agent_version}")
@@ -87,7 +88,8 @@ def Run(
8788
)
8889
raise AbortException("The setup function has thrown an error.")
8990
except AbortException as exc:
90-
return self.abort_with_msg(context, exc.message)
91+
self.abort_with_msg(context, exc.message)
92+
return
9193
else:
9294
assert not was_it_raised
9395
self._run_cache[cache_key] = result
@@ -107,7 +109,8 @@ def Run(
107109
stringized_tb,
108110
)
109111
except AbortException as exc:
110-
return self.abort_with_msg(context, exc.message)
112+
self.abort_with_msg(context, exc.message)
113+
return
111114

112115
def execute_function(
113116
self,
@@ -195,7 +198,7 @@ def log(self, message: str) -> None:
195198

196199
def abort_with_msg(
197200
self,
198-
context: ServicerContext,
201+
context: aio.ServicerContext,
199202
message: str,
200203
*,
201204
code: StatusCode = StatusCode.INVALID_ARGUMENT,
@@ -205,23 +208,26 @@ def abort_with_msg(
205208
return None
206209

207210

208-
def create_server(address: str) -> grpc.Server:
211+
def create_server(address: str) -> aio.Server:
209212
"""Create a new (temporary) gRPC server listening on the given
210213
address."""
211-
server = grpc.server(
212-
futures.ThreadPoolExecutor(max_workers=1),
213-
maximum_concurrent_rpcs=1,
214+
# Use asyncio server so requests can run in the main thread and intercept signals
215+
# There seems to be a weird bug with grpcio that makes subsequent requests fail with
216+
# concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2
217+
# fixes it, even though in practice, we only run one request at a time.
218+
server = aio.server(
219+
maximum_concurrent_rpcs=2,
214220
options=get_default_options(),
215221
)
216222

217223
# Local server credentials allow us to ensure that the
218224
# connection is established by a local process.
219-
server_credentials = grpc.local_server_credentials()
225+
server_credentials = local_server_credentials()
220226
server.add_secure_port(address, server_credentials)
221227
return server
222228

223229

224-
def run_agent(address: str, log_fd: int | None = None) -> int:
230+
async def run_agent(address: str, log_fd: int | None = None) -> int:
225231
"""Run the agent servicer on the given address."""
226232
server = create_server(address)
227233
servicer = AgentServicer(log_fd=log_fd)
@@ -231,19 +237,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int:
231237
# not have any global side effects.
232238
definitions.register_agent(servicer, server)
233239

234-
server.start()
235-
server.wait_for_termination()
240+
await server.start()
241+
await server.wait_for_termination()
236242
return 0
237243

238244

239-
def main() -> int:
245+
async def main() -> int:
240246
parser = ArgumentParser()
241247
parser.add_argument("address", type=str)
242248
parser.add_argument("--log-fd", type=int)
243249

244250
options = parser.parse_args()
245-
return run_agent(options.address, log_fd=options.log_fd)
251+
return await run_agent(options.address, log_fd=options.log_fd)
246252

247253

248254
if __name__ == "__main__":
249-
main()
255+
asyncio.run(main())

src/isolate/server/server.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import functools
44
import os
5+
import signal
56
import threading
67
import time
78
import traceback
@@ -197,6 +198,7 @@ class IsolateServicer(definitions.IsolateServicer):
197198
bridge_manager: BridgeManager
198199
default_settings: IsolateSettings = field(default_factory=IsolateSettings)
199200
background_tasks: dict[str, RunTask] = field(default_factory=dict)
201+
_shutting_down: bool = field(default=False)
200202

201203
_thread_pool: futures.ThreadPoolExecutor = field(
202204
default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS)
@@ -420,6 +422,16 @@ def Cancel(
420422

421423
return definitions.CancelResponse()
422424

425+
def shutdown(self) -> None:
426+
if self._shutting_down:
427+
return
428+
429+
self._shutting_down = True
430+
print("Shutting down, canceling all tasks...")
431+
self.cancel_tasks()
432+
self._thread_pool.shutdown(wait=True)
433+
print("All tasks canceled.")
434+
423435
def watch_queue_until_completed(
424436
self, queue: Queue, is_completed: Callable[[], bool]
425437
) -> Iterator[definitions.PartialRunResult]:
@@ -584,6 +596,7 @@ def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
584596
def termination() -> None:
585597
if is_run:
586598
print("Stopping server since run is finished")
599+
self.servicer.shutdown()
587600
# Stop the server after the Run task is finished
588601
self.server.stop(grace=0.1)
589602

@@ -610,6 +623,7 @@ def _stop(*args):
610623
# Small sleep to make sure the cancellation is processed
611624
time.sleep(0.1)
612625
print("Stopping server since the task is finished")
626+
self.servicer.shutdown()
613627
self.server.stop(grace=0.1)
614628

615629
# Add a callback which will stop the server
@@ -671,11 +685,20 @@ def main(argv: list[str] | None = None) -> None:
671685
definitions.register_isolate(servicer, server)
672686
health.register_health(HealthServicer(), server)
673687

688+
def handle_termination(*args):
689+
print("Termination signal received, shutting down...")
690+
servicer.shutdown()
691+
server.stop(grace=0.1)
692+
693+
signal.signal(signal.SIGINT, handle_termination)
694+
signal.signal(signal.SIGTERM, handle_termination)
695+
674696
server.add_insecure_port("[::]:50001")
675697
print("Started listening at localhost:50001")
676698

677699
server.start()
678700
server.wait_for_termination()
701+
print("Server shut down")
679702

680703

681704
if __name__ == "__main__":

0 commit comments

Comments
 (0)