diff --git a/README.md b/README.md index 338bb9b..7cf9dc5 100644 --- a/README.md +++ b/README.md @@ -167,8 +167,8 @@ class Subscription(graphene.ObjectType): def resolve_count_seconds( - root, - info, + root, + info, up_to=5 ): return Observable.interval(1000)\ @@ -202,4 +202,36 @@ from graphql_ws.django_channels import GraphQLSubscriptionConsumer channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), ] -``` \ No newline at end of file +``` + +### Tornado +```python +from asyncio import Queue +from tornado import web, ioloop, websocket + +from graphql_ws.tornado import TornadoSubscriptionServer + + +subscription_server = TornadoSubscriptionServer(schema) + + +class SubscriptionHandler(websocket.WebSocketHandler): + def initialize(self, sub_server): + self.subscription_server = subscription_server + self.queue = Queue() + + def select_subprotocol(self, subprotocols): + return 'graphql-ws' + + def open(self): + ioloop.IOLoop.current().spawn_callback(subscription_server.handle, self) + + async def on_message(self, message): + await self.queue.put(message) + + async def recv(self): + return await self.queue.get() + +app = web.Application([(r"/subscriptions", SubscriptionHandler)]).listen(8000) +ioloop.IOLoop.current().start() +``` diff --git a/examples/tornado/__init__.py b/examples/tornado/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/tornado/app.py b/examples/tornado/app.py new file mode 100644 index 0000000..f8d807d --- /dev/null +++ b/examples/tornado/app.py @@ -0,0 +1,47 @@ +from asyncio import Queue +from tornado import web, ioloop, websocket + +from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler + +from graphql_ws.tornado import TornadoSubscriptionServer +from graphql_ws.constants import GRAPHQL_WS + +from .template import render_graphiql +from .schema import schema + + +class GraphiQLHandler(web.RequestHandler): + def get(self): + self.finish(render_graphiql()) + + +class SubscriptionHandler(websocket.WebSocketHandler): + def initialize(self, subscription_server): + self.subscription_server = subscription_server + self.queue = Queue(100) + + def select_subprotocol(self, subprotocols): + return GRAPHQL_WS + + def open(self): + ioloop.IOLoop.current().spawn_callback(self.subscription_server.handle, self) + + async def on_message(self, message): + await self.queue.put(message) + + async def recv(self): + return await self.queue.get() + + +subscription_server = TornadoSubscriptionServer(schema) + +app = web.Application([ + (r"/graphql$", TornadoGraphQLHandler, dict( + schema=schema)), + (r"/subscriptions", SubscriptionHandler, dict( + subscription_server=subscription_server)), + (r"/graphiql$", GraphiQLHandler), +]) + +app.listen(8000) +ioloop.IOLoop.current().start() diff --git a/examples/tornado/requirements.txt b/examples/tornado/requirements.txt new file mode 100644 index 0000000..0eae84e --- /dev/null +++ b/examples/tornado/requirements.txt @@ -0,0 +1,3 @@ +graphql_ws +tornado +graphene>=2.0 diff --git a/examples/tornado/schema.py b/examples/tornado/schema.py new file mode 100644 index 0000000..3c23d00 --- /dev/null +++ b/examples/tornado/schema.py @@ -0,0 +1,34 @@ +import random +import asyncio +import graphene + + +class Query(graphene.ObjectType): + base = graphene.String() + + +class RandomType(graphene.ObjectType): + seconds = graphene.Int() + random_int = graphene.Int() + + +class Subscription(graphene.ObjectType): + count_seconds = graphene.Float(up_to=graphene.Int()) + random_int = graphene.Field(RandomType) + + async def resolve_count_seconds(root, info, up_to=5): + for i in range(up_to): + print("YIELD SECOND", i) + yield i + await asyncio.sleep(1.) + yield up_to + + async def resolve_random_int(root, info): + i = 0 + while True: + yield RandomType(seconds=i, random_int=random.randint(0, 500)) + await asyncio.sleep(1.) + i += 1 + + +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/tornado/template.py b/examples/tornado/template.py new file mode 100644 index 0000000..0b74e96 --- /dev/null +++ b/examples/tornado/template.py @@ -0,0 +1,125 @@ + +from string import Template + + +def render_graphiql(): + return Template(''' + + + + + GraphiQL + + + + + + + + + + + + + +''').substitute( + GRAPHIQL_VERSION='0.10.2', + SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', + subscriptionsEndpoint='ws://localhost:8000/subscriptions', + # subscriptionsEndpoint='ws://localhost:5000/', + endpointURL='/graphql', + ) diff --git a/graphql_ws/tornado.py b/graphql_ws/tornado.py new file mode 100644 index 0000000..212fa6a --- /dev/null +++ b/graphql_ws/tornado.py @@ -0,0 +1,114 @@ +from inspect import isawaitable + +from asyncio import ensure_future, wait, shield +from tornado.websocket import WebSocketClosedError +from graphql.execution.executors.asyncio import AsyncioExecutor + +from .base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer +from .observable_aiter import setup_observable_extension + +from .constants import ( + GQL_CONNECTION_ACK, + GQL_CONNECTION_ERROR, + GQL_COMPLETE +) + +setup_observable_extension() + + +class TornadoConnectionContext(BaseConnectionContext): + async def receive(self): + try: + msg = await self.ws.recv() + return msg + except WebSocketClosedError: + raise ConnectionClosedException() + + async def send(self, data): + if self.closed: + return + await self.ws.write_message(data) + + @property + def closed(self): + return self.ws.close_code is not None + + async def close(self, code): + await self.ws.close(code) + + +class TornadoSubscriptionServer(BaseSubscriptionServer): + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + def get_graphql_params(self, *args, **kwargs): + params = super(TornadoSubscriptionServer, + self).get_graphql_params(*args, **kwargs) + return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) + + async def _handle(self, ws, request_context): + connection_context = TornadoConnectionContext(ws, request_context) + await self.on_open(connection_context) + pending = set() + while True: + try: + if connection_context.closed: + raise ConnectionClosedException() + message = await connection_context.receive() + except ConnectionClosedException: + break + finally: + if pending: + (_, pending) = await wait(pending, timeout=0, loop=self.loop) + + task = ensure_future( + self.on_message(connection_context, message), loop=self.loop) + pending.add(task) + + self.on_close(connection_context) + for task in pending: + task.cancel() + + async def handle(self, ws, request_context=None): + await shield(self._handle(ws, request_context), loop=self.loop) + + async def on_open(self, connection_context): + pass + + def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + execution_result = self.execute( + connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if not hasattr(execution_result, '__aiter__'): + await self.send_execution_result(connection_context, op_id, execution_result) + else: + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result(connection_context, op_id, single_result) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + + async def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id)