1010
1111from __future__ import annotations
1212
13+ import asyncio
1314import os
1415import sys
1516import traceback
1617from argparse import ArgumentParser
17- from concurrent import futures
1818from dataclasses import dataclass
1919from typing import (
2020 Any ,
21+ AsyncGenerator ,
2122 Iterable ,
22- Iterator ,
2323)
2424
25- import grpc
26- from grpc import ServicerContext , StatusCode
25+ from grpc import StatusCode , aio , local_server_credentials
26+
27+ from isolate .connections .grpc .definitions import PartialRunResult
2728
2829try :
2930 from isolate import __version__ as agent_version
@@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None):
4950 self ._run_cache : dict [str , Any ] = {}
5051 self ._log = sys .stdout if log_fd is None else os .fdopen (log_fd , "w" )
5152
52- def Run (
53+ async def Run (
5354 self ,
5455 request : definitions .FunctionCall ,
55- context : ServicerContext ,
56- ) -> Iterator [ definitions . PartialRunResult ]:
56+ context : aio . ServicerContext ,
57+ ) -> AsyncGenerator [ PartialRunResult , None ]:
5758 self .log (f"A connection has been established: { context .peer ()} !" )
5859 server_version = os .getenv ("ISOLATE_SERVER_VERSION" ) or "unknown"
5960 self .log (f"Isolate info: server { server_version } , agent { agent_version } " )
@@ -87,7 +88,7 @@ def Run(
8788 )
8889 raise AbortException ("The setup function has thrown an error." )
8990 except AbortException as exc :
90- return self .abort_with_msg (context , exc .message )
91+ self .abort_with_msg (context , exc .message )
9192 else :
9293 assert not was_it_raised
9394 self ._run_cache [cache_key ] = result
@@ -107,7 +108,7 @@ def Run(
107108 stringized_tb ,
108109 )
109110 except AbortException as exc :
110- return self .abort_with_msg (context , exc .message )
111+ self .abort_with_msg (context , exc .message )
111112
112113 def execute_function (
113114 self ,
@@ -195,7 +196,7 @@ def log(self, message: str) -> None:
195196
196197 def abort_with_msg (
197198 self ,
198- context : ServicerContext ,
199+ context : aio . ServicerContext ,
199200 message : str ,
200201 * ,
201202 code : StatusCode = StatusCode .INVALID_ARGUMENT ,
@@ -205,23 +206,26 @@ def abort_with_msg(
205206 return None
206207
207208
208- def create_server (address : str ) -> grpc .Server :
209+ def create_server (address : str ) -> aio .Server :
209210 """Create a new (temporary) gRPC server listening on the given
210211 address."""
211- server = grpc .server (
212- futures .ThreadPoolExecutor (max_workers = 1 ),
213- maximum_concurrent_rpcs = 1 ,
212+ # Use asyncio server so requests can run in the main thread and intercept signals
213+ # There seems to be a weird bug with grpcio that makes subsequent requests fail with
214+ # concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2
215+ # fixes it, even though in practice, we only run one request at a time.
216+ server = aio .server (
217+ maximum_concurrent_rpcs = 2 ,
214218 options = get_default_options (),
215219 )
216220
217221 # Local server credentials allow us to ensure that the
218222 # connection is established by a local process.
219- server_credentials = grpc . local_server_credentials ()
223+ server_credentials = local_server_credentials ()
220224 server .add_secure_port (address , server_credentials )
221225 return server
222226
223227
224- def run_agent (address : str , log_fd : int | None = None ) -> int :
228+ async def run_agent (address : str , log_fd : int | None = None ) -> int :
225229 """Run the agent servicer on the given address."""
226230 server = create_server (address )
227231 servicer = AgentServicer (log_fd = log_fd )
@@ -231,19 +235,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int:
231235 # not have any global side effects.
232236 definitions .register_agent (servicer , server )
233237
234- server .start ()
235- server .wait_for_termination ()
238+ await server .start ()
239+ await server .wait_for_termination ()
236240 return 0
237241
238242
239- def main () -> int :
243+ async def main () -> int :
240244 parser = ArgumentParser ()
241245 parser .add_argument ("address" , type = str )
242246 parser .add_argument ("--log-fd" , type = int )
243247
244248 options = parser .parse_args ()
245- return run_agent (options .address , log_fd = options .log_fd )
249+ return await run_agent (options .address , log_fd = options .log_fd )
246250
247251
248252if __name__ == "__main__" :
249- main ()
253+ asyncio . run ( main () )
0 commit comments