From 3d028e30a5f98c21748b9b3f3a46551935a9610c Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 24 Apr 2025 14:53:25 -0700 Subject: [PATCH 1/2] Adding AppSync events --- .../event_handler/__init__.py | 2 + .../event_handler/api_gateway.py | 26 +- .../event_handler/appsync.py | 33 +- .../event_handler/events_appsync/__init__.py | 5 + .../event_handler/events_appsync/_registry.py | 92 + .../events_appsync/appsync_events.py | 422 +++++ .../event_handler/events_appsync/base.py | 42 + .../events_appsync/exceptions.py | 25 + .../event_handler/events_appsync/functions.py | 108 ++ .../event_handler/events_appsync/router.py | 199 ++ .../event_handler/events_appsync/types.py | 21 + .../event_handler/exception_handling.py | 118 ++ .../utilities/data_classes/__init__.py | 2 + .../data_classes/appsync_resolver_event.py | 75 +- .../appsync_resolver_events_event.py | 56 + docs/core/event_handler/appsync_events.md | 24 +- .../src/accessing_event_and_context.py | 13 +- .../getting_started_with_publish_events.py | 2 +- .../getting_started_with_subscribe_events.py | 12 +- .../getting_started_with_testing_publish.py | 2 +- .../getting_started_with_testing_subscribe.py | 2 +- .../src/working_with_aggregated_events.py | 32 +- .../src/working_with_error_handling.py | 8 +- .../working_with_error_handling_multiple.py | 6 +- .../src/working_with_wildcard_resolvers.py | 2 +- tests/events/appSyncEventsEvent.json | 70 + .../appsync/test_appsync_events_resolvers.py | 1614 +++++++++++++++++ .../test_appsync_events_event.py | 16 + .../_required_dependencies/__init__.py | 0 .../appsync_events/__init__.py | 145 ++ .../appsync_events/test_functions.py | 0 .../test_exception_handler_manager.py | 176 ++ 32 files changed, 3222 insertions(+), 128 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/events_appsync/__init__.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/_registry.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/appsync_events.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/base.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/exceptions.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/functions.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/router.py create mode 100644 aws_lambda_powertools/event_handler/events_appsync/types.py create mode 100644 aws_lambda_powertools/event_handler/exception_handling.py create mode 100644 aws_lambda_powertools/utilities/data_classes/appsync_resolver_events_event.py create mode 100644 tests/events/appSyncEventsEvent.json create mode 100644 tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py create mode 100644 tests/unit/data_classes/required_dependencies/test_appsync_events_event.py create mode 100644 tests/unit/event_handler/_required_dependencies/__init__.py create mode 100644 tests/unit/event_handler/_required_dependencies/appsync_events/__init__.py create mode 100644 tests/unit/event_handler/_required_dependencies/appsync_events/test_functions.py create mode 100644 tests/unit/event_handler/_required_dependencies/test_exception_handler_manager.py diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index ffbb2abe4ae..8bcf2d6636c 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -12,6 +12,7 @@ ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver +from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, ) @@ -19,6 +20,7 @@ __all__ = [ "AppSyncResolver", + "AppSyncEventsResolver", "APIGatewayRestResolver", "APIGatewayHttpResolver", "ALBResolver", diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 399ec8052d1..a5c5a7bb053 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -17,6 +17,7 @@ from typing_extensions import override from aws_lambda_powertools.event_handler import content_types +from aws_lambda_powertools.event_handler.exception_handling import ExceptionHandlerManager from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.config import OpenAPIConfig from aws_lambda_powertools.event_handler.openapi.constants import ( @@ -1576,6 +1577,7 @@ def __init__( self.processed_stack_frames = [] self._response_builder_class = ResponseBuilder[BaseProxyEvent] self.openapi_config = OpenAPIConfig() # starting an empty dataclass + self.exception_handler_manager = ExceptionHandlerManager() self._has_response_validation_error = response_validation_error_http_code is not None self._response_validation_error_http_code = self._validate_response_validation_error_http_code( response_validation_error_http_code, @@ -2498,7 +2500,7 @@ def not_found_handler(): return Response(status_code=204, content_type=None, headers=_headers, body="") # Customer registered 404 route? Call it. - custom_not_found_handler = self._lookup_exception_handler(NotFoundError) + custom_not_found_handler = self.exception_handler_manager.lookup_exception_handler(NotFoundError) if custom_not_found_handler: return custom_not_found_handler(NotFoundError()) @@ -2571,26 +2573,10 @@ def not_found(self, func: Callable | None = None): return self.exception_handler(NotFoundError)(func) def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]): - def register_exception_handler(func: Callable): - if isinstance(exc_class, list): # pragma: no cover - for exp in exc_class: - self._exception_handlers[exp] = func - else: - self._exception_handlers[exc_class] = func - return func - - return register_exception_handler - - def _lookup_exception_handler(self, exp_type: type) -> Callable | None: - # Use "Method Resolution Order" to allow for matching against a base class - # of an exception - for cls in exp_type.__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None + return self.exception_handler_manager.exception_handler(exc_class=exc_class) def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuilder | None: - handler = self._lookup_exception_handler(type(exp)) + handler = self.exception_handler_manager.lookup_exception_handler(type(exp)) if handler: try: return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route) @@ -2686,7 +2672,7 @@ def include_router(self, router: Router, prefix: str | None = None) -> None: self._router_middlewares = self._router_middlewares + router._router_middlewares logger.debug("Appending Router exception_handler into App exception_handler.") - self._exception_handlers.update(router._exception_handlers) + self.exception_handler_manager.update_exception_handlers(router._exception_handlers) # use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx) router.context = self.context diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index c7b48b6a4d4..29c48d71cb1 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -5,6 +5,7 @@ import warnings from typing import TYPE_CHECKING, Any +from aws_lambda_powertools.event_handler.exception_handling import ExceptionHandlerManager from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import InvalidBatchResponse, ResolverNotFoundError from aws_lambda_powertools.event_handler.graphql_appsync.router import Router from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent @@ -55,6 +56,7 @@ def __init__(self): """ super().__init__() self.context = {} # early init as customers might add context before event resolution + self.exception_handler_manager = ExceptionHandlerManager() self._exception_handlers: dict[type, Callable] = {} def __call__( @@ -153,7 +155,7 @@ def lambda_handler(event, context): Router.current_event = data_model(event) response = self._call_single_resolver(event=event, data_model=data_model) except Exception as exp: - response_builder = self._lookup_exception_handler(type(exp)) + response_builder = self.exception_handler_manager.lookup_exception_handler(type(exp)) if response_builder: return response_builder(exp) raise @@ -495,31 +497,4 @@ def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]): A decorator function that registers the exception handler. """ - def register_exception_handler(func: Callable): - if isinstance(exc_class, list): # pragma: no cover - for exp in exc_class: - self._exception_handlers[exp] = func - else: - self._exception_handlers[exc_class] = func - return func - - return register_exception_handler - - def _lookup_exception_handler(self, exp_type: type) -> Callable | None: - """ - Looks up the registered exception handler for the given exception type or its base classes. - - Parameters - ---------- - exp_type (type): - The exception type to look up the handler for. - - Returns - ------- - Callable | None: - The registered exception handler function if found, otherwise None. - """ - for cls in exp_type.__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None + return self.exception_handler_manager.exception_handler(exc_class=exc_class) diff --git a/aws_lambda_powertools/event_handler/events_appsync/__init__.py b/aws_lambda_powertools/event_handler/events_appsync/__init__.py new file mode 100644 index 00000000000..64387723526 --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/__init__.py @@ -0,0 +1,5 @@ +from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver + +__all__ = [ + "AppSyncEventsResolver", +] diff --git a/aws_lambda_powertools/event_handler/events_appsync/_registry.py b/aws_lambda_powertools/event_handler/events_appsync/_registry.py new file mode 100644 index 00000000000..8c682327706 --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/_registry.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import logging +import warnings +from typing import TYPE_CHECKING + +from aws_lambda_powertools.event_handler.events_appsync.functions import find_best_route, is_valid_path +from aws_lambda_powertools.warnings import PowertoolsUserWarning + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_lambda_powertools.event_handler.events_appsync.types import ResolverTypeDef + + +logger = logging.getLogger(__name__) + + +class ResolverEventsRegistry: + def __init__(self, kind_resolver: str): + self.resolvers: dict[str, ResolverTypeDef] = {} + self.kind_resolver = kind_resolver + + def register( + self, + path: str = "/default/*", + aggregate: bool = False, + ) -> Callable | None: + """Registers the resolver for path that includes namespace + channel + + Parameters + ---------- + path : str + Path including namespace + channel + aggregate: bool + A flag indicating whether the batch items should be processed at once or individually. + If True, the resolver will process all items as a single event. + If False (default), the resolver will process each item individually. + + Return + ---------- + Callable + A Callable + """ + + def _register(func) -> Callable | None: + if not is_valid_path(path): + warnings.warn( + f"The path `{path}` registered for `{self.kind_resolver}` is not valid and will be skipped." + f"A path should always have a namespace starting with '/'" + "A path can have multiple namespaces, all separated by '/'." + "Wildcards are allowed only at the end of the path.", + stacklevel=2, + category=PowertoolsUserWarning, + ) + return None + + logger.debug( + f"Adding resolver `{func.__name__}` for path `{path}` and kind_resolver `{self.kind_resolver}`", + ) + self.resolvers[f"{path}"] = { + "func": func, + "aggregate": aggregate, + } + return func + + return _register + + def find_resolver(self, path: str) -> ResolverTypeDef | None: + """Find resolver based on type_name and field_name + + Parameters + ---------- + path : str + Type name + Return + ---------- + dict | None + A dictionary with the resolver and if this is aggregated or not + """ + logger.debug(f"Looking for resolver for path `{path}` and kind_resolver `{self.kind_resolver}`") + return self.resolvers.get(find_best_route(self.resolvers, path)) + + def merge(self, other_registry: ResolverEventsRegistry): + """Update current registry with incoming registry + + Parameters + ---------- + other_registry : ResolverRegistry + Registry to merge from + """ + self.resolvers.update(**other_registry.resolvers) diff --git a/aws_lambda_powertools/event_handler/events_appsync/appsync_events.py b/aws_lambda_powertools/event_handler/events_appsync/appsync_events.py new file mode 100644 index 00000000000..ee03db5c625 --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/appsync_events.py @@ -0,0 +1,422 @@ +from __future__ import annotations + +import asyncio +import logging +import warnings +from typing import TYPE_CHECKING, Any + +from aws_lambda_powertools.event_handler.events_appsync.exceptions import UnauthorizedException +from aws_lambda_powertools.event_handler.events_appsync.router import Router +from aws_lambda_powertools.utilities.data_classes.appsync_resolver_events_event import AppSyncResolverEventsEvent +from aws_lambda_powertools.warnings import PowertoolsUserWarning + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_lambda_powertools.event_handler.events_appsync.types import ResolverTypeDef + from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext + + +logger = logging.getLogger(__name__) + + +class AppSyncEventsResolver(Router): + """ + AppSync Events API Resolver for handling publish and subscribe operations. + + This class extends the Router to process AppSync real-time API events, managing + both synchronous and asynchronous resolvers for event publishing and subscribing. + + Attributes + ---------- + context: dict + Dictionary to store context information accessible across resolvers + lambda_context: LambdaContext + Lambda context from the AWS Lambda function + current_event: AppSyncResolverEventsEvent + Current event being processed + + Examples + -------- + Define a simple AppSync events resolver for a chat application: + + >>> from aws_lambda_powertools.event_handler import AppSyncEventsResolver + >>> app = AppSyncEventsResolver() + >>> + >>> # Using aggregate mode to process multiple messages at once + >>> @app.on_publish(channel_path="/default/*", aggregate=True) + >>> def handle_batch_messages(payload): + >>> processed_messages = [] + >>> for message in payload: + >>> # Process each message + >>> processed_messages.append({ + >>> "messageId": f"msg-{message.get('id')}", + >>> "processed": True + >>> }) + >>> return processed_messages + >>> + >>> # Asynchronous resolver + >>> @app.async_on_publish(channel_path="/default/*") + >>> async def handle_async_messages(event): + >>> # Perform async operations (e.g., DB queries, HTTP calls) + >>> await asyncio.sleep(0.1) # Simulate async work + >>> return { + >>> "messageId": f"async-{event.get('id')}", + >>> "processed": True + >>> } + >>> + >>> # Lambda handler + >>> def lambda_handler(event, context): + >>> return events.resolve(event, context) + """ + + def __init__(self): + """Initialize the AppSyncEventsResolver.""" + super().__init__() + self.context = {} # early init as customers might add context before event resolution + self._exception_handlers: dict[type, Callable] = {} + + def __call__( + self, + event: dict | AppSyncResolverEventsEvent, + context: LambdaContext, + ) -> Any: + """ + Implicit lambda handler which internally calls `resolve`. + + Parameters + ---------- + event: dict or AppSyncResolverEventsEvent + The AppSync event to process + context: LambdaContext + The Lambda context + + Returns + ------- + Any + The resolver's response + """ + return self.resolve(event, context) + + def resolve( + self, + event: dict | AppSyncResolverEventsEvent, + context: LambdaContext, + ) -> Any: + """ + Resolves the response based on the provided event and decorator operation. + + Parameters + ---------- + event: dict or AppSyncResolverEventsEvent + The AppSync event to process + context: LambdaContext + The Lambda context + + Returns + ------- + Any + The resolver's response based on the operation type + + Examples + -------- + >>> events = AppSyncEventsResolver() + >>> + >>> # Explicit call to resolve in Lambda handler + >>> def lambda_handler(event, context): + >>> return events.resolve(event, context) + """ + + self._setup_context(event, context) + + if self.current_event.info.operation == "PUBLISH": + response = self._publish_events(payload=self.current_event.events) + else: + response = self._subscribe_events() + + self.clear_context() + + return response + + def _subscribe_events(self) -> Any: + """ + Handle subscribe events. + + Returns + ------- + Any + Any response + """ + channel_path = self.current_event.info.channel_path + logger.debug(f"Processing subscribe events for path {channel_path}") + + resolver = self._subscribe_registry.find_resolver(channel_path) + if resolver: + try: + resolver["func"]() + return None # Must return None in subscribe events + except UnauthorizedException: + raise + except Exception as error: + return {"error": self._format_error_response(error)} + + self._warn_no_resolver("subscribe", channel_path) + return None + + def _publish_events(self, payload: list[dict[str, Any]]) -> list[dict[str, Any]] | dict[str, Any]: + """ + Handle publish events. + + Parameters + ---------- + payload: list[dict[str, Any]] + The events payload to process + + Returns + ------- + list[dict[str, Any]] or dict[str, Any] + Processed events or error response + """ + + channel_path = self.current_event.info.channel_path + + logger.debug(f"Processing publish events for path {channel_path}") + + resolver = self._publish_registry.find_resolver(channel_path) + async_resolver = self._async_publish_registry.find_resolver(channel_path) + + if resolver and async_resolver: + warnings.warn( + f"Both synchronous and asynchronous resolvers found for the same event and field." + f"The synchronous resolver takes precedence. Executing: {resolver['func'].__name__}", + stacklevel=2, + category=PowertoolsUserWarning, + ) + + if resolver: + logger.debug(f"Found sync resolver: {resolver}") + return self._process_publish_event_sync_resolver(resolver) + + if async_resolver: + logger.debug(f"Found async resolver: {async_resolver}") + return asyncio.run(self._call_publish_event_async_resolver(async_resolver)) + + # No resolver found + # Warning and returning AS IS + self._warn_no_resolver("publish", channel_path, return_as_is=True) + return {"events": payload} + + def _process_publish_event_sync_resolver( + self, + resolver: ResolverTypeDef, + ) -> list[dict[str, Any]] | dict[str, Any]: + """ + Process events using a synchronous resolver. + + Parameters + ---------- + resolver : ResolverTypeDef + The resolver to use for processing events + + Returns + ------- + list[dict[str, Any]] or dict[str, Any] + Processed events or error response + + Notes + ----- + If the resolver is configured with aggregate=True, all events are processed + as a batch. Otherwise, each event is processed individually. + """ + + # Checks whether the entire batch should be processed at once + if resolver["aggregate"]: + try: + # Process the entire batch + response = resolver["func"](payload=self.current_event.events) + + if not isinstance(response, list): + warnings.warn( + "Response must be a list when using aggregate, AppSync will drop those events.", + stacklevel=2, + category=PowertoolsUserWarning, + ) + + return {"events": response} + except UnauthorizedException: + raise + except Exception as error: + return {"error": self._format_error_response(error)} + + # By default, we gracefully append `None` for any records that failed processing + results = [] + for idx, event in enumerate(self.current_event.events): + try: + result_return = resolver["func"](payload=event.get("payload")) + results.append({"id": event.get("id"), "payload": result_return}) + except Exception as error: + logger.debug(f"Failed to process event number {idx}") + error_return = {"id": event.get("id"), "error": self._format_error_response(error)} + results.append(error_return) + + return {"events": results} + + async def _call_publish_event_async_resolver( + self, + resolver: ResolverTypeDef, + ) -> list[dict[str, Any]] | dict[str, Any]: + """ + Process events using an asynchronous resolver. + + Parameters + ---------- + resolver: ResolverTypeDef + The async resolver to use for processing events + + Returns + ------- + list[Any] + Processed events or error responses + + Notes + ----- + If the resolver is configured with aggregate=True, all events are processed + as a batch. Otherwise, each event is processed individually and in parallel. + """ + + # Checks whether the entire batch should be processed at once + if resolver["aggregate"]: + try: + # Process the entire batch + response = await resolver["func"](payload=self.current_event.events) + if not isinstance(response, list): + warnings.warn( + "Response must be a list when using aggregate, AppSync will drop those events.", + stacklevel=2, + category=PowertoolsUserWarning, + ) + + return {"events": response} + except UnauthorizedException: + raise + except Exception as error: + return {"error": self._format_error_response(error)} + + response_async: list = [] + + # Prime coroutines + tasks = [resolver["func"](payload=e.get("payload")) for e in self.current_event.events] + + # Aggregate results and exceptions, then filter them out + # Use `None` upon exception for graceful error handling at GraphQL engine level + # + # NOTE: asyncio.gather(return_exceptions=True) catches and includes exceptions in the results + # this will become useful when we support exception handling in AppSync resolver + # Aggregate results and exceptions, then filter them out + results = await asyncio.gather(*tasks, return_exceptions=True) + response_async.extend( + [ + ( + {"id": e.get("id"), "error": self._format_error_response(ret)} + if isinstance(ret, Exception) + else {"id": e.get("id"), "payload": ret} + ) + for e, ret in zip(self.current_event.events, results) + ], + ) + + return {"events": response_async} + + def include_router(self, router: Router) -> None: + """ + Add all resolvers defined in a router to this resolver. + + Parameters + ---------- + router : Router + A router containing resolvers to include + + Examples + -------- + >>> # Create main resolver and a router + >>> app = AppSyncEventsResolver() + >>> router = Router() + >>> + >>> # Define resolvers in the router + >>> @router.publish(path="/chat/message") + >>> def handle_chat_message(payload): + >>> return {"processed": True, "messageId": payload.get("id")} + >>> + >>> # Include the router in the main resolver + >>> app.include_router(chat_router) + >>> + >>> # Now events can handle "/chat/message" channel_path + """ + + # Merge app and router context + logger.debug("Merging router and app context") + self.context.update(**router.context) + + # use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx) + router.context = self.context + + logger.debug("Merging router resolver registries") + self._publish_registry.merge(router._publish_registry) + self._async_publish_registry.merge(router._async_publish_registry) + self._subscribe_registry.merge(router._subscribe_registry) + + def _format_error_response(self, error=None) -> str: + """ + Format error responses consistently. + + Parameters + ---------- + error: Exception or None + The error to format + + Returns + ------- + str + Formatted error message + """ + if isinstance(error, Exception): + return f"{error.__class__.__name__} - {str(error)}" + return "An unknown error occurred" + + def _warn_no_resolver(self, operation_type: str, path: str, return_as_is: bool = False) -> None: + """ + Generate consistent warning messages for missing resolvers. + + Parameters + ---------- + operation_type : str + Type of operation (e.g., "publish", "subscribe") + path : str + The channel path that's missing a resolver + return_as_is : bool, optional + Whether payload will be returned as is, by default False + """ + message = ( + f"No resolvers were found for {operation_type} operations with path {path}" + f"{'. We will return the entire payload as is' if return_as_is else ''}" + ) + warnings.warn(message, stacklevel=3, category=PowertoolsUserWarning) + + def _setup_context(self, event: dict | AppSyncResolverEventsEvent, context: LambdaContext) -> None: + """ + Set up the context and event for processing. + + Parameters + ---------- + event : dict or AppSyncResolverEventsEvent + The AppSync event to process + context : LambdaContext + The Lambda context + """ + self.lambda_context = context + Router.lambda_context = context + + Router.current_event = ( + event if isinstance(event, AppSyncResolverEventsEvent) else AppSyncResolverEventsEvent(event) + ) + self.current_event = Router.current_event diff --git a/aws_lambda_powertools/event_handler/events_appsync/base.py b/aws_lambda_powertools/event_handler/events_appsync/base.py new file mode 100644 index 00000000000..0973553cda4 --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/base.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable + + +class BaseRouter(ABC): + """Abstract base class for Router (resolvers)""" + + @abstractmethod + def on_publish( + self, + path: str = "/default/*", + aggregate: bool = True, + ) -> Callable: + raise NotImplementedError + + @abstractmethod + def async_on_publish( + self, + path: str = "/default/*", + aggregate: bool = True, + ) -> Callable: + raise NotImplementedError + + @abstractmethod + def on_subscribe( + self, + path: str = "/default/*", + ) -> Callable: + raise NotImplementedError + + def append_context(self, **additional_context) -> None: + """ + Appends context information available under any route. + + Parameters + ----------- + **additional_context: dict + Additional context key-value pairs to append. + """ + raise NotImplementedError diff --git a/aws_lambda_powertools/event_handler/events_appsync/exceptions.py b/aws_lambda_powertools/event_handler/events_appsync/exceptions.py new file mode 100644 index 00000000000..5093c68c603 --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/exceptions.py @@ -0,0 +1,25 @@ +from __future__ import annotations + + +class UnauthorizedException(Exception): + """ + Error to be thrown to communicate the subscription is unauthorized. + + When this error is raised, the client will receive a 40x error code + and the subscription will be closed. + + Attributes: + message (str): The error message describing the unauthorized access. + """ + + def __init__(self, message: str | None = None, *args, **kwargs): + """ + Initialize the UnauthorizedException. + + Args: + message (str): A descriptive error message. + *args: Variable positional arguments. + **kwargs: Variable keyword arguments. + """ + super().__init__(message, *args, **kwargs) + self.name = "UnauthorizedException" diff --git a/aws_lambda_powertools/event_handler/events_appsync/functions.py b/aws_lambda_powertools/event_handler/events_appsync/functions.py new file mode 100644 index 00000000000..7f4952a0dd7 --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/functions.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import re +from functools import lru_cache +from typing import Any + +PATH_REGEX = re.compile(r"^\/([^\/\*]+)(\/[^\/\*]+)*(\/\*)?$") + + +def is_valid_path(path: str) -> bool: + """ + Checks if a given path is valid based on specific rules. + + Parameters + ---------- + path: str + The path to validate + + Returns: + -------- + bool: + True if the path is valid, False otherwise + + Examples: + >>> is_valid_path('/*') + True + >>> is_valid_path('/users') + True + >>> is_valid_path('/users/profile') + True + >>> is_valid_path('/users/*/details') + False + >>> is_valid_path('/users/*') + True + >>> is_valid_path('users') + False + """ + if path == "/*": + return True + return bool(PATH_REGEX.fullmatch(path)) + + +def find_best_route(routes: dict[str, Any], path: str): + """ + Find the most specific matching route for a given path. + + Examples of matches: + Route: /default/v1/* Path: /default/v1/users -> MATCH + Route: /default/v1/* Path: /default/v1/users/students -> MATCH + Route: /default/v1/users/* Path: /default/v1/users/123 -> MATCH (this wins over /default/v1/*) + Route: /* Path: /anything/here -> MATCH (lowest priority) + + Parameters + ---------- + routes: dict[str, Any] + Dictionary containing routes and their handlers + Format: { + 'resolvers': { + '/path/*': {'func': callable, 'aggregate': bool}, + '/path/specific/*': {'func': callable, 'aggregate': bool} + } + } + path: str + Actual path to match (e.g., '/default/v1/users') + + Returns + ------- + str: Most specific matching route or None if no match + """ + + @lru_cache(maxsize=1024) + def pattern_to_regex(route): + """ + Convert a route pattern to a regex pattern with caching. + Examples: + /default/v1/* -> ^/default/v1/[^/]+$ + /default/v1/users/* -> ^/default/v1/users/.*$ + + Parameters + ---------- + route: str + Route pattern with wildcards + + Returns + ------- + Pattern: + Compiled regex pattern + """ + # Escape special regex chars but convert * to regex pattern + pattern = re.escape(route).replace("\\*", "[^/]+") + + # If pattern ends with [^/]+, replace with .* for multi-segment match + if pattern.endswith("[^/]+"): + pattern = pattern[:-6] + ".*" + + # Compile and return the regex pattern + return re.compile(f"^{pattern}$") + + # Find all matching routes + matches = [route for route in routes.keys() if pattern_to_regex(route).match(path)] + + # Return the most specific route (longest length minus wildcards) + # Examples of specificity: + # - '/default/v1/users' -> score: 14 (len=14, wildcards=0) + # - '/default/v1/users/*' -> score: 14 (len=15, wildcards=1) + # - '/default/v1/*' -> score: 8 (len=9, wildcards=1) + # - '/*' -> score: 0 (len=2, wildcards=1) + return max(matches, key=lambda x: len(x) - x.count("*"), default=None) diff --git a/aws_lambda_powertools/event_handler/events_appsync/router.py b/aws_lambda_powertools/event_handler/events_appsync/router.py new file mode 100644 index 00000000000..45d7c81ddbb --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/router.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_lambda_powertools.event_handler.events_appsync._registry import ResolverEventsRegistry +from aws_lambda_powertools.event_handler.events_appsync.base import BaseRouter + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_lambda_powertools.utilities.data_classes.appsync_resolver_events_event import AppSyncResolverEventsEvent + from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext + + +class Router(BaseRouter): + """ + Router for AppSync real-time API event handling. + + This class provides decorators to register resolver functions for publish and subscribe + operations in AppSync real-time APIs. + + Parameters + ---------- + context : dict + Dictionary to store context information accessible across resolvers + current_event : AppSyncResolverEventsEvent + Current event being processed + lambda_context : LambdaContext + Lambda context from the AWS Lambda function + + Examples + -------- + Create a router and define resolvers: + + >>> chat_router = Router() + >>> + >>> # Register a resolver for publish operations + >>> @chat_router.on_publish(path="/chat/message") + >>> def handle_message(payload): + >>> # Process message + >>> return {"success": True, "messageId": payload.get("id")} + >>> + >>> # Register an async resolver for publish operations + >>> @chat_router.async_on_publish(path="/chat/typing") + >>> async def handle_typing(event): + >>> # Process typing indicator + >>> await some_async_operation() + >>> return {"processed": True} + >>> + >>> # Register a resolver for subscribe operations + >>> @chat_router.on_subscribe(path="/chat/room/*") + >>> def handle_subscribe(event): + >>> # Handle subscription setup + >>> return {"allowed": True} + """ + + context: dict + current_event: AppSyncResolverEventsEvent + lambda_context: LambdaContext + + def __init__(self): + """ + Initialize a new Router instance. + + Sets up empty context and registry containers for different types of resolvers. + """ + self.context = {} # early init as customers might add context before event resolution + self._publish_registry = ResolverEventsRegistry(kind_resolver="on_publish") + self._async_publish_registry = ResolverEventsRegistry(kind_resolver="async_on_publish") + self._subscribe_registry = ResolverEventsRegistry(kind_resolver="on_subscribe") + + def on_publish( + self, + path: str = "/default/*", + aggregate: bool = False, + ) -> Callable: + """ + Register a resolver function for publish operations. + + Parameters + ---------- + path : str, optional + The channel path pattern to match for this resolver, by default "/default/*" + aggregate : bool, optional + Whether to process events in aggregate (batch) mode, by default False + + Returns + ------- + Callable + Decorator function that registers the resolver + + Examples + -------- + >>> router = Router() + >>> + >>> # Basic usage + >>> @router.on_publish(path="/notifications/new") + >>> def handle_notification(payload): + >>> # Process a single notification + >>> return {"processed": True, "notificationId": payload.get("id")} + >>> + >>> # Aggregate mode for batch processing + >>> @router.on_publish(path="/notifications/batch", aggregate=True) + >>> def handle_batch_notifications(payload): + >>> # Process multiple notifications at once + >>> results = [] + >>> for item in payload: + >>> # Process each item + >>> results.append({"processed": True, "id": item.get("id")}) + >>> return results + """ + return self._publish_registry.register(path=path, aggregate=aggregate) + + def async_on_publish( + self, + path: str = "/default/*", + aggregate: bool = False, + ) -> Callable: + """ + Register an asynchronous resolver function for publish operations. + + Parameters + ---------- + path : str, optional + The channel path pattern to match for this resolver, by default "/default/*" + aggregate : bool, optional + Whether to process events in aggregate (batch) mode, by default False + + Returns + ------- + Callable + Decorator function that registers the async resolver + + Examples + -------- + >>> router = Router() + >>> + >>> # Basic async usage + >>> @router.async_on_publish(path="/messages/send") + >>> async def handle_message(event): + >>> # Perform async operations + >>> result = await database.save_message(event) + >>> return {"saved": True, "messageId": result.id} + >>> + >>> # Aggregate mode for batch processing + >>> @router.async_on_publish(path="/messages/batch", aggregate=True) + >>> async def handle_batch_messages(events): + >>> # Process multiple messages asynchronously + >>> tasks = [database.save_message(e) for e in events] + >>> results = await asyncio.gather(*tasks) + >>> return [{"saved": True, "id": r.id} for r in results] + """ + return self._async_publish_registry.register(path=path, aggregate=aggregate) + + def on_subscribe( + self, + path: str = "/default/*", + ) -> Callable: + """ + Register a resolver function for subscribe operations. + + Parameters + ---------- + path : str, optional + The channel path pattern to match for this resolver, by default "/default/*" + + Returns + ------- + Callable + Decorator function that registers the resolver + + Examples + -------- + >>> router = Router() + >>> + >>> # Handle subscription request + >>> @router.on_subscribe(path="/chat/room/*") + >>> def authorize_subscription(event): + >>> # Verify if the client can subscribe to this room + >>> room_id = event.info.channel_path.split('/')[-1] + >>> user_id = event.identity.username + >>> + >>> # Check if user is allowed in this room + >>> is_allowed = check_permission(user_id, room_id) + >>> + >>> return { + >>> "allowed": is_allowed, + >>> "roomId": room_id + >>> } + """ + return self._subscribe_registry.register(path=path) + + def append_context(self, **additional_context): + """Append key=value data as routing context""" + self.context.update(**additional_context) + + def clear_context(self): + """Resets routing context""" + self.context.clear() diff --git a/aws_lambda_powertools/event_handler/events_appsync/types.py b/aws_lambda_powertools/event_handler/events_appsync/types.py new file mode 100644 index 00000000000..708e8df8a8c --- /dev/null +++ b/aws_lambda_powertools/event_handler/events_appsync/types.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypedDict + +if TYPE_CHECKING: + from collections.abc import Callable + + +class ResolverTypeDef(TypedDict): + """ + Type definition for resolver dictionary + Parameters + ---------- + func: Callable[..., Any] + Resolver function + aggregate: bool + Aggregation flag or method + """ + + func: Callable[..., Any] + aggregate: bool diff --git a/aws_lambda_powertools/event_handler/exception_handling.py b/aws_lambda_powertools/event_handler/exception_handling.py new file mode 100644 index 00000000000..acd8eb95bc6 --- /dev/null +++ b/aws_lambda_powertools/event_handler/exception_handling.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Mapping + +if TYPE_CHECKING: + from collections.abc import Callable + + +class ExceptionHandlerManager: + """ + A class to manage exception handlers for different exception types. + This class allows registering handler functions for specific exception types + and looking up the appropriate handler when an exception occurs. + Example usage: + ------------- + handler_manager = ExceptionHandlerManager() + @handler_manager.exception_handler(ValueError) + def handle_value_error(e): + print(f"Handling ValueError: {e}") + return "Error handled" + # To handle multiple exception types with the same handler: + @handler_manager.exception_handler([KeyError, TypeError]) + def handle_multiple_errors(e): + print(f"Handling {type(e).__name__}: {e}") + return "Multiple error types handled" + # To find and execute a handler: + try: + # some code that might raise an exception + raise ValueError("Invalid value") + except Exception as e: + handler = handler_manager.lookup_exception_handler(type(e)) + if handler: + result = handler(e) + """ + + def __init__(self): + """Initialize an empty dictionary to store exception handlers.""" + self._exception_handlers: dict[type[Exception], Callable] = {} + + def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]): + """ + A decorator function that registers a handler for one or more exception types. + Parameters + ---------- + exc_class : type[Exception] | list[type[Exception]] + A single exception type or a list of exception types. + Returns + ------- + Callable + A decorator function that registers the exception handler. + """ + + def register_exception_handler(func: Callable): + if isinstance(exc_class, list): + for exp in exc_class: + self._exception_handlers[exp] = func + else: + self._exception_handlers[exc_class] = func + return func + + return register_exception_handler + + def lookup_exception_handler(self, exp_type: type) -> Callable | None: + """ + Looks up the registered exception handler for the given exception type or its base classes. + Parameters + ---------- + exp_type : type + The exception type to look up the handler for. + Returns + ------- + Callable | None + The registered exception handler function if found, otherwise None. + """ + for cls in exp_type.__mro__: + if cls in self._exception_handlers: + return self._exception_handlers[cls] + return None + + def update_exception_handlers(self, handlers: Mapping[type[Exception], Callable]) -> None: + """ + Updates the exception handlers dictionary with new handler mappings. + This method allows bulk updates of exception handlers by providing a dictionary + mapping exception types to handler functions. + Parameters + ---------- + handlers : Mapping[Type[Exception], Callable] + A dictionary mapping exception types to handler functions. + Example + ------- + >>> def handle_value_error(e): + ... print(f"Value error: {e}") + ... + >>> def handle_key_error(e): + ... print(f"Key error: {e}") + ... + >>> handler_manager.update_exception_handlers({ + ... ValueError: handle_value_error, + ... KeyError: handle_key_error + ... }) + """ + self._exception_handlers.update(handlers) + + def get_registered_handlers(self) -> dict[type[Exception], Callable]: + """ + Returns all registered exception handlers. + Returns + ------- + Dict[Type[Exception], Callable] + A dictionary mapping exception types to their handler functions. + """ + return self._exception_handlers.copy() + + def clear_handlers(self) -> None: + """ + Clears all registered exception handlers. + """ + self._exception_handlers.clear() diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index 2757725dc62..7c1b67e6fa0 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -6,6 +6,7 @@ from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 from .api_gateway_websocket_event import APIGatewayWebSocketEvent from .appsync_resolver_event import AppSyncResolverEvent +from .appsync_resolver_events_event import AppSyncResolverEventsEvent from .aws_config_rule_event import AWSConfigRuleEvent from .bedrock_agent_event import BedrockAgentEvent from .cloud_watch_alarm_event import ( @@ -55,6 +56,7 @@ "APIGatewayWebSocketEvent", "SecretsManagerEvent", "AppSyncResolverEvent", + "AppSyncResolverEventsEvent", "ALBEvent", "BedrockAgentEvent", "CloudWatchAlarmData", diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py index 83d266b119f..af9568325a5 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -26,6 +26,46 @@ def get_identity_object(identity: dict | None) -> Any: return AppSyncIdentityIAM(identity) +class AppSyncEventBase(DictWrapper): + """AppSync resolver event base to work with AppSync GraphQL + Events""" + + @property + def request_headers(self) -> dict[str, str]: + """Request headers""" + return CaseInsensitiveDict(self["request"]["headers"]) + + @property + def domain_name(self) -> str | None: + """The domain name when using custom domain""" + return self["request"].get("domainName") + + @property + def prev_result(self) -> dict[str, Any] | None: + """It represents the result of whatever previous operation was executed in a pipeline resolver.""" + prev = self.get("prev") + return prev.get("result") if prev else None + + @property + def stash(self) -> dict: + """The stash is a map that is made available inside each resolver and function mapping template. + The same stash instance lives through a single resolver execution. This means that you can use the + stash to pass arbitrary data across request and response mapping templates, and across functions in + a pipeline resolver.""" + return self.get("stash") or {} + + @property + def identity(self) -> AppSyncIdentityIAM | AppSyncIdentityCognito | None: + """An object that contains information about the caller. + Depending on the type of identify found: + - API_KEY authorization - returns None + - AWS_IAM authorization - returns AppSyncIdentityIAM + - AMAZON_COGNITO_USER_POOLS authorization - returns AppSyncIdentityCognito + - AWS_LAMBDA authorization - returns None - NEED TO TEST + - OPENID_CONNECT authorization - returns None - NEED TO TEST + """ + return get_identity_object(self.get("identity")) + + class AppSyncIdentityIAM(DictWrapper): """AWS_IAM authorization""" @@ -141,7 +181,7 @@ def selection_set_graphql(self) -> str | None: return self.get("selectionSetGraphQL") -class AppSyncResolverEvent(DictWrapper): +class AppSyncResolverEvent(AppSyncEventBase): """AppSync resolver event **NOTE:** AppSync Resolver Events can come in various shapes this data class @@ -178,49 +218,16 @@ def arguments(self) -> dict[str, Any]: """A map that contains all GraphQL arguments for this field.""" return self["arguments"] - @property - def identity(self) -> AppSyncIdentityIAM | AppSyncIdentityCognito | None: - """An object that contains information about the caller. - - Depending on the type of identify found: - - - API_KEY authorization - returns None - - AWS_IAM authorization - returns AppSyncIdentityIAM - - AMAZON_COGNITO_USER_POOLS authorization - returns AppSyncIdentityCognito - """ - return get_identity_object(self.get("identity")) - @property def source(self) -> dict[str, Any]: """A map that contains the resolution of the parent field.""" return self.get("source") or {} - @property - def request_headers(self) -> dict[str, str]: - """Request headers""" - return CaseInsensitiveDict(self["request"]["headers"]) - - @property - def prev_result(self) -> dict[str, Any] | None: - """It represents the result of whatever previous operation was executed in a pipeline resolver.""" - prev = self.get("prev") - if not prev: - return None - return prev.get("result") - @property def info(self) -> AppSyncResolverEventInfo: """The info section contains information about the GraphQL request.""" return self._info - @property - def stash(self) -> dict: - """The stash is a map that is made available inside each resolver and function mapping template. - The same stash instance lives through a single resolver execution. This means that you can use the - stash to pass arbitrary data across request and response mapping templates, and across functions in - a pipeline resolver.""" - return self.get("stash") or {} - @overload def get_header_value( self, diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_events_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_events_event.py new file mode 100644 index 00000000000..20f354f819f --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_events_event.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any + +from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncEventBase +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper + + +class AppSyncResolverEventsInfo(DictWrapper): + @property + def channel(self) -> dict[str, Any]: + """Channel details including path and segments""" + return self["channel"] + + @property + def channel_path(self) -> str: + """Provides direct access to the 'path' attribute within the 'channel' object.""" + return self["channel"]["path"] + + @property + def channel_segments(self) -> list[str]: + """Provides direct access to the 'segments' attribute within the 'channel' object.""" + return self["channel"]["segments"] + + @property + def channel_namespace(self) -> dict: + """Namespace configuration for the channel""" + return self["channelNamespace"] + + @property + def operation(self) -> str: + """The operation being performed (e.g., PUBLISH, SUBSCRIBE)""" + return self["operation"] + + +class AppSyncResolverEventsEvent(AppSyncEventBase): + """AppSync resolver event events + Documentation: + ------------- + - TBD + """ + + @property + def events(self) -> list[dict[str, Any]]: + """The payload sent to Lambda""" + return self.get("events") or [{}] + + @property + def out_errors(self) -> list: + """The outErrors property""" + return self.get("outErrors") or [] + + @property + def info(self) -> AppSyncResolverEventsInfo: + "The info containing information about channel, namespace, and event" + return AppSyncResolverEventsInfo(self["info"]) diff --git a/docs/core/event_handler/appsync_events.md b/docs/core/event_handler/appsync_events.md index 3eeee7459fd..82e87f60eda 100644 --- a/docs/core/event_handler/appsync_events.md +++ b/docs/core/event_handler/appsync_events.md @@ -65,25 +65,25 @@ AppSync Events uses a specific event format for Lambda requests and responses. I === "payload_request.json" - ```python hl_lines="5 10 12" + ```json hl_lines="13 22 32-45" --8<-- "examples/event_handler_appsync_events/src/payload_request.json" ``` === "payload_response.json" - ```python hl_lines="5 10 12" + ```json hl_lines="4-7 10-13" --8<-- "examples/event_handler_appsync_events/src/payload_response.json" ``` === "payload_response_with_error.json" - ```python hl_lines="5 10 12" + ```json hl_lines="4" --8<-- "examples/event_handler_appsync_events/src/payload_response_with_error.json" ``` === "payload_response_fail_request.json" - ```python hl_lines="5 10 12" + ```json hl_lines="2" --8<-- "examples/event_handler_appsync_events/src/payload_response_fail_request.json" ``` @@ -104,13 +104,13 @@ You can define your handlers for different event types using the `app.on_publish === "getting_started_with_publish_events.py" - ```python hl_lines="5 10 12" + ```python hl_lines="5 10 13" --8<-- "examples/event_handler_appsync_events/src/getting_started_with_publish_events.py" ``` === "getting_started_with_subscribe_events.py" - ```python hl_lines="5 6 13" + ```python hl_lines="6 7 13 17" --8<-- "examples/event_handler_appsync_events/src/getting_started_with_subscribe_events.py" ``` @@ -134,7 +134,7 @@ When an event matches with multiple handlers, the most specific pattern takes pr === "working_with_wildcard_resolvers.py" - ```python hl_lines="5 6 13" + ```python hl_lines="5 10 13 19 26" --8<-- "examples/event_handler_appsync_events/src/working_with_wildcard_resolvers.py" ``` @@ -155,7 +155,7 @@ You can enable this with the `aggregate` parameter: === "working_with_aggregated_events.py" - ```python hl_lines="5 6 13" + ```python hl_lines="8 15 22" --8<-- "examples/event_handler_appsync_events/src/working_with_aggregated_events.py" ``` @@ -169,13 +169,13 @@ When processing items individually with `aggregate=False`, you can raise an exce === "working_with_error_handling.py" - ```python hl_lines="5 6 13" + ```python hl_lines="5 13 17 20" --8<-- "examples/event_handler_appsync_events/src/working_with_error_handling.py" ``` === "working_with_error_handling_response.json" - ```python hl_lines="5 6 13" + ```json hl_lines="4" --8<-- "examples/event_handler_appsync_events/src/working_with_error_handling_response.json" ``` @@ -185,13 +185,13 @@ When processing batch of items with `aggregate=True`, you must format the payloa === "working_with_error_handling_multiple.py" - ```python hl_lines="5 6 13" + ```python hl_lines="5 10 13 24-29" --8<-- "examples/event_handler_appsync_events/src/working_with_error_handling_multiple.py" ``` === "working_with_error_handling_response.json" - ```python hl_lines="5 6 13" + ```python hl_lines="4" --8<-- "examples/event_handler_appsync_events/src/working_with_error_handling_response.json" ``` diff --git a/examples/event_handler_appsync_events/src/accessing_event_and_context.py b/examples/event_handler_appsync_events/src/accessing_event_and_context.py index db6f456e704..85d48c23d85 100644 --- a/examples/event_handler_appsync_events/src/accessing_event_and_context.py +++ b/examples/event_handler_appsync_events/src/accessing_event_and_context.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] -from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEventsEvent # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver +from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEventsEvent if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext @@ -19,18 +19,13 @@ class ValidationError(Exception): def handle_channel1_publish(payload: dict[str, Any]): # Access the full event and context lambda_event: AppSyncResolverEventsEvent = app.current_event - lambda_context: LambdaContext = app.context # Access request headers - headers = lambda_event.get("request", {}).get("headers", {}) - - # Check remaining time - remaining_time = lambda_context.get_remaining_time_in_millis() + header_user_agent = lambda_event.request_headers["user-agent"] return { "originalMessage": payload, - "userAgent": headers.get("User-Agent"), - "timeRemaining": remaining_time, + "userAgent": header_user_agent, } diff --git a/examples/event_handler_appsync_events/src/getting_started_with_publish_events.py b/examples/event_handler_appsync_events/src/getting_started_with_publish_events.py index 10b0e73160e..bd4fa00142f 100644 --- a/examples/event_handler_appsync_events/src/getting_started_with_publish_events.py +++ b/examples/event_handler_appsync_events/src/getting_started_with_publish_events.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext diff --git a/examples/event_handler_appsync_events/src/getting_started_with_subscribe_events.py b/examples/event_handler_appsync_events/src/getting_started_with_subscribe_events.py index 6626c36ab4a..1e4b7e69d05 100644 --- a/examples/event_handler_appsync_events/src/getting_started_with_subscribe_events.py +++ b/examples/event_handler_appsync_events/src/getting_started_with_subscribe_events.py @@ -2,12 +2,16 @@ from typing import TYPE_CHECKING -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools import Metrics +from aws_lambda_powertools.event_handler import AppSyncEventsResolver +from aws_lambda_powertools.event_handler.events_appsync.exceptions import UnauthorizedException +from aws_lambda_powertools.metrics import MetricUnit if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext app = AppSyncEventsResolver() +metrics = Metrics(namespace="AppSyncEvents", service="GettingStartedWithSubscribeEvents") @app.on_subscribe("/*") @@ -16,7 +20,10 @@ def handle_all_subscriptions(): # Perform access control checks if not is_authorized(path): - raise Exception("You are not authorized to subscribe to this channel") + raise UnauthorizedException("You are not authorized to subscribe to this channel") + + metrics.add_dimension(name="channel", value=path) + metrics.add_metric(name="subscription", unit=MetricUnit.Count, value=1) return True @@ -26,5 +33,6 @@ def is_authorized(path: str): return path != "not_allowed_path_here" +@metrics.log_metrics(capture_cold_start_metric=True) def lambda_handler(event: dict, context: LambdaContext): return app.resolve(event, context) diff --git a/examples/event_handler_appsync_events/src/getting_started_with_testing_publish.py b/examples/event_handler_appsync_events/src/getting_started_with_testing_publish.py index 248447f5ff1..9d9eaefbb78 100644 --- a/examples/event_handler_appsync_events/src/getting_started_with_testing_publish.py +++ b/examples/event_handler_appsync_events/src/getting_started_with_testing_publish.py @@ -1,7 +1,7 @@ import json from pathlib import Path -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver class LambdaContext: diff --git a/examples/event_handler_appsync_events/src/getting_started_with_testing_subscribe.py b/examples/event_handler_appsync_events/src/getting_started_with_testing_subscribe.py index d91ff76b38b..54ef103183b 100644 --- a/examples/event_handler_appsync_events/src/getting_started_with_testing_subscribe.py +++ b/examples/event_handler_appsync_events/src/getting_started_with_testing_subscribe.py @@ -1,7 +1,7 @@ import json from pathlib import Path -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver class LambdaContext: diff --git a/examples/event_handler_appsync_events/src/working_with_aggregated_events.py b/examples/event_handler_appsync_events/src/working_with_aggregated_events.py index 1d238027797..6e59ba9718b 100644 --- a/examples/event_handler_appsync_events/src/working_with_aggregated_events.py +++ b/examples/event_handler_appsync_events/src/working_with_aggregated_events.py @@ -2,24 +2,38 @@ from typing import TYPE_CHECKING, Any -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +import boto3 +from boto3.dynamodb.types import TypeSerializer + +from aws_lambda_powertools.event_handler import AppSyncEventsResolver if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext +dynamodb = boto3.client("dynamodb") +serializer = TypeSerializer() app = AppSyncEventsResolver() -@app.on_publish("/default/*", aggregate=True) -def handle_default_namespace_batch(payload_list: list[dict[str, Any]]): - results: list = [] +def marshall(item: dict[str, Any]) -> dict[str, Any]: + return {k: serializer.serialize(v) for k, v in item.items()} + + +@app.on_publish("/default/foo/*", aggregate=True) +async def handle_default_namespace_batch(payload: list[dict[str, Any]]): + write_operations: list = [] + + write_operations.extend({"PutRequest": {"Item": marshall(item)}} for item in payload) - # Process all events in the batch together - for event in payload_list: - # Process each event - results.append({"id": event.get("id"), "payload": {"processed": True, "originalEvent": event}}) + # Executar operação de lote no DynamoDB + if write_operations: + dynamodb.batch_write_item( + RequestItems={ + "your-table-name": write_operations, + }, + ) - return results + return payload def lambda_handler(event: dict, context: LambdaContext): diff --git a/examples/event_handler_appsync_events/src/working_with_error_handling.py b/examples/event_handler_appsync_events/src/working_with_error_handling.py index 459cf07a819..af34fdb7fa4 100644 --- a/examples/event_handler_appsync_events/src/working_with_error_handling.py +++ b/examples/event_handler_appsync_events/src/working_with_error_handling.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext @@ -19,16 +19,12 @@ def handle_channel1_publish(payload: dict[str, Any]): if not is_valid_payload(payload): raise ValidationError("Invalid payload format") - return process_payload(payload) + return {"processed": payload["data"]} def is_valid_payload(payload: dict[str, Any]): return "data" in payload -def process_payload(payload: dict[str, Any]): - return {"processed": payload["data"]} - - def lambda_handler(event: dict, context: LambdaContext): return app.resolve(event, context) diff --git a/examples/event_handler_appsync_events/src/working_with_error_handling_multiple.py b/examples/event_handler_appsync_events/src/working_with_error_handling_multiple.py index 73165b08029..cb24e820a4a 100644 --- a/examples/event_handler_appsync_events/src/working_with_error_handling_multiple.py +++ b/examples/event_handler_appsync_events/src/working_with_error_handling_multiple.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext @@ -11,11 +11,11 @@ @app.on_publish("/default/*", aggregate=True) -def handle_default_namespace_batch(payload_list: list[dict[str, Any]]): +def handle_default_namespace_batch(payload: list[dict[str, Any]]): results: list = [] # Process all events in the batch together - for event in payload_list: + for event in payload: try: # Process each event results.append({"id": event.get("id"), "payload": {"processed": True, "originalEvent": event}}) diff --git a/examples/event_handler_appsync_events/src/working_with_wildcard_resolvers.py b/examples/event_handler_appsync_events/src/working_with_wildcard_resolvers.py index 3a53c0f480a..c6f2447c744 100644 --- a/examples/event_handler_appsync_events/src/working_with_wildcard_resolvers.py +++ b/examples/event_handler_appsync_events/src/working_with_wildcard_resolvers.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from aws_lambda_powertools.event_handler import AppSyncEventsResolver # type: ignore[attr-defined] +from aws_lambda_powertools.event_handler import AppSyncEventsResolver if TYPE_CHECKING: from aws_lambda_powertools.utilities.typing import LambdaContext diff --git a/tests/events/appSyncEventsEvent.json b/tests/events/appSyncEventsEvent.json new file mode 100644 index 00000000000..7691855dce5 --- /dev/null +++ b/tests/events/appSyncEventsEvent.json @@ -0,0 +1,70 @@ +{ + "identity":"None", + "result":"None", + "request":{ + "headers": { + "x-forwarded-for": "1.1.1.1, 2.2.2.2", + "cloudfront-viewer-country": "US", + "cloudfront-is-tablet-viewer": "false", + "via": "2.0 xxxxxxxxxxxxxxxx.cloudfront.net (CloudFront)", + "cloudfront-forwarded-proto": "https", + "origin": "https://us-west-1.console.aws.amazon.com", + "content-length": "217", + "accept-language": "en-US,en;q=0.9", + "host": "xxxxxxxxxxxxxxxx.appsync-api.us-west-1.amazonaws.com", + "x-forwarded-proto": "https", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.83 Safari/537.36", + "accept": "*/*", + "cloudfront-is-mobile-viewer": "false", + "cloudfront-is-smarttv-viewer": "false", + "accept-encoding": "gzip, deflate, br", + "referer": "https://us-west-1.console.aws.amazon.com/appsync/home?region=us-west-1", + "content-type": "application/json", + "sec-fetch-mode": "cors", + "x-amz-cf-id": "3aykhqlUwQeANU-HGY7E_guV5EkNeMMtwyOgiA==", + "x-amzn-trace-id": "Root=1-5f512f51-fac632066c5e848ae714", + "authorization": "eyJraWQiOiJScWFCSlJqYVJlM0hrSnBTUFpIcVRXazNOW...", + "sec-fetch-dest": "empty", + "x-amz-user-agent": "AWS-Console-AppSync/", + "cloudfront-is-desktop-viewer": "true", + "sec-fetch-site": "cross-site", + "x-forwarded-port": "443" + }, + "domainName":"None" + }, + "info":{ + "channel":{ + "path":"/default/channel", + "segments":[ + "default", + "channel" + ] + }, + "channelNamespace":{ + "name":"default" + }, + "operation":"PUBLISH" + }, + "error":"None", + "prev":"None", + "stash":{ + + }, + "outErrors":[ + + ], + "events":[ + { + "payload":{ + "event_1":"data_1" + }, + "id":"1" + }, + { + "payload":{ + "event_2":"data_2" + }, + "id":"2" + } + ] + } diff --git a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py new file mode 100644 index 00000000000..887adb08fe8 --- /dev/null +++ b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py @@ -0,0 +1,1614 @@ +import asyncio +from copy import deepcopy + +import pytest + +from aws_lambda_powertools.event_handler import AppSyncEventsResolver +from aws_lambda_powertools.event_handler.events_appsync.exceptions import UnauthorizedException +from aws_lambda_powertools.event_handler.events_appsync.router import Router +from aws_lambda_powertools.warnings import PowertoolsUserWarning +from tests.functional.utils import load_event + + +class LambdaContext: + def __init__(self): + self.function_name = "test-func" + self.memory_limit_in_mb = 128 + self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func" + self.aws_request_id = "52fdfc07-2182-154f-163f-5f0f9a621d72" + + def get_remaining_time_in_millis(self) -> int: + return 1000 + + +@pytest.fixture(scope="module") +def lambda_context() -> LambdaContext: + """Create a new LambdaContext instance for each test module.""" + return LambdaContext() + + +@pytest.fixture(scope="module") +def mock_event(): + """Load a sample AppSyncEventsEvent for each test module.""" + return load_event("appSyncEventsEvent.json") + + +def test_publish_event_with_synchronous_resolver(lambda_context, mock_event): + """Test handling a publish event with a synchronous resolver.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with a synchronous resolver + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + return {"processed": True, "data": payload["data"]} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the correct response + expected_result = { + "events": [ + {"id": "123", "payload": {"processed": True, "data": "test data"}}, + ], + } + assert result == expected_result + + +def test_publish_event_with_async_resolver(lambda_context, mock_event): + """Test handling a publish event with an asynchronous resolver.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with an asynchronous resolver + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/default/*") + async def test_handler(payload): + await asyncio.sleep(0.01) # Simulate async work + return {"processed": True, "data": payload["data"]} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the correct response + assert "events" in result + assert len(result["events"]) == 1 + assert result["events"][0]["payload"]["processed"] is True + assert result["events"][0]["payload"]["data"] == "test data" + + +def test_publish_event_with_error_handling(lambda_context, mock_event): + """Test error handling during publish event processing.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with a resolver that raises an exception + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + raise ValueError("Test error") + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get an error response + assert "events" in result + assert "error" in result["events"][0] + assert "ValueError - Test error" in result["events"][0]["error"] + assert result["events"][0]["id"] == "123" + + +def test_publish_event_with_router_inclusion(lambda_context, mock_event): + """Test including a router in the AppSyncEventsResolver.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data", "from_router": True}}, + ] + + # GIVEN a router with a resolver + router = Router() + + @router.on_publish(path="/chat/*") + def router_handler(payload): + return {"from_router": True, "data": payload["data"]} + + # GIVEN an AppSyncEventsResolver that includes the router + app = AppSyncEventsResolver() + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the response from the router's handler + expected_result = { + "events": [ + {"id": "123", "payload": {"from_router": True, "data": "test data"}}, + ], + } + assert result == expected_result + + +def test_publish_event_with_custom_context(lambda_context, mock_event): + """Test resolving events with custom context data.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with custom context + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + # Access the context within the handler + return { + "processed": True, + "data": payload["data"], + "user_id": app.context.get("user_id"), + "role": app.context.get("role"), + } + + # WHEN we resolve the event + app.append_context(user_id="test-user", role="admin") + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the response with context data + expected_result = { + "events": [ + { + "id": "123", + "payload": { + "processed": True, + "data": "test data", + "user_id": "test-user", + "role": "admin", + }, + }, + ], + } + assert result == expected_result + + +def test_publish_event_with_aggregate_mode(lambda_context, mock_event): + """Test handling a publish event with aggregate mode enabled.""" + # GIVEN a sample publish event with multiple items + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data 1"}}, + {"id": "456", "payload": {"data": "test data 2"}}, + ] + + # GIVEN an AppSyncEventsResolver with an aggregate resolver + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*", aggregate=True) + def test_batch_handler(payload): + # Process all events at once + return [{"batch_processed": True, "data": item["payload"]["data"]} for item in payload] + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the batch processed response + expected_result = { + "events": [ + {"batch_processed": True, "data": "test data 1"}, + {"batch_processed": True, "data": "test data 2"}, + ], + } + assert result == expected_result + + +def test_async_publish_event_with_aggregate_mode(lambda_context, mock_event): + """Test handling an async publish event with aggregate mode enabled.""" + # GIVEN a sample publish event with multiple items + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data 1"}}, + {"id": "456", "payload": {"data": "test data 2"}}, + ] + + # GIVEN an AppSyncEventsResolver with an async aggregate resolver + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/default/*", aggregate=True) + async def test_async_batch_handler(payload): + # Simulate async processing of all events + await asyncio.sleep(0.01) + return [{"async_batch_processed": True, "data": item["payload"]["data"]} for item in payload] + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the batch processed response + expected_result = { + "events": [ + {"async_batch_processed": True, "data": "test data 1"}, + {"async_batch_processed": True, "data": "test data 2"}, + ], + } + assert result == expected_result + + +def test_publish_event_no_matching_resolver(lambda_context, mock_event): + """Test handling a publish event when no matching resolver is found.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/unknown/path" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with no matching resolver + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + return {"processed": True} + + # WHEN we resolve the event with a warning + with pytest.warns(PowertoolsUserWarning, match="No resolvers were found for publish operations"): + result = app.resolve(mock_event, lambda_context) + + # THEN we should get the original payload returned as is + expected_result = { + "events": [ + {"id": "123", "payload": {"data": "test data"}}, + ], + } + assert result == expected_result + + +def test_multiple_resolvers_for_same_path(lambda_context, mock_event): + """Test behavior when both sync and async resolvers exist for the same path.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/test" + mock_event["events"] = [ + {"id": "123", "payload": {"sync_processed": True, "data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with both sync and async resolvers for the same path + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def sync_handler(payload): + return {"sync_processed": True, "data": payload["data"]} + + @app.async_on_publish(path="/default/*") + async def async_handler(event): + await asyncio.sleep(0.01) + return {"async_processed": True, "data": event["data"]} + + # WHEN we resolve the event, with a warning expected + with pytest.warns(PowertoolsUserWarning, match="Both synchronous and asynchronous resolvers found"): + result = app.resolve(mock_event, lambda_context) + + # THEN the sync resolver should be used (takes precedence) + expected_result = { + "events": [ + {"id": "123", "payload": {"sync_processed": True, "data": "test data"}}, + ], + } + assert result == expected_result + + +def test_custom_exception_handling(lambda_context, mock_event): + """Test handling custom exceptions during event processing.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"sync_processed": True, "data": "test data"}}, + ] + + # GIVEN a custom exception class + class NotAuthorized(Exception): + pass + + # GIVEN an AppSyncEventsResolver with a resolver that raises a custom exception + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + if payload["data"] == "test data": + raise NotAuthorized("Not authorized") + return {"processed": True} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get an error response with our custom exception + assert "events" in result + assert "error" in result["events"][0] + assert "NotAuthorized - Not authorized" in result["events"][0]["error"] + assert result["events"][0]["id"] == "123" + + +def test_async_resolver_with_error_handling(lambda_context, mock_event): + """Test error handling with async resolvers during publish event processing.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"sync_processed": True, "data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with an async resolver that raises an exception + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/default/*") + async def test_handler(payload): + await asyncio.sleep(0.01) # Simulate async work + raise ValueError("Async test error") + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get an error response + assert "events" in result + assert len(result["events"]) == 1 + assert "error" in result["events"][0] + assert "ValueError - Async test error" in result["events"][0]["error"] + + +def test_lambda_handler_with_call_method(lambda_context, mock_event): + """Test that the lambda handler function properly processes events.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"sync_processed": True, "data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver setup + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + return {"lambda_processed": True, "data": payload["data"]} + + # WHEN we use the AppSyncEventsResolver as a Lambda handler + result = app(mock_event, lambda_context) # Using __call__ method which calls resolve() + + # THEN we should get the processed response + expected_result = { + "events": [ + {"id": "123", "payload": {"lambda_processed": True, "data": "test data"}}, + ], + } + assert result == expected_result + + +def test_event_with_mixed_success_and_errors(lambda_context, mock_event): + """Test handling a batch of events with mixed success and failure outcomes.""" + # GIVEN a sample publish event with multiple items + mock_event["events"] = [ + {"id": "123", "payload": {"data": "good data"}}, + {"id": "456", "payload": {"data": "bad data"}}, + {"id": "789", "payload": {"data": "good data again"}}, + ] + + # GIVEN an AppSyncEventsResolver with a resolver that conditionally fails + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + if payload["data"] == "bad data": + raise ValueError("Bad data detected") + return {"success": True, "data": payload["data"]} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get mixed results with success and error responses + assert "events" in result + assert len(result["events"]) == 3 + + # First event should be successful + assert "payload" in result["events"][0] + assert result["events"][0]["payload"]["success"] is True + assert result["events"][0]["payload"]["data"] == "good data" + + # Second event should have an error + assert "error" in result["events"][1] + assert "ValueError - Bad data detected" in result["events"][1]["error"] + + # Third event should be successful + assert "payload" in result["events"][2] + assert result["events"][2]["payload"]["success"] is True + assert result["events"][2]["payload"]["data"] == "good data again" + + +def test_router_with_context_sharing(lambda_context, mock_event): + """Test that context is properly shared between routers and the main resolver.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/chat/message" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN a router with context + router = Router() + router.append_context(service="chat") + + @router.on_publish(path="/chat/*") + def router_handler(payload): + # Access shared context + return { + "from_router": True, + "service": router.context.get("service"), + "tenant": router.context.get("tenant"), + } + + # GIVEN an AppSyncEventsResolver with its own context + app = AppSyncEventsResolver() + app.append_context(tenant="acme") + + # Include the router and merge contexts + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the handler should have access to merged context from both sources + expected_result = { + "events": [ + { + "id": "123", + "payload": { + "from_router": True, + "service": "chat", + "tenant": "acme", + }, + }, + ], + } + assert result == expected_result + + +def test_context_cleared_after_resolution(lambda_context, mock_event): + """Test that context is properly cleared after event resolution.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"sync_processed": True, "data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with context data + app = AppSyncEventsResolver() + app.append_context(request_id="12345") + + @app.on_publish(path="/default/*") + def test_handler(payload): + # Verify context exists during handler execution + assert app.context.get("request_id") == "12345" + return {"processed": True} + + # WHEN we resolve the event + app.resolve(mock_event, lambda_context) + + # THEN the context should be cleared afterward + assert app.context == {} + + +def test_path_matching_mechanism(mocker, lambda_context, mock_event): + """Test the path matching mechanism for resolvers.""" + + mock_find_resolver = mocker.patch( + "aws_lambda_powertools.event_handler.events_appsync._registry.ResolverEventsRegistry.find_resolver", + ) + # GIVEN a resolver that should be found + mock_resolver = { + "func": lambda payload: {"matched": True}, + "aggregate": False, + } + mock_find_resolver.return_value = mock_resolver + + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/chat/room/123/message" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + + # WHEN we resolve the event + app.resolve(mock_event, lambda_context) + + # THEN the registry should be queried with the correct path + mock_find_resolver.assert_called_with("/chat/room/123/message") + + +def test_async_aggregate_with_parallel_processing(lambda_context, mock_event): + """Test that async aggregate handlers can process events in parallel.""" + # GIVEN a sample publish event with multiple items + mock_event["info"]["channel"]["path"] = "/default/process" + mock_event["events"] = [ + {"id": "123", "payload": {"sync_processed": True, "data": "item 1", "delay": 0.03}}, + {"id": "456", "payload": {"sync_processed": True, "data": "item 2", "delay": 0.02}}, + {"id": "789", "payload": {"sync_processed": True, "data": "item 3", "delay": 0.01}}, + ] + + # GIVEN an AppSyncEventsResolver with an async aggregate handler + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/default/*", aggregate=True) + async def test_async_handler(payload): + # Create tasks for each event with different delays + tasks = [] + for idx_event in payload: + tasks.append(process_single_event(idx_event["payload"])) + + # Process all events in parallel + results = await asyncio.gather(*tasks) + return results + + async def process_single_event(payload): + # Simulate variable processing time + await asyncio.sleep(payload["delay"]) + return {"processed": True, "data": payload["data"]} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN all events should be processed + assert "events" in result + assert len(result["events"]) == 3 + + # Check all items were processed + processed_data = [item["data"] for item in result["events"]] + assert "item 1" in processed_data + assert "item 2" in processed_data + assert "item 3" in processed_data + + +def test_both_app_and_router_for_same_path(lambda_context, mock_event): + """Test precedence when both app and router have resolvers for the same path.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/duplicate" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN a router with a resolver + router = Router() + + @router.on_publish(path="/default/duplicate") + def router_handler(payload): + return {"source": "router"} + + # GIVEN an AppSyncEventsResolver with a resolver for the same path + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/duplicate") + def app_handler(payload): + return {"source": "app"} + + # Include the router after defining the app handler + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the router's handler should take precedence as it was registered last + expected_result = { + "events": [ + {"id": "123", "payload": {"source": "router"}}, + ], + } + assert result == expected_result + + +def test_event_with_real_world_example(lambda_context, mock_event): + """Test handling a more complex, real-world-like example.""" + # GIVEN a more realistic publish event with multiple items + mock_event["info"]["channel"]["path"] = "/chat/messages" + mock_event["events"] = [ + { + "id": "message-123", + "payload": { + "type": "text", + "content": "Hello, world!", + "timestamp": 1636718400000, + "sender": "user1", + }, + }, + { + "id": "message-456", + "payload": { + "type": "image", + "content": "https://example.com/image.jpg", + "timestamp": 1636718500000, + "sender": "user2", + }, + }, + ] + + # GIVEN a router for chat-related operations + chat_router = Router() + + @chat_router.on_publish(path="/chat/*") + def process_message(payload): + # Process message based on type + if payload["type"] == "text": + return { + "processed": True, + "messageType": "text", + "displayContent": payload["content"], + "timestamp": payload["timestamp"], + "sender": payload["sender"], + } + elif payload["type"] == "image": + return { + "processed": True, + "messageType": "image", + "displayContent": f"[Image] {payload['content']}", + "timestamp": payload["timestamp"], + "sender": payload["sender"], + } + else: + return { + "processed": False, + "error": "Unsupported message type", + } + + # GIVEN an AppSyncEventsResolver that includes the router + app = AppSyncEventsResolver() + app.include_router(chat_router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get properly processed messages + assert "events" in result + assert len(result["events"]) == 2 + + # Check text message + assert result["events"][0]["id"] == "message-123" + assert result["events"][0]["payload"]["processed"] is True + assert result["events"][0]["payload"]["messageType"] == "text" + assert result["events"][0]["payload"]["displayContent"] == "Hello, world!" + + # Check image message + assert result["events"][1]["id"] == "message-456" + assert result["events"][1]["payload"]["processed"] is True + assert result["events"][1]["payload"]["messageType"] == "image" + assert result["events"][1]["payload"]["displayContent"] == "[Image] https://example.com/image.jpg" + + +def test_event_response_with_custom_error_handling(lambda_context, mock_event): + """Test handling events with custom error handling logic.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "sensitive data"}}, + ] + + # GIVEN a custom exception and a router with an async handler + class CustomSecurityException(Exception): + pass + + router = Router() + + @router.async_on_publish(path="/default/*") + async def security_check(payload): + # Simulate a security check that blocks certain IDs + blocked_data = ["sensitive data"] + if payload["data"] in blocked_data: + raise CustomSecurityException("Security check failed: Blocked ID") + + await asyncio.sleep(0.01) # Simulate async work + return {"security_verified": True, "data": payload["data"]} + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get a security error response + assert "events" in result + assert len(result["events"]) == 1 + assert "error" in result["events"][0] + assert "CustomSecurityException - Security check failed" in result["events"][0]["error"] + assert result["events"][0]["id"] == "123" + + +def test_pattern_matching_no_valid_paths(lambda_context, mock_event): + """Test that path pattern matching works correctly with wildcards.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/users/123/notifications/new" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "user notification data"}}, + ] + + # GIVEN an AppSyncEventsResolver with wildcard path patterns + app = AppSyncEventsResolver() + + # Define multiple resolvers with different path patterns + @app.on_publish(path="/users/*/notifications/*") # Should not match + def user_notification_handler(payload): + return {"handler": "wildcard_match", "data": "modified data 1"} + + @app.on_publish(path="/users/123/messages/*") # Should not match + def user_message_handler(payload): + return {"handler": "wrong_path", "data": "modified data 2"} + + @app.on_publish(path="/*/*/*") # should not match + def generic_handler(payload): + return {"handler": "generic", "data": "modified data 3"} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN no resolver is found and we return as is + expected_result = { + "events": [ + {"id": "123", "payload": {"data": "user notification data"}}, + ], + } + assert result == expected_result + + +def test_nested_async_functions(lambda_context, mock_event): + """Test that nested async functions work correctly within resolvers.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/nested" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with a resolver that uses nested async functions + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/default/*") + async def outer_handler(payload): + # Define nested async functions + async def validate_data(data): + await asyncio.sleep(0.01) # Simulate validation + return data.strip() != "" + + async def transform_data(data): + await asyncio.sleep(0.01) # Simulate transformation + return data.upper() + + # Use nested async functions + is_valid = await validate_data(payload["data"]) + if not is_valid: + return {"error": "Invalid data"} + + transformed = await transform_data(payload["data"]) + return {"validated": is_valid, "transformed": transformed} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the nested async functions should execute correctly + assert "events" in result + assert len(result["events"]) == 1 + assert result["events"][0]["payload"]["validated"] is True + assert result["events"][0]["payload"]["transformed"] == "TEST DATA" + + +def test_concurrent_event_processing(lambda_context, mock_event): + """Test that multiple events are processed concurrently with async handlers.""" + # GIVEN a sample publish event with multiple items that take different times to process + mock_event["info"]["channel"]["path"] = "/default/concurrent" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "fast data", "delay": 0.01}}, + {"id": "456", "payload": {"data": "slow data", "delay": 0.03}}, + {"id": "789", "payload": {"data": "medium data", "delay": 0.02}}, + ] + + # GIVEN an AppSyncEventsResolver with an async handler + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/default/*") + async def process_with_variable_delay(payload): + # Simulate processing with different delays + await asyncio.sleep(payload["delay"]) + return { + "processed": True, + "data": payload["data"], + "processing_time": payload["delay"], + } + + # WHEN we resolve the event + import time + + start_time = time.time() + result = app.resolve(mock_event, lambda_context) + end_time = time.time() + + # THEN all events should be processed + assert "events" in result + assert len(result["events"]) == 3 + + # The total time should be roughly equal to the longest individual delay + # (not the sum of all delays, which would indicate sequential processing) + processing_time = end_time - start_time + assert processing_time < 0.1 # Should be close to the max delay (0.03) plus overhead + + # Check all events were processed + ids = [event.get("id") for event in result["events"]] + assert set(ids) == {"123", "456", "789"} + + +def test_handler_with_implicit_call_method_in_lambda_function(lambda_context, mock_event): + """Test that the __call__ method works correctly as an implicit Lambda handler.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def test_handler(payload): + return {"processed": True, "data": payload["data"]} + + # Define a Lambda handler using the app directly + def lambda_handler(event, context): + return app(event, context) # Using __call__ method + + # WHEN we call the lambda handler + result = lambda_handler(mock_event, lambda_context) + + # THEN we should get the expected result + expected_result = { + "events": [ + {"id": "123", "payload": {"processed": True, "data": "test data"}}, + ], + } + assert result == expected_result + + +def test_middleware_like_functionality(lambda_context, mock_event): + """Test implementing middleware-like functionality with context.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + + # Simulate middleware by adding context before processing + def add_request_metadata(event, context, app): + app.append_context( + request_id="req-123", + timestamp=123456789, + user_agent="test-agent", + ) + + # Handler that uses the context added by middleware + @app.on_publish(path="/default/*") + def handler_with_middleware_data(payload): + return { + "processed": True, + "data": payload["data"], + "metadata": { + "request_id": app.context.get("request_id"), + "timestamp": app.context.get("timestamp"), + "user_agent": app.context.get("user_agent"), + }, + } + + # WHEN we add middleware data and resolve the event + add_request_metadata(mock_event, lambda_context, app) + result = app.resolve(mock_event, lambda_context) + + # THEN the handler should have access to middleware-added context + expected_metadata = { + "request_id": "req-123", + "timestamp": 123456789, + "user_agent": "test-agent", + } + + assert result["events"][0]["payload"]["metadata"] == expected_metadata + + +def test_handler_with_event_transformation(lambda_context, mock_event): + """Test handlers that transform event data before processing.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/transform" + mock_event["events"] = [ + {"id": "123", "payload": {"user_data": {"name": "John", "age": 30}}}, + {"id": "456", "payload": {"user_data": {"name": "Jane", "age": 16}}}, + ] + + # GIVEN an AppSyncEventsResolver with a router + router = Router() + + # Add middleware context to transform data + @router.on_publish(path="/default/*", aggregate=True) + def transform_and_process(payload): + # Transform the payload structure + transformed = [] + for item in payload: + transformed.append( + { + "id": item["id"], + "payload": { + "user_data": { + "fullName": item["payload"]["user_data"]["name"], + "userAge": item["payload"]["user_data"]["age"], + "isAdult": item["payload"]["user_data"]["age"] >= 18, + }, + }, + }, + ) + return transformed + + app = AppSyncEventsResolver() + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the data should be transformed + assert "events" in result + assert len(result["events"]) == 2 + + # Check transformation results + assert result["events"][0]["id"] == "123" + assert result["events"][0]["payload"]["user_data"]["fullName"] == "John" + assert result["events"][0]["payload"]["user_data"]["userAge"] == 30 + assert result["events"][0]["payload"]["user_data"]["isAdult"] is True + + assert result["events"][1]["id"] == "456" + assert result["events"][1]["payload"]["user_data"]["fullName"] == "Jane" + assert result["events"][1]["payload"]["user_data"]["userAge"] == 16 + assert result["events"][1]["payload"]["user_data"]["isAdult"] is False + + +def test_empty_events_payload(lambda_context, mock_event): + """Test handling events with an empty payload.""" + # GIVEN a sample publish event with empty events + mock_event["events"] = [] + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*", aggregate=True) + def handle_events(payload): + # Should handle empty payload gracefully + if payload == [{}]: + return [] + return [{"processed": True} for _ in payload] + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get an empty events list + assert "events" in result + assert result["events"] == [] + + +def test_multiple_related_routes_with_precedence(lambda_context, mock_event): + """Test event routing when multiple paths could match an event.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/products/electronics/phones/123" + mock_event["events"] = [ + {"id": "123", "payload": {"level": "phones", "data": "product data"}}, + ] + + # GIVEN an AppSyncEventsResolver with multiple related routes + app = AppSyncEventsResolver() + + # Define resolvers with varying specificity + @app.on_publish(path="/products/*") + def general_product_handler(payload): + return {"level": "general", "data": payload["data"]} + + @app.on_publish(path="/products/electronics/*") + def electronics_handler(payload): + return {"level": "electronics", "data": payload["data"]} + + @app.on_publish(path="/products/electronics/phones/*") + def phones_handler(payload): + return {"level": "phones", "data": payload["data"]} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the most specific matching path should be used + expected_result = { + "events": [ + {"id": "123", "payload": {"level": "phones", "data": "product data"}}, + ], + } + assert result == expected_result + + +def test_integration_with_external_service(lambda_context, mock_event): + """Test integration with an external service using mocks.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/orders/process" + mock_event["events"] = [ + {"id": "123", "payload": {"id": "order-123", "product_id": "prod-456", "quantity": 2}}, + ] + + # Mock an external service + class MockOrderService: + @staticmethod + async def process_order(order_id, product_id, quantity): + # Simulate processing delay + await asyncio.sleep(0.01) + return { + "order_id": order_id, + "status": "processed", + "total_amount": quantity * 10.99, + } + + order_service = MockOrderService() + + # GIVEN an AppSyncEventsResolver with an async resolver using the service + app = AppSyncEventsResolver() + + @app.async_on_publish(path="/orders/*") + async def process_order(payload): + # Call the external service + result = await order_service.process_order( + order_id=payload["id"], + product_id=payload["product_id"], + quantity=payload["quantity"], + ) + return { + "order_processed": True, + "order_details": result, + } + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the order should be processed with the external service + assert "events" in result + assert result["events"][0]["payload"]["order_processed"] is True + assert result["events"][0]["payload"]["order_details"]["order_id"] == "order-123" + assert result["events"][0]["payload"]["order_details"]["status"] == "processed" + assert result["events"][0]["payload"]["order_details"]["total_amount"] == 21.98 # 2 * 10.99 + + +def test_complex_resolver_hierarchy(lambda_context, mock_event): + """Test a complex setup with multiple routers and nested paths.""" + # GIVEN a complex event + mock_event["info"]["channel"]["path"] = "/api/v1/users/profile/update" + mock_event["events"] = [ + {"id": "123", "payload": {"profile": {"name": "John Doe", "email": "john@example.com"}}}, + ] + + # GIVEN multiple routers for different API parts + base_router = Router() + users_router = Router() + profiles_router = Router() + + # Add handlers to each router + @base_router.on_publish(path="/api/*") + def api_base_handler(payload): + return {"source": "base", "data": payload} + + @users_router.on_publish(path="/api/v1/users/*") + def users_handler(payload): + return {"source": "users", "data": payload} + + @profiles_router.on_publish(path="/api/v1/users/profile/*") + def profile_handler(payload): + # Do some profile-specific processing + return { + "source": "profiles", + "updated": True, + "profile": { + "fullName": payload["profile"]["name"], + "email": payload["profile"]["email"], + "timestamp": "2023-01-01T00:00:00Z", + }, + } + + # GIVEN an AppSyncEventsResolver with included routers + app = AppSyncEventsResolver() + app.include_router(base_router) + app.include_router(users_router) + app.include_router(profiles_router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the most specific router's handler should be used + assert "events" in result + assert result["events"][0]["id"] == "123" + assert result["events"][0]["payload"]["source"] == "profiles" + assert result["events"][0]["payload"]["updated"] is True + assert "fullName" in result["events"][0]["payload"]["profile"] + assert result["events"][0]["payload"]["profile"]["fullName"] == "John Doe" + + +def test_warning_behavior_with_no_matching_resolver(lambda_context, mock_event): + """Test warning behavior when no matching resolver is found.""" + # GIVEN a sample publish event with an unmatched path + mock_event["info"]["channel"]["path"] = "/unmatched/path" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with a resolver for a different path + app = AppSyncEventsResolver() + + @app.on_publish(path="/matched/path") + def test_handler(payload): + return {"processed": True} + + # WHEN we resolve the event + # THEN a warning should be generated + with pytest.warns(UserWarning, match="No resolvers were found for publish operations with path /unmatched/path"): + result = app.resolve(mock_event, lambda_context) + + # AND the payload should be returned as is + assert result == {"events": [{"id": "123", "payload": {"data": "test data"}}]} + + +def test_resolver_precedence_with_exact_match(lambda_context, mock_event): + """Test that exact path matches have precedence over wildcard matches.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/notifications/system" + mock_event["events"] = [ + {"id": "123", "payload": {"message": "System notification"}}, + ] + + # GIVEN an AppSyncEventsResolver with both wildcard and exact path resolvers + app = AppSyncEventsResolver() + + @app.on_publish(path="/notifications/*") + def wildcard_handler(payload): + return {"source": "wildcard", "message": payload["message"]} + + @app.on_publish(path="/notifications/system") + def exact_handler(payload): + return {"source": "exact", "message": payload["message"]} + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the exact path match should take precedence + expected_result = { + "events": [ + {"id": "123", "payload": {"source": "exact", "message": "System notification"}}, + ], + } + assert result == expected_result + + +def test_custom_routing_patterns(lambda_context, mock_event): + """Test custom routing patterns beyond simple wildcards.""" + # GIVEN events with different path formats + event1 = deepcopy(mock_event) + event2 = deepcopy(mock_event) + + event1["info"]["channel"]["path"] = "/users/123/posts/456" + event1["events"] = [ + {"id": "123", "payload": {"data": "user post data"}}, + ] + + event2["info"]["channel"]["path"] = "/organizations/abc/members/xyz" + event2["events"] = [ + {"id": "123", "payload": {"data": "organization member data"}}, + ] + + # GIVEN an AppSyncEventsResolver with pattern-based routing + app = AppSyncEventsResolver() + + # Define resolvers for different entity patterns + @app.on_publish(path="/users/*") + def user_resource_handler(payload): + path = app.current_event.info.channel_path + segments = path.split("/") + user_id = segments[2] + resource_type = segments[3] + + return {"entity_type": "user", "entity_id": user_id, "resource_type": resource_type, "data": payload["data"]} + + @app.on_publish(path="/organizations/*") + def org_resource_handler(payload): + path = app.current_event.info.channel_path + segments = path.split("/") + org_id = segments[2] + resource_type = segments[3] + + return { + "entity_type": "organization", + "entity_id": org_id, + "resource_type": resource_type, + "data": payload["data"], + } + + # WHEN we resolve the events + result1 = app.resolve(event1, lambda_context) + result2 = app.resolve(event2, lambda_context) + + # THEN each event should be handled by the appropriate pattern-based resolver + assert result1["events"][0]["payload"]["entity_type"] == "user" + assert result1["events"][0]["payload"]["entity_id"] == "123" + assert result1["events"][0]["payload"]["resource_type"] == "posts" + + assert result2["events"][0]["payload"]["entity_type"] == "organization" + assert result2["events"][0]["payload"]["entity_id"] == "abc" + assert result2["events"][0]["payload"]["resource_type"] == "members" + + +def test_warning_on_invalid_response_format(lambda_context, mock_event): + """Test warning generation for invalid response formats.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + {"id": "456", "payload": {"data": "more data"}}, + ] + + # GIVEN an AppSyncEventsResolver with an aggregate handler that returns non-list + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*", aggregate=True) + def invalid_format_handler(payload): + # Incorrectly return a dict instead of a list + return {"processed": True, "count": len(payload)} + + # WHEN we resolve the event + # THEN a warning should be generated about the response format + with pytest.warns(UserWarning, match="Response must be a list when using aggregate"): + result = app.resolve(mock_event, lambda_context) + + # The result should still contain what was returned + assert "events" in result + assert result["events"]["processed"] is True + assert result["events"]["count"] == 2 + + +def test_router_and_resolver_clear_context_after_resolution(lambda_context, mock_event): + """Test that both router and resolver's context are cleared after resolution.""" + # GIVEN a sample publish event + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN a router with context data + router = Router() + router.append_context(router_key="router_value") + + @router.on_publish(path="/default/*") + def router_handler(payload): + assert router.context["router_key"] == "router_value" + assert router.context["test_var"] == "app_value" + return {"processed": True} + + # GIVEN an AppSyncEventsResolver with context data + app = AppSyncEventsResolver() + app.append_context(test_var="app_value") + + # Include the router and merge contexts + app.include_router(router) + + # WHEN we resolve the event + app.resolve(mock_event, lambda_context) + + # THEN both contexts should be cleared + assert app.context == {} + assert router.context == {} + + +def test_sync_and_async_router_inclusion(lambda_context, mock_event): + """Test including multiple routers with both sync and async handlers.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/notifications/test" + mock_event["events"] = [ + {"id": "123", "payload": {"message": "test notification"}}, + ] + + # GIVEN a router with synchronous handlers + sync_router = Router() + + @sync_router.on_publish(path="/notifications/*") + def sync_handler(payload): + return {"sync": True, "message": payload["message"]} + + # GIVEN another router with asynchronous handlers + async_router = Router() + + @async_router.async_on_publish(path="/notifications/*") + async def async_handler(event): + await asyncio.sleep(0.01) + return {"async": True, "message": event["message"]} + + # GIVEN an AppSyncEventsResolver that includes both routers + app = AppSyncEventsResolver() + app.include_router(sync_router) + app.include_router(async_router) + + # WHEN we resolve the event + with pytest.warns(UserWarning, match="Both synchronous and asynchronous resolvers found"): + result = app.resolve(mock_event, lambda_context) + + # THEN the sync handler should take precedence + expected_result = { + "events": [ + {"id": "123", "payload": {"sync": True, "message": "test notification"}}, + ], + } + assert result == expected_result + + +def test_aws_lambda_context_availability_in_handlers(lambda_context, mock_event): + """Test that Lambda context is available in handlers.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/default/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with a handler that uses Lambda context + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def context_aware_handler(payload): + # Access Lambda context information + return { + "processed": True, + "function_name": app.lambda_context.function_name, + "request_id": app.lambda_context.aws_request_id, + "function_arn": app.lambda_context.invoked_function_arn, + "payload_data": payload["data"], + } + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN Lambda context information should be included in the result + assert result["events"][0]["payload"]["function_name"] == lambda_context.function_name + assert result["events"][0]["payload"]["request_id"] == lambda_context.aws_request_id + assert result["events"][0]["payload"]["function_arn"] == lambda_context.invoked_function_arn + assert result["events"][0]["payload"]["payload_data"] == "test data" + + +def test_router_lambda_context_shared(lambda_context, mock_event): + """Test that Lambda context is shared with included routers.""" + # GIVEN a sample publish event + mock_event["info"]["channel"]["path"] = "/router/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN a router with a handler that uses Lambda context + router = Router() + + @router.on_publish(path="/router/*") + def router_context_handler(payload): + # Access Lambda context from the router + return { + "from_router": True, + "function_name": router.lambda_context.function_name, + "request_id": router.lambda_context.aws_request_id, + "payload_data": payload["data"], + } + + # GIVEN an AppSyncEventsResolver that includes the router + app = AppSyncEventsResolver() + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the router should have access to the same Lambda context + assert result["events"][0]["payload"]["from_router"] is True + assert result["events"][0]["payload"]["function_name"] == lambda_context.function_name + assert result["events"][0]["payload"]["request_id"] == lambda_context.aws_request_id + assert result["events"][0]["payload"]["payload_data"] == "test data" + + +def test_current_event_availability(lambda_context, mock_event): + """Test that current_event is properly available to handlers.""" + # GIVEN a sample publish event with extra metadata + mock_event["info"]["channel"]["path"] = "/default/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver with a handler that accesses current_event + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*") + def event_aware_handler(payload): + # Access the full event object for additional context + return { + "processed": True, + "x-forwarded-for": app.current_event.request_headers["x-forwarded-for"], + "payload_data": payload["data"], + } + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the handler should have access to the full event information + assert result["events"][0]["payload"]["processed"] is True + assert result["events"][0]["payload"]["x-forwarded-for"] == mock_event["request"]["headers"]["x-forwarded-for"] + assert result["events"][0]["payload"]["payload_data"] == "test data" + + +def test_router_current_event_shared(lambda_context, mock_event): + """Test that current_event is shared with included routers.""" + # GIVEN a sample publish event with extra metadata + mock_event["info"]["channel"]["path"] = "/router/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN a router with a handler that accesses current_event + router = Router() + + @router.on_publish(path="/router/*") + def router_event_handler(payload): + # Access event information from the router + return { + "processed": True, + "x-forwarded-for": app.current_event.request_headers["x-forwarded-for"], + "payload_data": payload["data"], + } + + # GIVEN an AppSyncEventsResolver that includes the router + app = AppSyncEventsResolver() + app.include_router(router) + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN the router should have access to the same event information + assert result["events"][0]["payload"]["processed"] is True + assert result["events"][0]["payload"]["x-forwarded-for"] == mock_event["request"]["headers"]["x-forwarded-for"] + assert result["events"][0]["payload"]["payload_data"] == "test data" + + +@pytest.mark.skip(reason="Not implemented yet") +def test_channel_path_normalization(lambda_context, mock_event): + """Test that channel paths are properly normalized before matching.""" + # GIVEN sample publish events with different path formats + event1 = deepcopy(mock_event) + event2 = deepcopy(mock_event) + + event1["info"]["channel"]["path"] = "/test" + event1["events"] = [ + {"id": "123", "payload": {"data": "data1"}}, + ] + + event2["info"]["channel"]["path"] = "/test/" + event2["events"] = [ + {"id": "456", "payload": {"data": "data2"}}, + ] + + # GIVEN an AppSyncEventsResolver with a handler + app = AppSyncEventsResolver() + + @app.on_publish(path="/test") # Register with path without trailing slash + def test_handler(payload): + return {"normalized": True, "data": payload["data"]} + + # WHEN we resolve both events + result1 = app.resolve(event1, lambda_context) + result2 = app.resolve(event2, lambda_context) + + # THEN both events should be handled consistently + expected_result1 = { + "events": [ + {"id": "123", "payload": {"normalized": True, "data": "data1"}}, + ], + } + assert result1 == expected_result1 + + # With proper normalization, this should also match + expected_result2 = { + "events": [ + {"id": "456", "payload": {"normalized": True, "data": "data2"}}, + ], + } + assert result2 == expected_result2 + + +def test_subscribe_event_with_error_handling(lambda_context, mock_event): + """Test error handling during publish event processing.""" + # GIVEN a sample publish event + mock_event["info"]["operation"] = "SUBSCRIBE" + mock_event["info"]["channel"]["path"] = "/default/powertools" + del mock_event["events"] # SUBSCRIBE events are not supported + + # GIVEN an AppSyncEventsResolver with a resolver that raises an exception + app = AppSyncEventsResolver() + + @app.on_subscribe(path="/default/*") + def test_handler(): + raise ValueError("Test error") + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get an error response + assert "error" in result + assert "ValueError - Test error" in result["error"] + + +def test_subscribe_event_with_valid_return(lambda_context, mock_event): + """Test error handling during publish event processing.""" + # GIVEN a sample publish event + mock_event["info"]["operation"] = "SUBSCRIBE" + mock_event["info"]["channel"]["path"] = "/default/powertools" + + # GIVEN an AppSyncEventsResolver with a resolver that returns ok + app = AppSyncEventsResolver() + + @app.on_subscribe(path="/default/*") + def test_handler(): + return 1 + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should return None because subscribe always must return None + assert result is None + + +def test_subscribe_event_with_no_resolver(lambda_context, mock_event): + """Test error handling during publish event processing.""" + # GIVEN a sample publish event + mock_event["info"]["operation"] = "SUBSCRIBE" + mock_event["info"]["channel"]["path"] = "/default/powertools" + + # GIVEN an AppSyncEventsResolver with a resolver that returns ok + app = AppSyncEventsResolver() + + @app.on_subscribe(path="/test") + def test_handler(): + return 1 + + # WHEN we resolve the event + result = app.resolve(mock_event, lambda_context) + + # THEN we should get an error response + assert not result + + +def test_publish_events_throw_unauthorized_exception(lambda_context, mock_event): + """Test handling events with an empty payload.""" + # GIVEN a sample publish event with empty events + mock_event["info"]["operation"] = "PUBLISH" + mock_event["info"]["channel"]["path"] = "/default/test" + mock_event["events"] = [ + {"id": "123", "payload": {"data": "test data"}}, + ] + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + + @app.on_publish(path="/default/*", aggregate=True) + def handle_events(payload): + raise UnauthorizedException + + # WHEN we resolve the event with unauthorized route + with pytest.raises(UnauthorizedException): + app.resolve(mock_event, lambda_context) + + +def test_subscribe_events_throw_unauthorized_exception(lambda_context, mock_event): + """Test handling events with an empty payload.""" + # GIVEN a sample publish event with empty events + mock_event["info"]["operation"] = "SUBSCRIBE" + mock_event["info"]["channel"]["path"] = "/default/test" + + # GIVEN an AppSyncEventsResolver + app = AppSyncEventsResolver() + + @app.on_subscribe(path="/default/*") + def handle_events(): + raise UnauthorizedException + + # WHEN we resolve the event with unauthorized route + with pytest.raises(UnauthorizedException): + app.resolve(mock_event, lambda_context) diff --git a/tests/unit/data_classes/required_dependencies/test_appsync_events_event.py b/tests/unit/data_classes/required_dependencies/test_appsync_events_event.py new file mode 100644 index 00000000000..0e716dca38f --- /dev/null +++ b/tests/unit/data_classes/required_dependencies/test_appsync_events_event.py @@ -0,0 +1,16 @@ +from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEventsEvent +from tests.functional.utils import load_event + + +def test_appsync_resolver_event(): + raw_event = load_event("appSyncEventsEvent.json") + parsed_event = AppSyncResolverEventsEvent(raw_event) + + assert parsed_event.events == raw_event["events"] + assert parsed_event.out_errors == raw_event["outErrors"] + assert parsed_event.domain_name == raw_event["request"]["domainName"] + assert parsed_event.info.channel == raw_event["info"]["channel"] + assert parsed_event.info.channel_path == raw_event["info"]["channel"]["path"] + assert parsed_event.info.channel_segments == raw_event["info"]["channel"]["segments"] + assert parsed_event.info.channel_namespace == raw_event["info"]["channelNamespace"] + assert parsed_event.info.operation == raw_event["info"]["operation"] diff --git a/tests/unit/event_handler/_required_dependencies/__init__.py b/tests/unit/event_handler/_required_dependencies/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/event_handler/_required_dependencies/appsync_events/__init__.py b/tests/unit/event_handler/_required_dependencies/appsync_events/__init__.py new file mode 100644 index 00000000000..c344f9ad421 --- /dev/null +++ b/tests/unit/event_handler/_required_dependencies/appsync_events/__init__.py @@ -0,0 +1,145 @@ +import pytest + +from aws_lambda_powertools.event_handler.events_appsync.functions import find_best_route, is_valid_path + + +@pytest.mark.parametrize( + "path,expected,description", + [ + ("/*", True, "Root wildcard path is valid"), + ("/users", True, "Simple path with one segment is valid"), + ("/users/profile/settings", True, "Path with multiple segments is valid"), + ("/users/*", True, "Path ending with /* is valid"), + ("/users/*/details", False, "Path with wildcard in the middle is invalid"), + ("users/profile", False, "Path without leading slash is invalid"), + ("/users/", False, "Path with trailing slash is invalid"), + ("", False, "Empty path is invalid"), + ("/", False, "Root path / is invalid according to the regex"), + ], +) +def test_path_validation(path, expected, description): + """Test various path validation scenarios.""" + # Given a path (provided by parametrize) + + # When validating + result = is_valid_path(path) + + # Then must match the regexp + assert result is expected, description + + +def test_path_with_non_string_input(): + """Test that non-string input raises an appropriate error.""" + # Given + path = None + + # When/Then + with pytest.raises(TypeError): + is_valid_path(path) + + +@pytest.mark.parametrize( + "routes, path, expected_route, description", + [ + ( + { + "/default/v1/*": {"func": lambda x: x, "aggregate": False}, + "/default/v1/users/*": {"func": lambda x: x, "aggregate": False}, + "/default/v1/users/active/*": {"func": lambda x: x, "aggregate": False}, + }, + "/default/v1/users/active/123", + "/default/v1/users/active/*", + "Most specific route with wildcard should be matched", + ), + ], +) +def test_find_best_route_specific_wildcard(routes, path, expected_route, description): + """Test that find_best_route selects most specific wildcard path.""" + # GIVEN + + # WHEN + result = find_best_route(routes, path) + + # THEN + assert result == expected_route, description + + +@pytest.mark.parametrize( + "routes, path, expected_route, description", + [ + ( + { + "/default/v1/users": {"func": lambda x: x, "aggregate": False}, + "/default/v1/*": {"func": lambda x: x, "aggregate": False}, + }, + "/default/v1/users", + "/default/v1/users", + "Exact match wins over wildcard match", + ), + ], +) +def test_find_best_route_exact_match(routes, path, expected_route, description): + """Test that find_best_route prefers exact matches over wildcard matches.""" + # GIVEN + + # WHEN + result = find_best_route(routes, path) + + # THEN + assert result == expected_route, description + + +@pytest.mark.parametrize( + "routes, path, expected_route, description", + [ + ( + { + "/*": {"func": lambda x: x, "aggregate": False}, + "/other/*": {"func": lambda x: x, "aggregate": False}, + }, + "/default/v1/users", + "/*", + "Fallback to /* when no specific matches", + ), + ], +) +def test_find_best_route_fallback(routes, path, expected_route, description): + """Test that find_best_route falls back to /* when no specific route matches.""" + # GIVEN + + # WHEN + result = find_best_route(routes, path) + + # THEN + assert result == expected_route, description + + +@pytest.mark.parametrize( + "routes, path, expected_route, description", + [ + ( + { + "/api/v1/users/*": {"func": lambda x: x, "aggregate": False}, + "/api/v1/posts/*": {"func": lambda x: x, "aggregate": False}, + }, + "/api/v2/users/123", + None, + "No match should return None", + ), + ( + {}, + "/any/path", + None, + "Empty routes dictionary should return None", + ), + ], +) +def test_find_best_route_no_match(routes, path, expected_route, description): + """Test that find_best_route returns None when no routes match.""" + # GIVEN + + # WHEN + result = find_best_route(routes, path) + + # THEN + assert result == expected_route, description diff --git a/tests/unit/event_handler/_required_dependencies/appsync_events/test_functions.py b/tests/unit/event_handler/_required_dependencies/appsync_events/test_functions.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/event_handler/_required_dependencies/test_exception_handler_manager.py b/tests/unit/event_handler/_required_dependencies/test_exception_handler_manager.py new file mode 100644 index 00000000000..e1eed042205 --- /dev/null +++ b/tests/unit/event_handler/_required_dependencies/test_exception_handler_manager.py @@ -0,0 +1,176 @@ +import pytest + +from aws_lambda_powertools.event_handler.exception_handling import ( + ExceptionHandlerManager, # Assuming the class is in this module +) + + +@pytest.fixture +def exception_manager() -> ExceptionHandlerManager: + """Fixture to provide a fresh ExceptionHandlerManager instance for each test.""" + return ExceptionHandlerManager() + + +# ----- Tests for exception_handler decorator ----- + + +def test_decorator_registers_single_exception_handler(exception_manager): + """ + WHEN the exception_handler decorator is used with a single exception type + GIVEN a function decorated with @manager.exception_handler(ValueError) + THEN the function is registered as a handler for ValueError + """ + + @exception_manager.exception_handler(ValueError) + def handle_value_error(e): + return "ValueError handled" + + handlers = exception_manager.get_registered_handlers() + assert ValueError in handlers + assert handlers[ValueError] == handle_value_error + + +def test_decorator_registers_multiple_exception_handlers(exception_manager): + """ + GIVEN a function decorated with @manager.exception_handler([KeyError, TypeError]) + WHEN the exception_handler decorator is used with multiple exception types + THEN the function is registered as a handler for both KeyError and TypeError + """ + + @exception_manager.exception_handler([KeyError, TypeError]) + def handle_multiple_errors(e): + return f"{type(e).__name__} handled" + + handlers = exception_manager.get_registered_handlers() + assert KeyError in handlers + assert TypeError in handlers + assert handlers[KeyError] == handle_multiple_errors + assert handlers[TypeError] == handle_multiple_errors + + +def test_lookup_uses_inheritance_hierarchy(exception_manager): + # GIVEN a handler has been registered for a base exception type + @exception_manager.exception_handler(Exception) + def handle_exception(e): + return "Exception handled" + + # WHEN lookup_exception_handler is called with a derived exception type + # THEN the handler for the base exception is returned + handler = exception_manager.lookup_exception_handler(ValueError) + assert handler == handle_exception + + +def test_lookup_returns_none_for_unregistered_handler(exception_manager): + """ + GIVEN no handler has been registered for that type or its base classes + WHEN lookup_exception_handler is called with an exception type + THEN None is returned + """ + handler = exception_manager.lookup_exception_handler(ValueError) + assert handler is None + + +def test_register_handler_for_multiple_exceptions(exception_manager): + # GIVEN a valid handler function + @exception_manager.exception_handler([ValueError, KeyError]) + def handle_error(e): + return "Error handled" + + # THEN the handler is properly registered for all exceptions in the list + handlers = exception_manager.get_registered_handlers() + assert KeyError in handlers + assert ValueError in handlers + assert handlers[KeyError] == handle_error + assert handlers[ValueError] == handle_error + + +def test_update_exception_handlers_with_dictionary(exception_manager): + """ + WHEN update_exception_handlers is called with a dictionary + GIVEN the dictionary maps exception types to handler functions + THEN all handlers in the dictionary are properly registered + """ + + def handle_value_error(e): + return "ValueError handled" + + def handle_key_error(e): + return "KeyError handled" + + # Update with a dictionary of handlers + exception_manager.update_exception_handlers( + { + ValueError: handle_value_error, + KeyError: handle_key_error, + }, + ) + + handlers = exception_manager.get_registered_handlers() + assert ValueError in handlers + assert KeyError in handlers + assert handlers[ValueError] == handle_value_error + assert handlers[KeyError] == handle_key_error + + +def test_clear_handlers_removes_all_handlers(exception_manager): + # GIVEN handlers have been registered + @exception_manager.exception_handler([ValueError, KeyError]) + def handle_error(e): + return "Error handled" + + # Verify handlers are registered + assert len(exception_manager.get_registered_handlers()) == 2 + + # WHEN clear_handlers is called + exception_manager.clear_handlers() + + # THEN all handlers are removed + assert len(exception_manager.get_registered_handlers()) == 0 + + +def test_get_registered_handlers_returns_copy(exception_manager): + # WHEN get_registered_handlers is called + @exception_manager.exception_handler(ValueError) + def handle_error(e): + return "Error handled" + + # GIVEN handlers have been registered + handlers_copy = exception_manager.get_registered_handlers() + + # THEN a copy of the handlers dictionary is returned that doesn't affect the original + handlers_copy[KeyError] = lambda e: "Not registered properly" + assert KeyError not in exception_manager.get_registered_handlers() + + +def test_handler_executes_correctly(exception_manager): + # GIVEN a registered handler is executed with an exception + @exception_manager.exception_handler(ValueError) + def handle_value_error(e): + return f"Handled: {str(e)}" + + # WHEN an exception happens + # THEN the handler processes the exception correctly + try: + raise ValueError("Test error") + except Exception as e: + handler = exception_manager.lookup_exception_handler(type(e)) + result = handler(e) + assert result == "Handled: Test error" + + +def test_registering_new_handler_overrides_previous(exception_manager): + # WHEN a new handler is registered for an exception type + @exception_manager.exception_handler(ValueError) + def first_handler(e): + return "First handler" + + # GIVEN a handler was already registered for that type + @exception_manager.exception_handler(ValueError) + def second_handler(e): + return "Second handler" + + # THEN the new handler replaces the previous one + # Check that the second handler overrode the first + handler = exception_manager.lookup_exception_handler(ValueError) + assert handler == second_handler + assert handler != first_handler From 1b5e05cb9fb823a6028e722de6fd49217f1fdd5c Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 24 Apr 2025 15:00:08 -0700 Subject: [PATCH 2/2] Adding AppSync events --- .../event_handler/events_appsync/base.py | 8 +++++--- .../event_handler/events_appsync/functions.py | 4 +--- .../event_handler/events_appsync/router.py | 8 ++++---- .../appsync/test_appsync_events_resolvers.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/event_handler/events_appsync/base.py b/aws_lambda_powertools/event_handler/events_appsync/base.py index 0973553cda4..86a1e140d5d 100644 --- a/aws_lambda_powertools/event_handler/events_appsync/base.py +++ b/aws_lambda_powertools/event_handler/events_appsync/base.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod from typing import Callable +DEFAULT_ROUTE = "/default/*" + class BaseRouter(ABC): """Abstract base class for Router (resolvers)""" @@ -10,7 +12,7 @@ class BaseRouter(ABC): @abstractmethod def on_publish( self, - path: str = "/default/*", + path: str = DEFAULT_ROUTE, aggregate: bool = True, ) -> Callable: raise NotImplementedError @@ -18,7 +20,7 @@ def on_publish( @abstractmethod def async_on_publish( self, - path: str = "/default/*", + path: str = DEFAULT_ROUTE, aggregate: bool = True, ) -> Callable: raise NotImplementedError @@ -26,7 +28,7 @@ def async_on_publish( @abstractmethod def on_subscribe( self, - path: str = "/default/*", + path: str = DEFAULT_ROUTE, ) -> Callable: raise NotImplementedError diff --git a/aws_lambda_powertools/event_handler/events_appsync/functions.py b/aws_lambda_powertools/event_handler/events_appsync/functions.py index 7f4952a0dd7..0d7ddf2518f 100644 --- a/aws_lambda_powertools/event_handler/events_appsync/functions.py +++ b/aws_lambda_powertools/event_handler/events_appsync/functions.py @@ -35,9 +35,7 @@ def is_valid_path(path: str) -> bool: >>> is_valid_path('users') False """ - if path == "/*": - return True - return bool(PATH_REGEX.fullmatch(path)) + return True if path == "/*" else bool(PATH_REGEX.fullmatch(path)) def find_best_route(routes: dict[str, Any], path: str): diff --git a/aws_lambda_powertools/event_handler/events_appsync/router.py b/aws_lambda_powertools/event_handler/events_appsync/router.py index 45d7c81ddbb..167403e30fe 100644 --- a/aws_lambda_powertools/event_handler/events_appsync/router.py +++ b/aws_lambda_powertools/event_handler/events_appsync/router.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from aws_lambda_powertools.event_handler.events_appsync._registry import ResolverEventsRegistry -from aws_lambda_powertools.event_handler.events_appsync.base import BaseRouter +from aws_lambda_powertools.event_handler.events_appsync.base import DEFAULT_ROUTE, BaseRouter if TYPE_CHECKING: from collections.abc import Callable @@ -71,7 +71,7 @@ def __init__(self): def on_publish( self, - path: str = "/default/*", + path: str = DEFAULT_ROUTE, aggregate: bool = False, ) -> Callable: """ @@ -113,7 +113,7 @@ def on_publish( def async_on_publish( self, - path: str = "/default/*", + path: str = DEFAULT_ROUTE, aggregate: bool = False, ) -> Callable: """ @@ -154,7 +154,7 @@ def async_on_publish( def on_subscribe( self, - path: str = "/default/*", + path: str = DEFAULT_ROUTE, ) -> Callable: """ Register a resolver function for subscribe operations. diff --git a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py index 887adb08fe8..4d53c3cb934 100644 --- a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py +++ b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_events_resolvers.py @@ -1041,7 +1041,7 @@ async def process_order(order_id, product_id, quantity): return { "order_id": order_id, "status": "processed", - "total_amount": quantity * 10.99, + "total_amount": quantity * 10, } order_service = MockOrderService() @@ -1070,7 +1070,7 @@ async def process_order(payload): assert result["events"][0]["payload"]["order_processed"] is True assert result["events"][0]["payload"]["order_details"]["order_id"] == "order-123" assert result["events"][0]["payload"]["order_details"]["status"] == "processed" - assert result["events"][0]["payload"]["order_details"]["total_amount"] == 21.98 # 2 * 10.99 + assert result["events"][0]["payload"]["order_details"]["total_amount"] == 20 # 2 * 10 def test_complex_resolver_hierarchy(lambda_context, mock_event):