Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide execution context as an argument to schema extensions #3640

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fc
nrbnlulu committed Sep 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 8185fc1b30d30098f75e2d36d25a52793888838b
12 changes: 6 additions & 6 deletions strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
@@ -19,33 +19,33 @@ class LifecycleStep(Enum):


class SchemaExtension:
execution_context: ExecutionContext

# to support extensions that still use the old signature
# we have an optional argument here for ease of initialization.
def __init__(
self, *, execution_context: ExecutionContext | None = None
) -> None: ...

def on_operation( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after a GraphQL operation (query / mutation) starts."""
yield None

def on_validate( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after the validation step."""
yield None

def on_parse( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after the parsing step."""
yield None

def on_execute( # type: ignore
self,
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after the execution step."""
yield None
@@ -60,7 +60,7 @@ def resolve(
) -> AwaitableOrValue[object]:
return _next(root, info, *args, **kwargs)

def get_results(self) -> AwaitableOrValue[Dict[str, Any]]:
def get_results(self, execution_context: ExecutionContext) -> AwaitableOrValue[Dict[str, Any]]:
return {}

@classmethod
13 changes: 8 additions & 5 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
)

from strawberry.extensions import SchemaExtension
from strawberry.types.execution import ExecutionContext
from strawberry.utils.await_maybe import AwaitableOrValue, await_maybe

if TYPE_CHECKING:
@@ -31,7 +32,7 @@

class WrappedHook(NamedTuple):
extension: SchemaExtension
hook: Callable[..., Union[AsyncContextManager[None], ContextManager[None]]]
hook: Callable[[ExecutionContext], Union[AsyncContextManager[None], ContextManager[None]]]
is_async: bool


@@ -42,6 +43,7 @@ class ExtensionContextManagerBase:
"default_hook",
"async_exit_stack",
"exit_stack",
"execution_context",
)

def __init_subclass__(cls) -> None:
@@ -56,8 +58,9 @@ def __init_subclass__(cls) -> None:
LEGACY_ENTER: str
LEGACY_EXIT: str

def __init__(self, extensions: List[SchemaExtension]) -> None:
def __init__(self, extensions: List[SchemaExtension], execution_context: ExecutionContext) -> None:
self.hooks: List[WrappedHook] = []
self.execution_context = execution_context
self.default_hook: Hook = getattr(SchemaExtension, self.HOOK_NAME)
for extension in extensions:
hook = self.get_hook(extension)
@@ -175,7 +178,7 @@ def __enter__(self) -> None:
"failed to complete synchronously."
)
else:
self.exit_stack.enter_context(hook.hook()) # type: ignore
self.exit_stack.enter_context(hook.hook(self.execution_context))

def __exit__(
self,
@@ -192,9 +195,9 @@ async def __aenter__(self) -> None:

for hook in self.hooks:
if hook.is_async:
await self.async_exit_stack.enter_async_context(hook.hook()) # type: ignore
await self.async_exit_stack.enter_async_context(hook.hook(self.execution_context)) # type: ignore
else:
self.async_exit_stack.enter_context(hook.hook()) # type: ignore
self.async_exit_stack.enter_context(hook.hook(self.execution_context)) # type: ignore

async def __aexit__(
self,
8 changes: 4 additions & 4 deletions strawberry/extensions/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

from strawberry.extensions.context import (
ExecutingContextManager,
@@ -40,21 +40,21 @@ def parsing(self) -> ParsingContextManager:
def executing(self) -> ExecutingContextManager:
return ExecutingContextManager(self.extensions)

def get_extensions_results_sync(self) -> Dict[str, Any]:
def get_extensions_results_sync(self, ctx: ExecutionContext) -> Dict[str, Any]:
data: Dict[str, Any] = {}
for extension in self.extensions:
if inspect.iscoroutinefunction(extension.get_results):
msg = "Cannot use async extension hook during sync execution"
raise RuntimeError(msg)
data.update(extension.get_results()) # type: ignore
data.update(cast(Dict[str, Any], extension.get_results(ctx)))

return data

async def get_extensions_results(self, ctx: ExecutionContext) -> Dict[str, Any]:
data: Dict[str, Any] = {}

for extension in self.extensions:
data.update(await await_maybe(extension.get_results()))
data.update(await await_maybe(extension.get_results(ctx)))

data.update(ctx.extensions_results)
return data