Skip to content

Commit 7b6cdce

Browse files
amin-farjadiAmin Farjadileandrodamascena
authored
fix(openapi): correct response validation for falsy objects (#7990)
* fix(openapi): correct response validation for falsy objects * chore: remove skip from falsy return type tests --------- Co-authored-by: Amin Farjadi <amin.farjadi@eonnext.com> Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
1 parent f7290ae commit 7b6cdce

File tree

3 files changed

+100
-39
lines changed

3 files changed

+100
-39
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,27 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
228228
return self._handle_response(route=route, response=response)
229229

230230
def _handle_response(self, *, route: Route, response: Response):
231-
# Process the response body if it exists
232-
if response.body and response.is_json():
233-
response.body = self._serialize_response(
234-
field=route.dependant.return_param,
231+
field = route.dependant.return_param
232+
233+
if field is None:
234+
if not response.is_json():
235+
return response
236+
else:
237+
# JSON serialize the body without validation
238+
response.body = jsonable_encoder(response.body, custom_serializer=self._validation_serializer)
239+
else:
240+
response.body = self._serialize_response_with_validation(
241+
field=field,
235242
response_content=response.body,
236243
has_route_custom_response_validation=route.custom_response_validation_http_code is not None,
237244
)
238245

239246
return response
240247

241-
def _serialize_response(
248+
def _serialize_response_with_validation(
242249
self,
243250
*,
244-
field: ModelField | None = None,
251+
field: ModelField,
245252
response_content: Any,
246253
include: IncEx | None = None,
247254
exclude: IncEx | None = None,
@@ -254,45 +261,42 @@ def _serialize_response(
254261
"""
255262
Serialize the response content according to the field type.
256263
"""
257-
if field:
258-
errors: list[dict[str, Any]] = []
259-
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
260-
if errors:
261-
# route-level validation must take precedence over app-level
262-
if has_route_custom_response_validation:
263-
raise ResponseValidationError(
264-
errors=_normalize_errors(errors),
265-
body=response_content,
266-
source="route",
267-
)
268-
if self._has_response_validation_error:
269-
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
270-
271-
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
272-
273-
if hasattr(field, "serialize"):
274-
return field.serialize(
275-
value,
276-
include=include,
277-
exclude=exclude,
278-
by_alias=by_alias,
279-
exclude_unset=exclude_unset,
280-
exclude_defaults=exclude_defaults,
281-
exclude_none=exclude_none,
264+
errors: list[dict[str, Any]] = []
265+
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
266+
if errors:
267+
# route-level validation must take precedence over app-level
268+
if has_route_custom_response_validation:
269+
raise ResponseValidationError(
270+
errors=_normalize_errors(errors),
271+
body=response_content,
272+
source="route",
282273
)
283-
return jsonable_encoder(
274+
if self._has_response_validation_error:
275+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content, source="app")
276+
277+
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
278+
279+
if hasattr(field, "serialize"):
280+
return field.serialize(
284281
value,
285282
include=include,
286283
exclude=exclude,
287284
by_alias=by_alias,
288285
exclude_unset=exclude_unset,
289286
exclude_defaults=exclude_defaults,
290287
exclude_none=exclude_none,
291-
custom_serializer=self._validation_serializer,
292288
)
293-
else:
294-
# Just serialize the response content returned from the handler.
295-
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)
289+
290+
return jsonable_encoder(
291+
value,
292+
include=include,
293+
exclude=exclude,
294+
by_alias=by_alias,
295+
exclude_unset=exclude_unset,
296+
exclude_defaults=exclude_defaults,
297+
exclude_none=exclude_none,
298+
custom_serializer=self._validation_serializer,
299+
)
296300

297301
def _prepare_response_content(
298302
self,

tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def search(
209209
# =============================================================================
210210

211211

212+
@pytest.mark.skip("Due to issue #7981.")
212213
@pytest.mark.asyncio
213214
async def test_async_handler_with_validation():
214215
# GIVEN an app with async handler and validation

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,39 @@ def handler(user_id: int = 123):
16321632
assert result["statusCode"] == 200
16331633

16341634

1635-
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
1635+
def test_validate_list_response(gw_event):
1636+
# GIVEN an APIGatewayRestResolver with validation enabled
1637+
app = APIGatewayRestResolver(enable_validation=True)
1638+
1639+
class Model(BaseModel):
1640+
name: str
1641+
age: int
1642+
1643+
response_before_validation = [
1644+
{
1645+
"name": "Joe",
1646+
"age": 20,
1647+
},
1648+
{
1649+
"name": "Jane",
1650+
"age": 20,
1651+
},
1652+
]
1653+
1654+
@app.get("/list_response_with_same_element_types")
1655+
def handler_different_list() -> List[Model]:
1656+
return response_before_validation
1657+
1658+
# WHEN returning list with the same element type as the non-Optional return type
1659+
gw_event["path"] = "/list_response_with_same_element_types"
1660+
result = app(gw_event, {})
1661+
body = json.loads(result["body"])
1662+
1663+
# THEN it should return a validation error
1664+
assert result["statusCode"] == 200
1665+
assert body == response_before_validation
1666+
1667+
16361668
def test_validation_error_none_returned_non_optional_type(gw_event):
16371669
# GIVEN an APIGatewayRestResolver with validation enabled
16381670
app = APIGatewayRestResolver(enable_validation=True)
@@ -1656,6 +1688,32 @@ def handler_none_not_allowed() -> Model:
16561688
assert body["detail"][0]["loc"] == ["response"]
16571689

16581690

1691+
def test_validation_error_different_list_returned_non_optional_type(gw_event):
1692+
# GIVEN an APIGatewayRestResolver with validation enabled
1693+
app = APIGatewayRestResolver(enable_validation=True)
1694+
1695+
class Model(BaseModel):
1696+
name: str
1697+
age: int
1698+
1699+
different_list_response = ["a", "b", "c"]
1700+
1701+
@app.get("/list_response_with_different_element_types")
1702+
def handler_different_list() -> List[Model]:
1703+
return different_list_response
1704+
1705+
# WHEN returning list with the different element type as the non-Optional return type
1706+
gw_event["path"] = "/list_response_with_different_element_types"
1707+
result = app(gw_event, {})
1708+
1709+
# THEN it should return a validation error
1710+
assert result["statusCode"] == 422
1711+
body = json.loads(result["body"])
1712+
assert len(body["detail"]) == len(different_list_response)
1713+
assert body["detail"][0]["type"] == "model_attributes_type"
1714+
assert body["detail"][0]["loc"] == ["response", 0]
1715+
1716+
16591717
def test_validation_error_incomplete_model_returned_non_optional_type(gw_event):
16601718
# GIVEN an APIGatewayRestResolver with validation enabled
16611719
app = APIGatewayRestResolver(enable_validation=True)
@@ -1700,7 +1758,6 @@ def handler_none_allowed() -> Optional[Model]:
17001758
assert result["body"] == "null"
17011759

17021760

1703-
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
17041761
@pytest.mark.parametrize(
17051762
"path, body",
17061763
[
@@ -1756,7 +1813,6 @@ def handler_valid_response() -> Model:
17561813
assert body == {"name": "Joe", "age": 18}
17571814

17581815

1759-
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
17601816
@pytest.mark.parametrize(
17611817
"http_code",
17621818
(422, 500, 510),

0 commit comments

Comments
 (0)