Skip to content

Commit 6fcedec

Browse files
authored
feat(apigateway): add Router to allow large routing composition (aws-powertools#645)
1 parent 5c2444c commit 6fcedec

File tree

2 files changed

+206
-2
lines changed

2 files changed

+206
-2
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+72-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import traceback
77
import zlib
88
from enum import Enum
9-
from functools import partial
9+
from functools import partial, wraps
1010
from http import HTTPStatus
11-
from typing import Any, Callable, Dict, List, Optional, Set, Union
11+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1212

1313
from aws_lambda_powertools.event_handler import content_types
1414
from aws_lambda_powertools.event_handler.exceptions import ServiceError
@@ -630,3 +630,73 @@ def _to_response(self, result: Union[Dict, Response]) -> Response:
630630

631631
def _json_dump(self, obj: Any) -> str:
632632
return self._serializer(obj)
633+
634+
def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
635+
"""Adds all routes defined in a router"""
636+
router._app = self
637+
for route, func in router.api.items():
638+
if prefix and route[0] == "/":
639+
route = (prefix, *route[1:])
640+
elif prefix:
641+
route = (f"{prefix}{route[0]}", *route[1:])
642+
self.route(*route)(func())
643+
644+
645+
class Router:
646+
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
647+
648+
_app: ApiGatewayResolver
649+
650+
def __init__(self):
651+
self.api: Dict[tuple, Callable] = {}
652+
653+
@property
654+
def current_event(self) -> BaseProxyEvent:
655+
return self._app.current_event
656+
657+
@property
658+
def lambda_context(self) -> LambdaContext:
659+
return self._app.lambda_context
660+
661+
def route(
662+
self,
663+
rule: str,
664+
method: Union[str, Tuple[str], List[str]],
665+
cors: Optional[bool] = None,
666+
compress: bool = False,
667+
cache_control: Optional[str] = None,
668+
):
669+
def actual_decorator(func: Callable):
670+
@wraps(func)
671+
def wrapper():
672+
def inner_wrapper(**kwargs):
673+
return func(**kwargs)
674+
675+
return inner_wrapper
676+
677+
if isinstance(method, (list, tuple)):
678+
for item in method:
679+
self.api[(rule, item, cors, compress, cache_control)] = wrapper
680+
else:
681+
self.api[(rule, method, cors, compress, cache_control)] = wrapper
682+
683+
return actual_decorator
684+
685+
def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
686+
return self.route(rule, "GET", cors, compress, cache_control)
687+
688+
def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
689+
return self.route(rule, "POST", cors, compress, cache_control)
690+
691+
def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
692+
return self.route(rule, "PUT", cors, compress, cache_control)
693+
694+
def delete(
695+
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
696+
):
697+
return self.route(rule, "DELETE", cors, compress, cache_control)
698+
699+
def patch(
700+
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
701+
):
702+
return self.route(rule, "PATCH", cors, compress, cache_control)

Diff for: tests/functional/event_handler/test_api_gateway.py

+134
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ProxyEventType,
1818
Response,
1919
ResponseBuilder,
20+
Router,
2021
)
2122
from aws_lambda_powertools.event_handler.exceptions import (
2223
BadRequestError,
@@ -860,3 +861,136 @@ def base():
860861
# THEN process event correctly
861862
assert result["statusCode"] == 200
862863
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
864+
865+
866+
def test_api_gateway_app_router():
867+
# GIVEN a Router with registered routes
868+
app = ApiGatewayResolver()
869+
router = Router()
870+
871+
@router.get("/my/path")
872+
def foo():
873+
return {}
874+
875+
app.include_router(router)
876+
# WHEN calling the event handler after applying routes from router object
877+
result = app(LOAD_GW_EVENT, {})
878+
879+
# THEN process event correctly
880+
assert result["statusCode"] == 200
881+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
882+
883+
884+
def test_api_gateway_app_router_with_params():
885+
# GIVEN a Router with registered routes
886+
app = ApiGatewayResolver()
887+
router = Router()
888+
req = "foo"
889+
event = deepcopy(LOAD_GW_EVENT)
890+
event["resource"] = "/accounts/{account_id}"
891+
event["path"] = f"/accounts/{req}"
892+
lambda_context = {}
893+
894+
@router.route(rule="/accounts/<account_id>", method=["GET", "POST"])
895+
def foo(account_id):
896+
assert router.current_event.raw_event == event
897+
assert router.lambda_context == lambda_context
898+
assert account_id == f"{req}"
899+
return {}
900+
901+
app.include_router(router)
902+
# WHEN calling the event handler after applying routes from router object
903+
result = app(event, lambda_context)
904+
905+
# THEN process event correctly
906+
assert result["statusCode"] == 200
907+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
908+
909+
910+
def test_api_gateway_app_router_with_prefix():
911+
# GIVEN a Router with registered routes
912+
# AND a prefix is defined during the registration
913+
app = ApiGatewayResolver()
914+
router = Router()
915+
916+
@router.get(rule="/path")
917+
def foo():
918+
return {}
919+
920+
app.include_router(router, prefix="/my")
921+
# WHEN calling the event handler after applying routes from router object
922+
result = app(LOAD_GW_EVENT, {})
923+
924+
# THEN process event correctly
925+
assert result["statusCode"] == 200
926+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
927+
928+
929+
def test_api_gateway_app_router_with_prefix_equals_path():
930+
# GIVEN a Router with registered routes
931+
# AND a prefix is defined during the registration
932+
app = ApiGatewayResolver()
933+
router = Router()
934+
935+
@router.get(rule="/")
936+
def foo():
937+
return {}
938+
939+
app.include_router(router, prefix="/my/path")
940+
# WHEN calling the event handler after applying routes from router object
941+
# WITH the request path matching the registration prefix
942+
result = app(LOAD_GW_EVENT, {})
943+
944+
# THEN process event correctly
945+
assert result["statusCode"] == 200
946+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
947+
948+
949+
def test_api_gateway_app_router_with_different_methods():
950+
# GIVEN a Router with all the possible HTTP methods
951+
app = ApiGatewayResolver()
952+
router = Router()
953+
954+
@router.get("/not_matching_get")
955+
def get_func():
956+
raise RuntimeError()
957+
958+
@router.post("/no_matching_post")
959+
def post_func():
960+
raise RuntimeError()
961+
962+
@router.put("/no_matching_put")
963+
def put_func():
964+
raise RuntimeError()
965+
966+
@router.delete("/no_matching_delete")
967+
def delete_func():
968+
raise RuntimeError()
969+
970+
@router.patch("/no_matching_patch")
971+
def patch_func():
972+
raise RuntimeError()
973+
974+
app.include_router(router)
975+
976+
# Also check check the route configurations
977+
routes = app._routes
978+
assert len(routes) == 5
979+
for route in routes:
980+
if route.func == get_func:
981+
assert route.method == "GET"
982+
elif route.func == post_func:
983+
assert route.method == "POST"
984+
elif route.func == put_func:
985+
assert route.method == "PUT"
986+
elif route.func == delete_func:
987+
assert route.method == "DELETE"
988+
elif route.func == patch_func:
989+
assert route.method == "PATCH"
990+
991+
# WHEN calling the handler
992+
# THEN return a 404
993+
result = app(LOAD_GW_EVENT, None)
994+
assert result["statusCode"] == 404
995+
# AND cors headers are not returned
996+
assert "Access-Control-Allow-Origin" not in result["headers"]

0 commit comments

Comments
 (0)