Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement connection checks using PING frame. Task #62 #94

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion grpclib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ def __init__(
codec: Optional[CodecBase] = None,
status_details_codec: Optional[StatusDetailsCodecBase] = None,
ssl: Union[None, bool, '_ssl.SSLContext'] = None,
ping_delay: float = 1,
ping_timeout: float = 10,
):
"""Initialize connection to the server

Expand Down Expand Up @@ -600,12 +602,17 @@ def __init__(

self.__dispatch__ = _DispatchChannelEvents()

self._ping_delay = ping_delay
self._ping_timeout = ping_timeout

def __repr__(self) -> str:
return ('Channel({!r}, {!r}, ..., path={!r})'
.format(self._host, self._port, self._path))

def _protocol_factory(self) -> H2Protocol:
return H2Protocol(Handler(), self._config, loop=self._loop)
return H2Protocol(Handler(), self._config, loop=self._loop,
ping_delay=self._ping_delay,
ping_timeout=self._ping_timeout)

async def _create_connection(self) -> H2Protocol:
if self._path is not None:
Expand Down
50 changes: 47 additions & 3 deletions grpclib/protocol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import socket
import logging
import struct
import time

from io import BytesIO
from abc import ABC, abstractmethod
from typing import Optional, List, Tuple, Dict, NamedTuple, Callable
from typing import cast, TYPE_CHECKING
from asyncio import Transport, Protocol, Event, AbstractEventLoop, BaseTransport
from asyncio import Queue
from asyncio import Transport, Protocol, Event, AbstractEventLoop, \
BaseTransport, TimerHandle, Queue
from functools import partial
from collections import deque

Expand Down Expand Up @@ -178,12 +180,18 @@ def __init__(
self,
connection: H2Connection,
transport: Transport,
ping_delay: float = 1,
ping_timeout: float = 10,
*,
loop: AbstractEventLoop,
) -> None:
self._connection = connection
self._transport = transport
self._loop = loop
self._ping_delay = ping_delay
self._ping_timeout = ping_timeout
self._ping_handle: Optional[TimerHandle] = None
self._close_by_ping_timeout_handle: Optional[TimerHandle] = None

self.write_ready = Event(loop=self._loop)
self.write_ready.set()
Expand Down Expand Up @@ -219,13 +227,44 @@ def flush(self) -> None:
self._transport.write(data)

def close(self) -> None:
if self._close_by_ping_timeout_handle:
self._close_by_ping_timeout_handle.cancel()

if self._ping_handle:
self._ping_handle.cancel()

if hasattr(self, '_transport'):
self._transport.close()
# remove cyclic references to improve memory usage
del self._transport
if hasattr(self._connection, '_frame_dispatch_table'):
del self._connection._frame_dispatch_table

def _ping(self) -> None:
data = struct.pack('!Q', int(time.monotonic() * 10 ** 6))
self._connection.ping(data)
self.flush()
self._ping_handle = self._loop.call_later(
self._ping_delay,
self._ping
)

def initialize(self) -> None:
if self._ping_timeout > 0 and self._ping_delay > 0:
self._ping()
self._close_by_ping_timeout_handle = self._loop.call_later(
self._ping_timeout,
self.close
)

def pong_process(self) -> None:
if self._close_by_ping_timeout_handle:
self._close_by_ping_timeout_handle.cancel()
self._close_by_ping_timeout_handle = self._loop.call_later(
self._ping_timeout,
self.close
)


_Headers = List[Tuple[str, str]]

Expand Down Expand Up @@ -589,17 +628,20 @@ def process_ping_received(self, event: PingReceived) -> None:
pass

def process_ping_ack_received(self, event: PingAckReceived) -> None:
pass
self.connection.pong_process()


class H2Protocol(Protocol):
connection: Connection
processor: EventsProcessor

def __init__(self, handler: AbstractHandler, config: H2Configuration,
ping_delay: float = 1, ping_timeout: float = 10,
*, loop: AbstractEventLoop) -> None:
self.handler = handler
self.config = config
self._ping_delay = ping_delay
self._ping_timeout = ping_timeout
self.loop = loop

def connection_made(self, transport: BaseTransport) -> None:
Expand All @@ -611,7 +653,9 @@ def connection_made(self, transport: BaseTransport) -> None:
h2_conn.initiate_connection()

self.connection = Connection(h2_conn, cast(Transport, transport),
self._ping_delay, self._ping_timeout,
loop=self.loop)
self.connection.initialize()
self.connection.flush()

self.processor = EventsProcessor(self.handler, self.connection)
Expand Down
12 changes: 9 additions & 3 deletions tests/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self, *, loop):
client_config = H2Configuration(client_side=True,
header_encoding='ascii')
self.client_proto = H2Protocol(client.Handler(), client_config,
loop=loop)
loop=loop,
ping_delay=0)
self.client_proto.connection_made(self.to_server_transport)

def server_flush(self):
Expand Down Expand Up @@ -129,11 +130,14 @@ class ClientServer:
server = None
channel = None

def __init__(self, handler_cls, stub_cls, *, loop, codec=None):
def __init__(self, handler_cls, stub_cls, *, loop, codec=None,
ping_delay=0, ping_timeout=0):
self.handler_cls = handler_cls
self.stub_cls = stub_cls
self.loop = loop
self.codec = codec
self.ping_delay = ping_delay
self.ping_timeout = ping_timeout

async def __aenter__(self):
host = '127.0.0.1'
Expand All @@ -146,7 +150,9 @@ async def __aenter__(self):
await self.server.start(host, port)

self.channel = client.Channel(host, port, loop=self.loop,
codec=self.codec)
codec=self.codec,
ping_delay=self.ping_delay,
ping_timeout=self.ping_timeout)
stub = self.stub_cls(self.channel)
return handler, stub

Expand Down
59 changes: 59 additions & 0 deletions tests/test_ping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import asyncio
import async_timeout

import grpclib.const
import grpclib.server

from grpclib.client import UnaryStreamMethod
from grpclib.exceptions import StreamTerminatedError

from conn import ClientServer
from dummy_pb2 import DummyRequest, DummyReply


class PingServiceHandler:
async def UnaryStream(self, stream):
await stream.recv_message()
await stream.send_message(DummyReply(value='ping'))
await asyncio.sleep(0.1)
await stream.send_message(DummyReply(value='ping'))

def __mapping__(self):
return {
'/ping.PingService/UnaryStream': grpclib.const.Handler(
self.UnaryStream,
grpclib.const.Cardinality.UNARY_STREAM,
DummyRequest,
DummyReply,
),
}


class PingServiceStub:

def __init__(self, channel):
self.UnaryStream = UnaryStreamMethod(
channel,
'/ping.PingService/UnaryStream',
DummyRequest,
DummyReply,
)


@pytest.mark.asyncio
async def test_stream_ping(loop):
ctx = ClientServer(PingServiceHandler, PingServiceStub, loop=loop,
ping_delay=0.01, ping_timeout=0.1)
async with ctx as (handler, stub):
await stub.UnaryStream(DummyRequest(value='ping'))


@pytest.mark.asyncio
async def test_stream_cancel_by_ping(loop):
ctx = ClientServer(PingServiceHandler, PingServiceStub, loop=loop,
ping_delay=0.1, ping_timeout=0.01)
with pytest.raises(StreamTerminatedError):
with async_timeout.timeout(5):
async with ctx as (handler, stub):
await stub.UnaryStream(DummyRequest(value='ping'))