1010
1111from __future__ import annotations
1212
13+ import asyncio
1314import os
1415import sys
1516import traceback
1617from argparse import ArgumentParser
17- from concurrent import futures
1818from dataclasses import dataclass
1919from typing import (
2020 Any ,
21+ AsyncGenerator ,
2122 Iterable ,
22- Iterator ,
2323)
2424
25- import grpc
26- from grpc import ServicerContext , StatusCode
25+ from grpc import StatusCode , aio , local_server_credentials
26+
27+ from isolate .connections .grpc .definitions import PartialRunResult
2728
2829try :
2930 from isolate import __version__ as agent_version
@@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None):
4950 self ._run_cache : dict [str , Any ] = {}
5051 self ._log = sys .stdout if log_fd is None else os .fdopen (log_fd , "w" )
5152
52- def Run (
53+ async def Run (
5354 self ,
5455 request : definitions .FunctionCall ,
55- context : ServicerContext ,
56- ) -> Iterator [ definitions . PartialRunResult ]:
56+ context : aio . ServicerContext ,
57+ ) -> AsyncGenerator [ PartialRunResult , None ]:
5758 self .log (f"A connection has been established: { context .peer ()} !" )
5859 server_version = os .getenv ("ISOLATE_SERVER_VERSION" ) or "unknown"
5960 self .log (f"Isolate info: server { server_version } , agent { agent_version } " )
@@ -87,7 +88,8 @@ def Run(
8788 )
8889 raise AbortException ("The setup function has thrown an error." )
8990 except AbortException as exc :
90- return self .abort_with_msg (context , exc .message )
91+ self .abort_with_msg (context , exc .message )
92+ return
9193 else :
9294 assert not was_it_raised
9395 self ._run_cache [cache_key ] = result
@@ -107,7 +109,8 @@ def Run(
107109 stringized_tb ,
108110 )
109111 except AbortException as exc :
110- return self .abort_with_msg (context , exc .message )
112+ self .abort_with_msg (context , exc .message )
113+ return
111114
112115 def execute_function (
113116 self ,
@@ -195,7 +198,7 @@ def log(self, message: str) -> None:
195198
196199 def abort_with_msg (
197200 self ,
198- context : ServicerContext ,
201+ context : aio . ServicerContext ,
199202 message : str ,
200203 * ,
201204 code : StatusCode = StatusCode .INVALID_ARGUMENT ,
@@ -205,23 +208,26 @@ def abort_with_msg(
205208 return None
206209
207210
208- def create_server (address : str ) -> grpc .Server :
211+ def create_server (address : str ) -> aio .Server :
209212 """Create a new (temporary) gRPC server listening on the given
210213 address."""
211- server = grpc .server (
212- futures .ThreadPoolExecutor (max_workers = 1 ),
213- maximum_concurrent_rpcs = 1 ,
214+ # Use asyncio server so requests can run in the main thread and intercept signals
215+ # There seems to be a weird bug with grpcio that makes subsequent requests fail with
216+ # concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2
217+ # fixes it, even though in practice, we only run one request at a time.
218+ server = aio .server (
219+ maximum_concurrent_rpcs = 2 ,
214220 options = get_default_options (),
215221 )
216222
217223 # Local server credentials allow us to ensure that the
218224 # connection is established by a local process.
219- server_credentials = grpc . local_server_credentials ()
225+ server_credentials = local_server_credentials ()
220226 server .add_secure_port (address , server_credentials )
221227 return server
222228
223229
224- def run_agent (address : str , log_fd : int | None = None ) -> int :
230+ async def run_agent (address : str , log_fd : int | None = None ) -> int :
225231 """Run the agent servicer on the given address."""
226232 server = create_server (address )
227233 servicer = AgentServicer (log_fd = log_fd )
@@ -231,19 +237,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int:
231237 # not have any global side effects.
232238 definitions .register_agent (servicer , server )
233239
234- server .start ()
235- server .wait_for_termination ()
240+ await server .start ()
241+ await server .wait_for_termination ()
236242 return 0
237243
238244
239- def main () -> int :
245+ async def main () -> int :
240246 parser = ArgumentParser ()
241247 parser .add_argument ("address" , type = str )
242248 parser .add_argument ("--log-fd" , type = int )
243249
244250 options = parser .parse_args ()
245- return run_agent (options .address , log_fd = options .log_fd )
251+ return await run_agent (options .address , log_fd = options .log_fd )
246252
247253
248254if __name__ == "__main__" :
249- main ()
255+ asyncio . run ( main () )
0 commit comments