Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle sigterm #130

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crossplane/function/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def update(r: fnv1.Resource, source: dict | structpb.Struct | pydantic.BaseModel
# apiVersion is set to its default value 's3.aws.upbound.io/v1beta2'
# (and not explicitly provided during initialization), it will be
# excluded from the serialized output.
data['apiVersion'] = source.apiVersion
data['kind'] = source.kind
data["apiVersion"] = source.apiVersion
data["kind"] = source.kind
r.resource.update(data)
case structpb.Struct():
# TODO(negz): Use struct_to_dict and update to match other semantics?
Expand Down
14 changes: 13 additions & 1 deletion crossplane/function/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import os
import signal

import grpc
from grpc_reflection.v1alpha import reflection
Expand All @@ -25,6 +26,7 @@
import crossplane.function.proto.v1beta1.run_function_pb2 as fnv1beta1
import crossplane.function.proto.v1beta1.run_function_pb2_grpc as grpcv1beta1

GRACE_PERIOD = 5
SERVICE_NAMES = (
reflection.SERVICE_NAME,
fnv1.DESCRIPTOR.services_by_name["FunctionRunnerService"].full_name,
Expand Down Expand Up @@ -64,6 +66,10 @@ def load_credentials(tls_certs_dir: str) -> grpc.ServerCredentials:
)


async def _stop(server, grace=GRACE_PERIOD):
await server.stop(grace=grace)


def serve(
function: grpcv1.FunctionRunnerService,
address: str,
Expand All @@ -90,6 +96,11 @@ def serve(

server = grpc.aio.server()

signal.signal(
signal.SIGTERM,
lambda _, __: asyncio.create_task(_stop(server)),
)

grpcv1.add_FunctionRunnerServiceServicer_to_server(function, server)
grpcv1beta1.add_FunctionRunnerServiceServicer_to_server(
BetaFunctionRunner(wrapped=function), server
Expand All @@ -116,7 +127,8 @@ async def start():
try:
loop.run_until_complete(start())
finally:
loop.run_until_complete(server.stop(grace=5))
if server._server.is_running():
loop.run_until_complete(server.stop(grace=GRACE_PERIOD))
loop.close()


Expand Down
22 changes: 22 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import dataclasses
import os
import signal
import unittest

import grpc
Expand Down Expand Up @@ -52,6 +55,25 @@ class TestCase:

self.assertEqual(rsp, case.want, "-want, +got")

async def test_sigterm_handling(self) -> None:
async def mock_server():
await server.start()
await asyncio.sleep(1)
self.assertTrue(server._server.is_running(), "Server should be running")
os.kill(os.getpid(), signal.SIGTERM)
await server.wait_for_termination()
self.assertFalse(
server._server.is_running(),
"Server should have been stopped on SIGTERM",
)

server = grpc.aio.server()
signal.signal(
signal.SIGTERM,
lambda _, __: asyncio.create_task(runtime._stop(server)),
)
await mock_server()


class EchoRunner(grpcv1.FunctionRunnerService):
def __init__(self):
Expand Down