diff --git a/crossplane/function/resource.py b/crossplane/function/resource.py index 039af46..c9bfc69 100644 --- a/crossplane/function/resource.py +++ b/crossplane/function/resource.py @@ -45,8 +45,8 @@ def update(r: fnv1.Resource, source: dict | structpb.Struct | pydantic.BaseModel # apiVersion is set to its default value 's3.aws.upbound.io/v1beta2' # (and not explicitly provided during initialization), it will be # excluded from the serialized output. - data['apiVersion'] = source.apiVersion - data['kind'] = source.kind + data["apiVersion"] = source.apiVersion + data["kind"] = source.kind r.resource.update(data) case structpb.Struct(): # TODO(negz): Use struct_to_dict and update to match other semantics? diff --git a/crossplane/function/runtime.py b/crossplane/function/runtime.py index f91f392..3783bbc 100644 --- a/crossplane/function/runtime.py +++ b/crossplane/function/runtime.py @@ -16,6 +16,7 @@ import asyncio import os +import signal import grpc from grpc_reflection.v1alpha import reflection @@ -25,6 +26,7 @@ import crossplane.function.proto.v1beta1.run_function_pb2 as fnv1beta1 import crossplane.function.proto.v1beta1.run_function_pb2_grpc as grpcv1beta1 +GRACE_PERIOD = 5 SERVICE_NAMES = ( reflection.SERVICE_NAME, fnv1.DESCRIPTOR.services_by_name["FunctionRunnerService"].full_name, @@ -64,6 +66,10 @@ def load_credentials(tls_certs_dir: str) -> grpc.ServerCredentials: ) +async def _stop(server, grace=GRACE_PERIOD): + await server.stop(grace=grace) + + def serve( function: grpcv1.FunctionRunnerService, address: str, @@ -90,6 +96,11 @@ def serve( server = grpc.aio.server() + signal.signal( + signal.SIGTERM, + lambda _, __: asyncio.create_task(_stop(server)), + ) + grpcv1.add_FunctionRunnerServiceServicer_to_server(function, server) grpcv1beta1.add_FunctionRunnerServiceServicer_to_server( BetaFunctionRunner(wrapped=function), server @@ -116,7 +127,8 @@ async def start(): try: loop.run_until_complete(start()) finally: - loop.run_until_complete(server.stop(grace=5)) + if server._server.is_running(): + loop.run_until_complete(server.stop(grace=GRACE_PERIOD)) loop.close() diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 26229aa..5c31269 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import dataclasses +import os +import signal import unittest import grpc @@ -52,6 +55,25 @@ class TestCase: self.assertEqual(rsp, case.want, "-want, +got") + async def test_sigterm_handling(self) -> None: + async def mock_server(): + await server.start() + await asyncio.sleep(1) + self.assertTrue(server._server.is_running(), "Server should be running") + os.kill(os.getpid(), signal.SIGTERM) + await server.wait_for_termination() + self.assertFalse( + server._server.is_running(), + "Server should have been stopped on SIGTERM", + ) + + server = grpc.aio.server() + signal.signal( + signal.SIGTERM, + lambda _, __: asyncio.create_task(runtime._stop(server)), + ) + await mock_server() + class EchoRunner(grpcv1.FunctionRunnerService): def __init__(self):