Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 107 additions & 28 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def add_route(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: Union[List[str], None] = None,
status_code: Optional[int] = None,
):
"""
Connect a URI to a handler
Expand All @@ -164,9 +165,7 @@ def add_route(
:param handler function: represents the sync or async function passed as a handler for the route
:param is_const bool: represents if the handler is a const function or not
:param auth_required bool: represents if the route needs authentication or not
"""

""" We will add the status code here only
:param status_code int|None: default HTTP status code for the response
"""
injected_dependencies = self.dependencies.get_dependency_map(self)

Expand Down Expand Up @@ -215,6 +214,7 @@ def add_route(
openapi_tags=list_openapi_tags,
exception_handler=self.exception_handler,
injected_dependencies=injected_dependencies,
default_status_code=status_code,
)

logger.info("Added route %s %s", route_type, normalized_endpoint)
Expand Down Expand Up @@ -371,6 +371,7 @@ def get(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["get"],
status_code: Optional[int] = None,
):
"""
The @app.get decorator to add a route with the GET method
Expand All @@ -380,10 +381,11 @@ def get(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, openapi_name, openapi_tags)
return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, openapi_name, openapi_tags, status_code=status_code)

return inner

Expand All @@ -393,6 +395,7 @@ def post(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["post"],
status_code: Optional[int] = None,
):
"""
The @app.post decorator to add a route with POST method
Expand All @@ -401,10 +404,13 @@ def post(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.POST, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

return inner

Expand All @@ -414,6 +420,7 @@ def put(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["put"],
status_code: Optional[int] = None,
):
"""
The @app.put decorator to add a get route with PUT method
Expand All @@ -422,10 +429,13 @@ def put(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.PUT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

return inner

Expand All @@ -435,6 +445,7 @@ def delete(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["delete"],
status_code: Optional[int] = None,
):
"""
The @app.delete decorator to add a route with DELETE method
Expand All @@ -443,10 +454,13 @@ def delete(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.DELETE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

return inner

Expand All @@ -456,6 +470,7 @@ def patch(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["patch"],
status_code: Optional[int] = None,
):
"""
The @app.patch decorator to add a route with PATCH method
Expand All @@ -464,10 +479,13 @@ def patch(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.PATCH, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

return inner

Expand All @@ -477,6 +495,7 @@ def head(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["head"],
status_code: Optional[int] = None,
):
"""
The @app.head decorator to add a route with HEAD method
Expand All @@ -485,10 +504,13 @@ def head(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.HEAD, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

return inner

Expand All @@ -498,6 +520,7 @@ def options(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["options"],
status_code: Optional[int] = None,
):
"""
The @app.options decorator to add a route with OPTIONS method
Expand All @@ -506,10 +529,19 @@ def options(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.OPTIONS,
endpoint,
handler,
auth_required=auth_required,
openapi_name=openapi_name,
openapi_tags=openapi_tags,
status_code=status_code,
)

return inner

Expand All @@ -519,6 +551,7 @@ def connect(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["connect"],
status_code: Optional[int] = None,
):
"""
The @app.connect decorator to add a route with CONNECT method
Expand All @@ -527,10 +560,19 @@ def connect(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.CONNECT,
endpoint,
handler,
auth_required=auth_required,
openapi_name=openapi_name,
openapi_tags=openapi_tags,
status_code=status_code,
)

return inner

Expand All @@ -540,6 +582,7 @@ def trace(
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["trace"],
status_code: Optional[int] = None,
):
"""
The @app.trace decorator to add a route with TRACE method
Expand All @@ -548,10 +591,13 @@ def trace(
:param auth_required bool: represents if the route needs authentication or not
:param openapi_name: str -- the name of the endpoint in the openapi spec
:param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec
:param status_code: int|None -- default HTTP status code for the response
"""

def inner(handler):
return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
return self.add_route(
HttpMethod.TRACE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

return inner

Expand Down Expand Up @@ -705,29 +751,62 @@ def __add_prefix(self, endpoint: str):

return f"{normalized_prefix}{normalized_endpoint}"

def get(self, endpoint: str, const: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["get"]):
return super().get(endpoint=self.__add_prefix(endpoint), const=const, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def get(
self,
endpoint: str,
const: bool = False,
auth_required: bool = False,
openapi_name: str = "",
openapi_tags: List[str] = ["get"],
status_code: Optional[int] = None,
):
return super().get(
endpoint=self.__add_prefix(endpoint),
const=const,
auth_required=auth_required,
openapi_name=openapi_name,
openapi_tags=openapi_tags,
status_code=status_code,
)

def post(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["post"]):
return super().post(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def post(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["post"], status_code: Optional[int] = None):
return super().post(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def put(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["put"]):
return super().put(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def put(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["put"], status_code: Optional[int] = None):
return super().put(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def delete(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["delete"]):
return super().delete(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def delete(
self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["delete"], status_code: Optional[int] = None
):
return super().delete(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def patch(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["patch"]):
return super().patch(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def patch(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["patch"], status_code: Optional[int] = None):
return super().patch(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def head(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["head"]):
return super().head(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def head(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["head"], status_code: Optional[int] = None):
return super().head(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def trace(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["trace"]):
return super().trace(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def trace(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["trace"], status_code: Optional[int] = None):
return super().trace(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def options(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["options"]):
return super().options(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags)
def options(
self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["options"], status_code: Optional[int] = None
):
return super().options(
endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags, status_code=status_code
)

def websocket(self, endpoint: str):
"""
Expand Down
11 changes: 8 additions & 3 deletions robyn/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def __post_init__(self):
"externalDocs": asdict(self.info.externalDocs) if self.info.externalDocs.url else None,
}

def add_openapi_path_obj(self, route_type: str, endpoint: str, openapi_name: str, openapi_tags: List[str], handler: Callable):
def add_openapi_path_obj(
self, route_type: str, endpoint: str, openapi_name: str, openapi_tags: List[str], handler: Callable, status_code: Optional[int] = None
):
"""
Adds the given path to openapi spec

Expand All @@ -177,6 +179,7 @@ def add_openapi_path_obj(self, route_type: str, endpoint: str, openapi_name: str
@param openapi_name: str the name of the endpoint
@param openapi_tags: List[str] for grouping of endpoints
@param handler: Callable the handler function for the endpoint
@param status_code: Optional[int] default response status code
"""

if self.openapi_file_override:
Expand Down Expand Up @@ -224,7 +227,7 @@ def add_openapi_path_obj(self, route_type: str, endpoint: str, openapi_name: str
return_annotation = signature.return_annotation

modified_endpoint, path_obj = self.get_path_obj(
endpoint, openapi_name, openapi_description, openapi_tags, query_params, request_body, return_annotation
endpoint, openapi_name, openapi_description, openapi_tags, query_params, request_body, return_annotation, status_code=status_code
)

if modified_endpoint not in self.openapi_spec["paths"]:
Expand Down Expand Up @@ -274,6 +277,7 @@ def get_path_obj(
query_params: Optional[str_typed_dict],
request_body: Optional[str_typed_dict],
return_annotation: Optional[str_typed_dict],
status_code: Optional[int] = None,
) -> Tuple[str, dict]:
"""
Get the "path" openapi object according to spec
Expand Down Expand Up @@ -369,7 +373,8 @@ def get_path_obj(
response_type = "application/json"
response_schema = self.get_schema_object("response object", return_annotation)

openapi_path_object["responses"] = {"200": {"description": "Successful Response", "content": {response_type: {"schema": response_schema}}}}
response_key = str(status_code) if status_code is not None else "200"
openapi_path_object["responses"] = {response_key: {"description": "Successful Response", "content": {response_type: {"schema": response_schema}}}}

return endpoint_with_path_params_wrapped_in_braces, openapi_path_object

Expand Down
Loading
Loading