Skip to content

Commit

Permalink
Add support for UNIX sockets (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
kippandrew authored and vmagamedov committed Aug 21, 2018
1 parent 5e79508 commit dac7c60
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 16 deletions.
32 changes: 28 additions & 4 deletions grpclib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,31 @@ class Channel:
"""
_protocol = None

def __init__(self, host='127.0.0.1', port=50051, *, loop, codec=None):
def __init__(self, host=None, port=None, *, loop, path=None, codec=None):
"""Initialize connection to the server
:param host: server host name.
:param port: server port number.
:param path: server socket path. If specified, host and port should be
omitted (must be None).
"""
if path is not None and (host is not None or port is not None):
raise ValueError("The 'path' parameter can not be used with the "
"'host' or 'port' parameters.")
else:
if host is None:
host = '127.0.0.1'

if port is None:
port = 50051

self._host = host
self._port = port
self._loop = loop
self._path = path

self._codec = codec or ProtoCodec()

self._config = H2Configuration(client_side=True,
Expand All @@ -414,9 +435,12 @@ def _content_type(self):

async def __connect__(self):
if self._protocol is None or self._protocol.handler.connection_lost:
_, self._protocol = await self._loop.create_connection(
self._protocol_factory, self._host, self._port
)
if self._path is not None:
_, self._protocol = await self._loop.create_unix_connection(
self._protocol_factory, self._path)
else:
_, self._protocol = await self._loop.create_connection(
self._protocol_factory, self._host, self._port)
return self._protocol

def request(self, name, request_type, reply_type, *, timeout=None,
Expand Down
38 changes: 26 additions & 12 deletions grpclib/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def __init__(self, handlers, *, loop, codec=None):
header_encoding='ascii',
)

self._tcp_server = None
self._server = None
self._handlers = set()

def __gc_collect__(self):
Expand All @@ -457,7 +457,7 @@ def _protocol_factory(self):
self._handlers.add(handler)
return H2Protocol(handler, self._config, loop=self._loop)

async def start(self, host=None, port=None, *,
async def start(self, host=None, port=None, *, path=None,
family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
sock=None, backlog=100, ssl=None, reuse_address=None,
reuse_port=None):
Expand All @@ -468,6 +468,9 @@ async def start(self, host=None, port=None, *,
:param port: port number.
:param path: UNIX domain socket path. If specified, host and port should
be omitted (must be None).
:param family: can be set to either :py:data:`python:socket.AF_INET` or
:py:data:`python:socket.AF_INET6` to force the socket to use IPv4 or
IPv6. If not set it will be determined from host.
Expand All @@ -492,33 +495,44 @@ async def start(self, host=None, port=None, *,
to the same port as other existing endpoints are bound to,
so long as they all set this flag when being created.
"""
if self._tcp_server is not None:
if path is not None and (host is not None or port is not None):
raise ValueError("The 'path' parameter can not be used with the "
"'host' or 'port' parameters.")

if self._server is not None:
raise RuntimeError('Server is already started')

self._tcp_server = await self._loop.create_server(
self._protocol_factory, host, port,
family=family, flags=flags, sock=sock, backlog=backlog, ssl=ssl,
reuse_address=reuse_address, reuse_port=reuse_port
)
if path is not None:
self._server = await self._loop.create_unix_server(
self._protocol_factory, path, sock=sock, backlog=backlog,
ssl=ssl
)

else:
self._server = await self._loop.create_server(
self._protocol_factory, host, port,
family=family, flags=flags, sock=sock, backlog=backlog, ssl=ssl,
reuse_address=reuse_address, reuse_port=reuse_port
)

def close(self):
"""Stops accepting new connections, cancels all currently running
requests. Request handlers are able to handle `CancelledError` and
exit properly.
"""
if self._tcp_server is None:
if self._server is None:
raise RuntimeError('Server is not started')
self._tcp_server.close()
self._server.close()
for handler in self._handlers:
handler.close()

async def wait_closed(self):
"""Coroutine to wait until all existing request handlers will exit
properly.
"""
if self._tcp_server is None:
if self._server is None:
raise RuntimeError('Server is not started')
await self._tcp_server.wait_closed()
await self._server.wait_closed()
if self._handlers:
await asyncio.wait({h.wait_closed() for h in self._handlers},
loop=self._loop)
42 changes: 42 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import socket
import tempfile

import pytest

Expand Down Expand Up @@ -66,6 +68,38 @@ async def __aexit__(self, *exc_info):
self.channel.close()


class UnixClientServer:
temp = None
sock = None
server = None
channel = None

def __init__(self, *, loop):
self.loop = loop

async def __aenter__(self):
self.temp = tempfile.mkdtemp()
self.sock = os.path.join(self.temp, 'grpclib.sock')

dummy_service = DummyService()

self.server = Server([dummy_service], loop=self.loop)
await self.server.start(path=self.sock)

self.channel = Channel(path=self.sock, loop=self.loop)
dummy_stub = DummyServiceStub(self.channel)
return dummy_service, dummy_stub

async def __aexit__(self, *exc_info):
self.server.close()
await self.server.wait_closed()
self.channel.close()
if os.path.exists(self.sock):
os.unlink(self.sock)
if os.path.exists(self.temp):
os.rmdir(self.temp)


@pytest.mark.asyncio
async def test_close_empty_channel(loop):
async with ClientServer(loop=loop):
Expand All @@ -80,6 +114,14 @@ async def test_unary_unary_simple(loop):
assert handler.log == [DummyRequest(value='ping')]


@pytest.mark.asyncio
async def test_unary_unary_simple_unix(loop):
async with UnixClientServer(loop=loop) as (handler, stub):
reply = await stub.UnaryUnary(DummyRequest(value='ping'))
assert reply == DummyReply(value='pong')
assert handler.log == [DummyRequest(value='ping')]


@pytest.mark.asyncio
async def test_unary_unary_advanced(loop):
async with ClientServer(loop=loop) as (handler, stub):
Expand Down

0 comments on commit dac7c60

Please sign in to comment.