From 587d926400a70011190064d39653736d18811dd3 Mon Sep 17 00:00:00 2001 From: MEHRSHAD MIRSHEKARY Date: Sun, 9 Feb 2025 23:03:59 +0330 Subject: [PATCH 1/4] feat(exceptions): Add centralized ExceptionHandler class for consistent error responses - Introduces the `ExceptionHandler` class to handle exceptions and return consistent JSON error responses across the application. - Maps specific Django exceptions (e.g., `ObjectDoesNotExist`, `ValidationError`, `DatabaseError`) to appropriate HTTP status codes and messages. - Provides detailed error information in debug mode, including exception type, message, and traceback. - Supports handling of common Django exceptions (e.g., `PermissionDenied`, `BadRequest`, `IntegrityError`) and falls back to a generic 500 Internal Server Er ror for unhandled exceptions. Closes #29 --- response_shaper/exceptions.py | 219 ++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 response_shaper/exceptions.py diff --git a/response_shaper/exceptions.py b/response_shaper/exceptions.py new file mode 100644 index 0000000..4e650f6 --- /dev/null +++ b/response_shaper/exceptions.py @@ -0,0 +1,219 @@ +from typing import Any, Dict, Union + +from django.conf import settings +from django.core.exceptions import ( + BadRequest, + DisallowedHost, + DisallowedRedirect, + EmptyResultSet, + FieldDoesNotExist, + FieldError, + ImproperlyConfigured, + MiddlewareNotUsed, + MultipleObjectsReturned, + ObjectDoesNotExist, + PermissionDenied, + SuspiciousOperation, + ValidationError, +) +from django.db import ( + DatabaseError, + DataError, + IntegrityError, + InternalError, + OperationalError, + ProgrammingError, +) +from django.http import JsonResponse + + +class ExceptionHandler: + """Handles exception responses consistently across the application. + + This class provides a centralized way to handle exceptions and + return consistent JSON error responses. It maps specific exceptions + to appropriate HTTP status codes and messages, and includes detailed + error information in debug mode. + + """ + + @staticmethod + def build_error_response(status_code: int, message: str) -> JsonResponse: + """Helper method to build error responses consistently. + + Args: + status_code (int): The HTTP status code for the error response. + message (str): The error message or data to be included in the response. + + Returns: + JsonResponse: A JSON response containing the error details, structured as: + { + "status": False, + "status_code": , + "error": , + "data": {} + } + + """ + return JsonResponse( + {"status": False, "status_code": status_code, "error": message, "data": {}}, + status=status_code, + ) + + @classmethod + def handle(cls, exception: Exception) -> JsonResponse: + """Processes exceptions and returns structured error responses. + + This method maps specific exceptions to appropriate HTTP status codes + and messages. It uses a dictionary of exception handlers to determine + the appropriate response. If the exception is not explicitly handled, + it falls back to a generic 500 Internal Server Error response. + + Args: + exception (Exception): The exception that was raised. + + Returns: + JsonResponse: A JSON response containing the error details. + + """ + + # pylint: disable=W0108 + # Helper functions for response consistency + def bad_request(msg="Bad request"): + """Handles 400 Bad Request errors, including detailed messages in + debug mode.""" + error_message = ( + cls._get_detailed_error_info(exception) if settings.DEBUG else msg + ) + return cls.build_error_response(400, error_message) + + def not_found(msg="Resource not found"): + """Handles 404 Not Found errors.""" + return cls.build_error_response(404, msg) + + def server_error(exc): + """Handles 500 Internal Server Errors, including detailed messages + in debug mode.""" + error_message = ( + cls._get_detailed_error_info(exc) + if settings.DEBUG + else "Internal Server Error" + ) + return cls.build_error_response(500, error_message) + + def db_error(exc, status=500): + """Handles database errors, including detailed messages in debug + mode.""" + error_message = ( + cls._get_detailed_error_info(exc) + if settings.DEBUG + else "A Database Error Occurred" + ) + return cls.build_error_response(status, error_message) + + # Exception mapping dictionary + exception_handlers = { + # Not Found + FieldDoesNotExist: lambda e: not_found("Field does not exist"), + ObjectDoesNotExist: lambda e: not_found("Object not found"), + EmptyResultSet: lambda e: not_found("No results found"), + # Bad Request + MultipleObjectsReturned: lambda e: bad_request("Multiple objects returned"), + SuspiciousOperation: lambda e: bad_request("Suspicious operation detected"), + DisallowedHost: lambda e: bad_request("Invalid host header"), + DisallowedRedirect: lambda e: bad_request("Disallowed redirect"), + BadRequest: lambda e: bad_request(), + # Permission Issues + PermissionDenied: lambda e: cls.build_error_response( + 403, "Permission denied" + ), + # Configuration & Middleware Errors + MiddlewareNotUsed: lambda e: server_error(e), + ImproperlyConfigured: lambda e: server_error(e), + # Field & Validation Errors + FieldError: lambda e: bad_request("Field error"), + ValidationError: lambda e: bad_request(cls.extract_first_error(e)), + # Database Errors + IntegrityError: lambda e: db_error(e, 400), + ProgrammingError: lambda e: db_error(e), + OperationalError: lambda e: db_error(e, 503), + DataError: lambda e: db_error(e, 400), + InternalError: lambda e: db_error(e), + DatabaseError: lambda e: db_error(e), + } + + # Use explicit lookup first + handler = exception_handlers.get(type(exception)) + if handler: + return handler(exception) + + # Fallback for subclass-based exceptions + for exc_class, handler in exception_handlers.items(): + if isinstance(exception, exc_class): + return handler(exception) + + # Catch-all for unexpected exceptions + message = ( + cls._get_detailed_error_info(exception) + if settings.DEBUG + else "Internal Server Error" + ) + return cls.build_error_response(500, message) + + @staticmethod + def extract_first_error(error_data: Any) -> Union[Any, Dict]: + """Extract the first error message from various data structures (dict, + list, string). Stops at the first error encountered. + + This method is useful for extracting the first error message from + complex error data structures, such as those returned by Django's + validation framework. + + Args: + error_data (Any): The error data structure, which can be a string, + list, or dictionary. + + Returns: + Union[str, dict]: The extracted error message or structure. If the + input is a list, it returns the first element. If the input is + a dictionary, it returns the first key-value pair. If the input + is a string, it returns the string itself. + + """ + if isinstance(error_data, str): + return error_data + if isinstance(error_data, list) and error_data: + return ExceptionHandler.extract_first_error(error_data[0]) + if isinstance(error_data, dict): + for key, value in error_data.items(): + return {key: ExceptionHandler.extract_first_error(value)} + return str(error_data) + + @staticmethod + def _get_detailed_error_info(exception: Exception) -> Dict: + """Extract detailed error information including the exception message + and traceback. + + This method is used to provide detailed error information in debug mode, + including the exception type, message, and traceback. + + Args: + exception (Exception): The exception that occurred. + + Returns: + dict: A dictionary containing the error details, structured as: + { + "message": , + "type": , + "traceback": (if DEBUG is True) + } + + """ + import traceback + + error_detail = { + "message": f"Internal Server Error: {str(exception)}", + "type": type(exception).__name__, + "traceback": traceback.format_exc() if settings.DEBUG else None, + } + return error_detail From e10a8fe2509477b768272eb1caea99e3e2d97829 Mon Sep 17 00:00:00 2001 From: MEHRSHAD MIRSHEKARY Date: Sun, 9 Feb 2025 23:08:35 +0330 Subject: [PATCH 2/4] :rotating_light::heavy_check_mark: tests(exeptions): Add ExceptionHanlder logic tests --- response_shaper/tests/test_exceptions.py | 211 +++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 response_shaper/tests/test_exceptions.py diff --git a/response_shaper/tests/test_exceptions.py b/response_shaper/tests/test_exceptions.py new file mode 100644 index 0000000..b98b6f8 --- /dev/null +++ b/response_shaper/tests/test_exceptions.py @@ -0,0 +1,211 @@ +import sys +import json + +import pytest +from django.core.exceptions import ( + ObjectDoesNotExist, + FieldDoesNotExist, + MultipleObjectsReturned, + SuspiciousOperation, + DisallowedHost, + DisallowedRedirect, + EmptyResultSet, + FieldError, + BadRequest, + PermissionDenied, + MiddlewareNotUsed, + ImproperlyConfigured, + ValidationError, +) +from django.db import ( + IntegrityError, + ProgrammingError, + OperationalError, + DataError, + InternalError, + DatabaseError, +) +from django.conf import settings +from django.http import JsonResponse + +from response_shaper.exceptions import ExceptionHandler +from response_shaper.tests.constants import PYTHON_VERSION, PYTHON_VERSION_REASON + +pytestmark = [ + pytest.mark.exceptions, + pytest.mark.skipif(sys.version_info < PYTHON_VERSION, reason=PYTHON_VERSION_REASON), +] + + +class TestExceptionHandler: + """Test suite for the ExceptionHandler class.""" + + def parse_json_response(self, response: JsonResponse) -> dict: + """ + Helper method to parse the content of a JsonResponse. + + :param response: The response object to parse. + :return: The parsed content as a Python dictionary. + """ + return json.loads(response.content.decode("utf-8")) + + @pytest.mark.parametrize( + "exception, expected_status_code, expected_error_message", + [ + # Not Found Exceptions + (ObjectDoesNotExist("Object not found"), 404, "Object not found"), + (FieldDoesNotExist("Field does not exist"), 404, "Field does not exist"), + (EmptyResultSet("No results found"), 404, "No results found"), + # Bad Request Exceptions + ( + MultipleObjectsReturned("Multiple objects returned"), + 400, + "Multiple objects returned", + ), + ( + SuspiciousOperation("Suspicious operation detected"), + 400, + "Suspicious operation detected", + ), + (DisallowedHost("Invalid host header"), 400, "Invalid host header"), + (DisallowedRedirect("Disallowed redirect"), 400, "Disallowed redirect"), + (BadRequest("Bad request"), 400, "Bad request"), + # Permission Issues + (PermissionDenied("Permission denied"), 403, "Permission denied"), + # Configuration & Middleware Errors + ( + MiddlewareNotUsed("Middleware not used"), + 500, + "Internal Server Error", + ), + ( + ImproperlyConfigured("Improperly configured"), + 500, + "Internal Server Error", + ), + # Field & Validation Errors + (FieldError("Field error"), 400, "Field error"), + # Database Errors + (IntegrityError("Integrity error"), 400, "A Database Error Occurred"), + (ProgrammingError("Programming error"), 500, "A Database Error Occurred"), + (OperationalError("Operational error"), 503, "A Database Error Occurred"), + (DataError("Data error"), 400, "A Database Error Occurred"), + (InternalError("Internal error"), 500, "A Database Error Occurred"), + (DatabaseError("Database error"), 500, "A Database Error Occurred"), + ], + ) + def test_handle_exceptions( + self, exception, expected_status_code, expected_error_message + ): + """Test that the ExceptionHandler correctly handles various exceptions. + + Args: + exception (Exception): The exception to handle. + expected_status_code (int): The expected HTTP status code. + expected_error_message (str or dict): The expected error message. + """ + response = ExceptionHandler.handle(exception) + response_data = self.parse_json_response(response) + + assert response.status_code == expected_status_code + assert response_data["status"] is False + assert response_data["status_code"] == expected_status_code + assert response_data["error"] == expected_error_message + assert response_data["data"] == {} + + def test_extract_first_error(self): + """Test the extract_first_error method with various data structures.""" + # Test with a string + assert ExceptionHandler.extract_first_error("Error message") == "Error message" + + # Test with a list + assert ( + ExceptionHandler.extract_first_error(["First error", "Second error"]) + == "First error" + ) + + # Test with a dictionary + assert ExceptionHandler.extract_first_error({"field": ["Invalid value"]}) == { + "field": "Invalid value" + } + + # Test with nested structures + assert ExceptionHandler.extract_first_error( + {"field": [{"nested": "Invalid value"}]} + ) == {"field": {"nested": "Invalid value"}} + + def test_get_detailed_error_info(self): + """Test the _get_detailed_error_info method in debug mode.""" + settings.DEBUG = True + exception = ValueError("Test error") + error_detail = ExceptionHandler._get_detailed_error_info(exception) + + assert error_detail["message"] == "Internal Server Error: Test error" + assert error_detail["type"] == "ValueError" + + def test_get_detailed_error_info_no_debug(self): + """Test the _get_detailed_error_info method when debug mode is off.""" + settings.DEBUG = False + exception = ValueError("Test error") + error_detail = ExceptionHandler._get_detailed_error_info(exception) + + assert error_detail["message"] == "Internal Server Error: Test error" + assert error_detail["type"] == "ValueError" + assert error_detail["traceback"] is None + + def test_build_error_response(self): + """Test the build_error_response method.""" + response = ExceptionHandler.build_error_response(400, "Bad request") + response_data = self.parse_json_response(response) + + assert response.status_code == 400 + assert response_data == { + "status": False, + "status_code": 400, + "error": "Bad request", + "data": {}, + } + + def test_unexpected_exception(self): + """Test that unexpected exceptions are handled with a 500 status code.""" + exception = Exception("Unexpected error") + response = ExceptionHandler.handle(exception) + response_data = self.parse_json_response(response) + + assert response.status_code == 500 + assert response_data["status"] is False + assert response_data["status_code"] == 500 + assert response_data["error"] == "Internal Server Error" + assert response_data["data"] == {} + + def test_subclass_based_exceptions(self): + """Test that subclass-based exceptions are handled correctly by the fallback logic.""" + + # Create a custom exception that inherits from a handled exception + class CustomDatabaseError(DatabaseError): + pass + + class CustomValidationError(ValidationError): + pass + + # Test with a custom database error + custom_db_error = CustomDatabaseError("Custom database error") + response = ExceptionHandler.handle(custom_db_error) + response_data = self.parse_json_response(response) + + assert response.status_code == 500 + assert response_data["status"] is False + assert response_data["status_code"] == 500 + assert response_data["error"] == "A Database Error Occurred" + assert response_data["data"] == {} + + # Test with a custom validation error + custom_validation_error = CustomValidationError({"field": ["Invalid value"]}) + response = ExceptionHandler.handle(custom_validation_error) + response_data = self.parse_json_response(response) + + assert response.status_code == 400 + assert response_data["status"] is False + assert response_data["status_code"] == 400 + assert response_data["error"] is not None + assert response_data["data"] == {} From 2403c43ed66fe26030d84746652a80e6ce6abaf2 Mon Sep 17 00:00:00 2001 From: MEHRSHAD MIRSHEKARY Date: Sun, 9 Feb 2025 23:43:37 +0330 Subject: [PATCH 3/4] feat(middleware): Add async support to DynamicResponseMiddleware - Added async-aware logic to handle both synchronous and asynchronous requests seamlessly using Django's `sync_to_async` utility. - Updated the `BaseMiddleware` class to support async mode by checking if the `get_response` function is a coroutine and marking the middleware as async-capa ble. - Ensured compatibility with Django's ASGI stack by properly handling async responses and exceptions. - Maintained existing synchronous functionality while extending support for async workflows, ensuring backward compatibility. - Added async-specific logic to process JSON responses consistently for both success and error cases. - Improved exception handling to work seamlessly in both sync and async contexts, leveraging the `ExceptionHandler` class for consistent error responses. Closes #28 --- response_shaper/middleware.py | 267 +++++++++++++++-------- response_shaper/tests/test_middleware.py | 214 +++++++++++++++--- 2 files changed, 365 insertions(+), 116 deletions(-) diff --git a/response_shaper/middleware.py b/response_shaper/middleware.py index 81b91a4..210cfcd 100644 --- a/response_shaper/middleware.py +++ b/response_shaper/middleware.py @@ -1,21 +1,138 @@ import json -from typing import Any, Callable, Optional, Union +from typing import Awaitable, Callable, Optional, Union -from django.conf import settings -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.db import IntegrityError +from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async from django.http import HttpRequest, HttpResponse, HttpResponseBase, JsonResponse -from django.utils.deprecation import MiddlewareMixin -from rest_framework.views import exception_handler +from response_shaper.exceptions import ExceptionHandler from response_shaper.settings.conf import response_shaper_config -class DynamicResponseMiddleware(MiddlewareMixin): - """A middleware to structure API responses in a consistent format based on - dynamic settings.""" +class BaseMiddleware: + """Base middleware class that supports both synchronous and asynchronous + modes. + + This class provides a foundation for creating middleware that can handle both + synchronous and asynchronous requests. Subclasses must implement the `__sync_call__` + and `__acall__` methods to define their behavior. + + Attributes: + sync_capable (bool): Indicates whether the middleware can handle synchronous requests. + async_capable (bool): Indicates whether the middleware can handle asynchronous requests. + + """ + + sync_capable: bool = True + async_capable: bool = True + + def __init__( + self, + get_response: Callable[ + [HttpRequest], Union[HttpResponseBase, Awaitable[HttpResponseBase]] + ], + ) -> None: + """Initialize the middleware. + + Args: + get_response: The next middleware or view to call. This can be either + synchronous or asynchronous. + + """ + self.get_response = get_response + self.async_mode = iscoroutinefunction(self.get_response) + if self.async_mode: + markcoroutinefunction(self) + + def __repr__(self) -> str: + """Provides a string representation of the middleware. + + Returns: + str: A string representation of the middleware, including the name of the + `get_response` function or class. + + """ + ger_response = getattr( + self.get_response, + "__qualname__", + self.get_response.__class__.__name__, + ) + return f"<{self.__class__.__qualname__} get_response={ger_response}>" + + def __call__( + self, request: HttpRequest + ) -> Union[HttpResponseBase, Awaitable[HttpResponseBase]]: + """Handles the incoming request, determining whether it's synchronous + or asynchronous. + + Args: + request (HttpRequest): The incoming HTTP request. + + Returns: + Union[HttpResponseBase, Awaitable[HttpResponseBase]]: The HTTP response, either + synchronous or asynchronous. + + """ + if self.async_mode: + return self.__acall__(request) + return self.__sync_call__(request) - def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): + def __sync_call__(self, request: HttpRequest) -> HttpResponseBase: + """Processes synchronous requests. + + Subclasses must implement this method to define how synchronous requests are handled. + + Args: + request (HttpRequest): The incoming HTTP request. + + Returns: + HttpResponseBase: The HTTP response. + + Raises: + NotImplementedError: If the method is not implemented by the subclass. + + """ + raise NotImplementedError("__sync_call__ must be implemented by subclass") + + async def __acall__(self, request: HttpRequest) -> HttpResponseBase: + """Processes asynchronous requests. + + Subclasses must implement this method to define how asynchronous requests are handled. + + Args: + request (HttpRequest): The incoming HTTP request. + + Returns: + HttpResponseBase: The HTTP response. + + Raises: + NotImplementedError: If the method is not implemented by the subclass. + + """ + raise NotImplementedError("__acall__ must be implemented by subclass") + + +class DynamicResponseMiddleware(BaseMiddleware): + """A middleware to structure API responses in a consistent format based on + dynamic settings. + + This middleware modifies API responses to follow a consistent JSON structure for both + success and error cases. It can be configured to exclude certain paths and supports + custom success and error handlers. + + Attributes: + excluded_paths (list): Paths for which response shaping should be skipped. + debug (bool): Whether debug mode is enabled. + success_handler (Callable): The handler for successful responses. + error_handler (Callable): The handler for error responses. + + """ + + def __init__( + self, + get_response: Callable[ + [HttpRequest], Union[HttpResponseBase, Awaitable[HttpResponseBase]] + ], + ): """Initialize the middleware with dynamic settings. Args: @@ -32,7 +149,7 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): response_shaper_config.error_handler, self._default_error_handler ) - def __call__(self, request: HttpRequest) -> HttpResponseBase: + def __sync_call__(self, request: HttpRequest) -> HttpResponseBase: """Process the request and response. Args: @@ -45,6 +162,19 @@ def __call__(self, request: HttpRequest) -> HttpResponseBase: response = self.get_response(request) return self.process_response(request, response) + async def __acall__(self, request: HttpRequest) -> HttpResponseBase: + """Process the request and response asynchronously. + + Args: + request: The incoming HTTP request. + + Returns: + HttpResponseBase: The structured HTTP response. + + """ + response = await self.get_response(request) + return await self.process_response_async(request, response) + def process_response( self, request: HttpRequest, response: HttpResponseBase ) -> HttpResponseBase: @@ -74,55 +204,50 @@ def process_response( else: return self.error_handler(response) - def process_exception( - self, request: HttpRequest, exception: Exception - ) -> Optional[HttpResponseBase]: - """Handle exceptions and structure error responses consistently. + async def process_response_async( + self, request: HttpRequest, response: HttpResponseBase + ) -> HttpResponseBase: + """Processes async responses, structuring JSON responses. Args: request: The incoming HTTP request. - exception: The raised exception to handle. + response: The original HTTP response. Returns: - Optional[HttpResponse]: The structured error response or None. + HttpResponseBase: The processed HTTP response, with structured JSON if applicable. """ - response = exception_handler(exception, None) - if self.shape_is_not_allowed(request): return response - # Handle specific Django exceptions explicitly - if isinstance(exception, IntegrityError): - return self._build_error_response(400, str(exception)) - if isinstance(exception, ValidationError): - return self._build_error_response(400, self._extract_first_error(exception)) - if isinstance(exception, ObjectDoesNotExist): - return self._build_error_response(404, "Object not found") + content_type = response.headers.get("Content-Type", "") - # Generic 500 Internal Server Error - detailed_error_message = self._get_detailed_error_info(exception) - return self._build_error_response(500, detailed_error_message) + if not content_type.startswith("application/json"): + return response - def _get_detailed_error_info(self, exception: Exception) -> dict: - """Extract detailed error information including the exception message - and traceback. + return await sync_to_async( + self.success_handler + if 200 <= response.status_code < 300 + else self.error_handler + )(response) + + def process_exception( + self, request: HttpRequest, exception: Exception + ) -> Optional[HttpResponseBase]: + """Handle exceptions and structure error responses consistently. Args: - exception: The exception that occurred. + request: The incoming HTTP request. + exception: The raised exception to handle. Returns: - dict: A dictionary containing the error details and traceback. + Optional[HttpResponse]: The structured error response or None. """ - import traceback + if self.shape_is_not_allowed(request): + return None # pass to let Django handle the exception - error_detail = { - "message": f"Internal Server Error: {str(exception)}", - "type": type(exception).__name__, - "traceback": traceback.format_exc() if settings.DEBUG else None, - } - return error_detail + return ExceptionHandler.handle(exception) def _default_success_handler(self, response: HttpResponse) -> JsonResponse: """Default handler for successful responses. @@ -160,55 +285,15 @@ def _default_error_handler(self, response: HttpResponse) -> JsonResponse: """ if hasattr(response, "data"): - error_message = self._extract_first_error(response.data) + error_message = ExceptionHandler.extract_first_error(response.data) else: # Decode content if 'data' is not available - error_message = json.loads(response.content.decode("utf-8"))["error"] - - return self._build_error_response(response.status_code, error_message) - - def _build_error_response( - self, status_code: int, message: Union[str, dict] - ) -> JsonResponse: - """Helper method to build error responses consistently. - - Args: - status_code: The HTTP status code for the error response. - message: The error message or data. - - Returns: - JsonResponse: The structured error response. + error_message = json.loads(response.content.decode("utf-8")).get("error") - """ - return JsonResponse( - {"status": False, "status_code": status_code, "error": message, "data": {}}, - status=status_code, + return ExceptionHandler.build_error_response( + response.status_code, error_message ) - def _extract_first_error(self, error_data: Any) -> Union[str, dict]: - """Extract the first error message from various data structures (dict, - list, string). Stops at the first error encountered. - - Args: - error_data: The error data structure. - - Returns: - Union[str, dict]: The extracted error message or structure. - - """ - if isinstance(error_data, str): - return error_data - elif isinstance(error_data, list): - if error_data: - return self._extract_first_error(error_data[0]) - elif isinstance(error_data, dict): - for key, value in error_data.items(): - first_error = self._extract_first_error(value) - if isinstance(first_error, str): - return {key: first_error} # Return as a dict with the field name - - return str(error_data) - def get_dynamic_handler( self, handler_path: str, default_handler: Callable ) -> Callable: @@ -235,7 +320,8 @@ def shape_is_not_allowed(self, request: HttpRequest) -> bool: request. This method checks whether the middleware should skip response shaping - based on the `debug` mode or if the request path is in the list of excluded paths. + based on the `debug` mode or if the request path starts with any of the + excluded paths. Args: request (HttpRequest): The incoming HTTP request object. @@ -244,4 +330,11 @@ def shape_is_not_allowed(self, request: HttpRequest) -> bool: bool: True if response shaping is not allowed (i.e., should be skipped), False otherwise. """ - return self.debug or request.path in self.excluded_paths + if self.debug: + return True + + for excluded_path in self.excluded_paths: + if request.path.startswith(excluded_path): + return True + + return False diff --git a/response_shaper/tests/test_middleware.py b/response_shaper/tests/test_middleware.py index 2bc855f..59601cb 100644 --- a/response_shaper/tests/test_middleware.py +++ b/response_shaper/tests/test_middleware.py @@ -1,12 +1,16 @@ +from asyncio import iscoroutinefunction + import pytest import json import sys -from unittest.mock import patch -from django.http import JsonResponse, HttpResponse +from unittest.mock import patch, Mock +from django.http import JsonResponse, HttpResponse, HttpRequest, HttpResponseBase from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.db import IntegrityError from django.test import RequestFactory -from response_shaper.middleware import DynamicResponseMiddleware + +from response_shaper.exceptions import ExceptionHandler +from response_shaper.middleware import DynamicResponseMiddleware, BaseMiddleware from response_shaper.settings.conf import response_shaper_config from response_shaper.tests.constants import PYTHON_VERSION, PYTHON_VERSION_REASON from typing import Dict, Callable, List @@ -17,6 +21,61 @@ ] +class TestBaseMiddleware: + """ + Test suite for the BaseMiddleware class. + """ + + def test_sync_mode(self) -> None: + """ + Test that the middleware correctly identifies and handles synchronous requests. + This test verifies that when the `get_response` function is synchronous, + the middleware calls the `__sync_call__` method. + """ + # Mock synchronous get_response + mock_get_response = Mock(spec=Callable[[HttpRequest], HttpResponseBase]) + + # Create an instance of the middleware + middleware = BaseMiddleware(mock_get_response) + + # Ensure that it is in synchronous mode + assert not iscoroutinefunction(middleware.get_response) + assert not middleware.async_mode + + # Test that calling the middleware raises NotImplementedError (since __sync_call__ is not implemented) + with pytest.raises( + NotImplementedError, match="__sync_call__ must be implemented by subclass" + ): + request = HttpRequest() + middleware(request) + + @pytest.mark.asyncio + async def test_async_mode(self) -> None: + """ + Test that the middleware correctly identifies and handles asynchronous requests. + This test verifies that when the `get_response` function is asynchronous, + the middleware calls the `__acall__` method. + """ + + # Mock asynchronous get_response + async def mock_get_response(request: HttpRequest) -> HttpResponseBase: + return Mock(spec=HttpResponseBase) + + # Create an instance of the middleware + middleware = BaseMiddleware(mock_get_response) + + # Ensure that it is in asynchronous mode + assert iscoroutinefunction(middleware.get_response) + assert middleware.async_mode + + # Test that calling the middleware raises NotImplementedError (since __acall__ is not implemented) + with pytest.raises( + NotImplementedError, match="__acall__ must be implemented by subclass" + ): + request = HttpRequest() + await middleware(request) + + @pytest.mark.django_db class TestDynamicResponseMiddleware: """ @@ -123,9 +182,7 @@ def test_process_exception( middleware = DynamicResponseMiddleware(get_response) # Simulate an IntegrityError exception - with patch.object( - DynamicResponseMiddleware, "_build_error_response" - ) as mock_build_error: + with patch.object(ExceptionHandler, "build_error_response") as mock_build_error: mock_build_error.return_value = JsonResponse( { "status": False, @@ -248,9 +305,7 @@ def test_process_object_does_not_exist_exception( middleware = DynamicResponseMiddleware(get_response) # Simulate ObjectDoesNotExist exception - with patch.object( - DynamicResponseMiddleware, "_build_error_response" - ) as mock_build_error: + with patch.object(ExceptionHandler, "build_error_response") as mock_build_error: mock_build_error.return_value = JsonResponse( { "status": False, @@ -277,70 +332,62 @@ def test_extract_string_error(self) -> None: """ Test that the middleware correctly extracts a string error message. """ - middleware = DynamicResponseMiddleware(HttpResponse) error = "This is an error message" - result = middleware._extract_first_error(error) + result = ExceptionHandler.extract_first_error(error) assert result == "This is an error message" def test_extract_list_of_errors(self) -> None: """ Test that the middleware correctly extracts the first error from a list of errors. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors = ["First error", "Second error"] - result = middleware._extract_first_error(errors) + result = ExceptionHandler.extract_first_error(errors) assert result == "First error" def test_extract_nested_list_of_errors(self) -> None: """ Test that the middleware correctly extracts the first error from a nested list of errors. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors = [["Nested error", "Another error"], "Second error"] - result = middleware._extract_first_error(errors) + result = ExceptionHandler.extract_first_error(errors) assert result == "Nested error" def test_extract_dict_of_errors(self) -> None: """ Test that the middleware correctly extracts the first error from a dictionary of errors. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors = {"field1": "Field1 error", "field2": "Field2 error"} - result = middleware._extract_first_error(errors) + result = ExceptionHandler.extract_first_error(errors) assert result == {"field1": "Field1 error"} def test_extract_nested_dict_of_errors(self) -> None: """ Test that the middleware correctly extracts the first error from a nested dictionary of errors. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors = {"field1": {"subfield": "Subfield error"}, "field2": "Field2 error"} - result = middleware._extract_first_error(errors) - assert result == {"field2": "Field2 error"} + result = ExceptionHandler.extract_first_error(errors) + assert result == {"field1": {"subfield": "Subfield error"}} def test_extract_empty_list(self) -> None: """ Test that the middleware correctly handles an empty list of errors. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors: List = [] - result = middleware._extract_first_error(errors) + result = ExceptionHandler.extract_first_error(errors) assert result == "[]" def test_extract_empty_dict(self) -> None: """ Test that the middleware correctly handles an empty dictionary of errors. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors: Dict = {} - result = middleware._extract_first_error(errors) + result = ExceptionHandler.extract_first_error(errors) assert result == "{}" def test_extract_complex_structure(self) -> None: """ Test that the middleware correctly extracts errors from a complex nested structure. """ - middleware = DynamicResponseMiddleware(HttpResponse) errors = { "field1": [ {"subfield": ["Subfield error", "Another error"]}, @@ -348,8 +395,8 @@ def test_extract_complex_structure(self) -> None: ], "field2": "Field2 error", } - result = middleware._extract_first_error(errors) - assert result == {"field2": "Field2 error"} + result = ExceptionHandler.extract_first_error(errors) + assert result == {"field1": {"subfield": "Subfield error"}} def test_skip_non_json_content_type( self, request_factory: RequestFactory, get_response: Callable @@ -411,7 +458,7 @@ def test_process_exception_handling( response_data = json.loads(processed_response.content) assert response_data["status"] is False assert response_data["status_code"] == 500 - assert "Internal Server Error" in response_data["error"]["message"] + assert "Internal Server Error" in response_data["error"] def test_process_validation_error( self, request_factory: RequestFactory, get_response: Callable @@ -452,7 +499,7 @@ def test_process_django_errors( processed_response = middleware.process_exception(request, integrity_error) response_data = json.loads(processed_response.content) assert processed_response.status_code == 400 - assert response_data["error"] == "Integrity error occurred" + assert response_data["error"] == "A Database Error Occurred" # Test ObjectDoesNotExist handling object_not_found = ObjectDoesNotExist() @@ -460,3 +507,112 @@ def test_process_django_errors( response_data = json.loads(processed_response.content) assert processed_response.status_code == 404 assert response_data["error"] == "Object not found" + + @pytest.mark.asyncio + async def test_async_success_response( + self, request_factory: RequestFactory + ) -> None: + """ + Test that the middleware correctly processes an asynchronous successful response. + + :param request_factory: Fixture to generate mock requests. + """ + + # Mock asynchronous get_response + async def mock_get_response(request: HttpRequest) -> HttpResponseBase: + return JsonResponse({"key": "value"}, status=200) + + request = request_factory.get("/api/test/") + middleware = DynamicResponseMiddleware(mock_get_response) + + # Call the middleware + response = await middleware(request) + response_data = self.parse_json_response(response) + + # Ensure the response is structured correctly + assert response.status_code == 200 + assert response_data == { + "status": True, + "status_code": 200, + "error": None, + "data": {"key": "value"}, + } + + @pytest.mark.asyncio + async def test_async_error_response(self, request_factory: RequestFactory) -> None: + """ + Test that the middleware correctly processes an asynchronous error response. + + :param request_factory: Fixture to generate mock requests. + """ + + # Mock asynchronous get_response + async def mock_get_response(request: HttpRequest) -> HttpResponseBase: + return JsonResponse({"error": "Some error occurred"}, status=400) + + request = request_factory.get("/api/test/") + middleware = DynamicResponseMiddleware(mock_get_response) + + # Call the middleware + response = await middleware(request) + response_data = self.parse_json_response(response) + + # Ensure the response is structured correctly + assert response.status_code == 400 + assert response_data == { + "status": False, + "status_code": 400, + "error": "Some error occurred", + "data": {}, + } + + @pytest.mark.asyncio + async def test_async_skip_non_json_content_type( + self, request_factory: RequestFactory + ) -> None: + """ + Test that non-JSON responses are skipped by the middleware in async mode. + + :param request_factory: Fixture to generate mock requests. + """ + + # Mock asynchronous get_response + async def mock_get_response(request: HttpRequest) -> HttpResponseBase: + return HttpResponse("", content_type="text/html") + + request = request_factory.get("/api/test/") + middleware = DynamicResponseMiddleware(mock_get_response) + + # Call the middleware + response = await middleware(request) + + # Ensure the middleware returns the original response for non-JSON content + assert response.content == b"" + assert response.status_code == 200 + assert response.headers["Content-Type"] == "text/html" + + @pytest.mark.asyncio + async def test_async_excluded_paths(self, request_factory: RequestFactory) -> None: + """ + Test that excluded paths are skipped by the middleware in async mode. + + :param request_factory: Fixture to generate mock requests. + """ + # Mock the config to include an excluded path + with patch.object( + response_shaper_config, "excluded_paths", new=["/api/excluded/"] + ): + # Mock asynchronous get_response + async def mock_get_response(request: HttpRequest) -> HttpResponseBase: + return JsonResponse({"key": "value"}, status=200) + + request = request_factory.get("/api/excluded/") + middleware = DynamicResponseMiddleware(mock_get_response) + + # Call the middleware for an excluded path + response = await middleware(request) + response_data = self.parse_json_response(response) + + # Ensure that the middleware does not alter the response + assert response.status_code == 200 + assert response_data == {"key": "value"} From 2162af562552f9e758aaa0c24ebee075326eeff3 Mon Sep 17 00:00:00 2001 From: MEHRSHAD MIRSHEKARY Date: Sun, 9 Feb 2025 23:47:28 +0330 Subject: [PATCH 4/4] :zap: Update(tox) config file to add pytest-asyncio as new dependency --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 8f7773e..59503e4 100644 --- a/tox.ini +++ b/tox.ini @@ -20,6 +20,7 @@ env_list = description = Run Pytest tests with multiple django and drf versions deps = pytest + pytest-asyncio pytest-cov pytest-django django40: django<5.0,>=4.2