Skip to content

Commit

Permalink
More tests to improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Dec 19, 2023
1 parent a306d60 commit f6df0b2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/uapi/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ def _make_quart_incanter(converter: Converter) -> Incanter:
)
res.register_hook_factory(
is_header,
lambda p: make_header_dependency(
lambda p: _make_header_dependency(
*get_header_type(p), p.name, converter, p.default
),
)
res.register_hook_factory(
lambda p: get_cookie_name(p.annotation, p.name) is not None,
lambda p: make_cookie_dependency(get_cookie_name(p.annotation, p.name), default=p.default), # type: ignore
lambda p: _make_cookie_dependency(get_cookie_name(p.annotation, p.name), default=p.default), # type: ignore
)

async def request_bytes() -> bytes:
Expand All @@ -259,7 +259,7 @@ async def request_bytes() -> bytes:
return res


def make_header_dependency(
def _make_header_dependency(
type: type,
headerspec: HeaderSpec,
name: str,
Expand Down Expand Up @@ -297,7 +297,7 @@ def read_opt_conv_header() -> Any:
return read_opt_conv_header


def make_cookie_dependency(cookie_name: str, default=Signature.empty):
def _make_cookie_dependency(cookie_name: str, default=Signature.empty):
if default is Signature.empty:

def read_cookie() -> str:
Expand Down
12 changes: 6 additions & 6 deletions src/uapi/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,13 @@ def read_query(_request: FrameworkRequest) -> Any:
)
res.register_hook_factory(
is_header,
lambda p: make_header_dependency(
lambda p: _make_header_dependency(
*get_header_type(p), p.name, converter, p.default
),
)
res.register_hook_factory(
lambda p: get_cookie_name(p.annotation, p.name) is not None,
lambda p: make_cookie_dependency(get_cookie_name(p.annotation, p.name), default=p.default), # type: ignore
lambda p: _make_cookie_dependency(get_cookie_name(p.annotation, p.name), default=p.default), # type: ignore
)

async def request_bytes(_request: FrameworkRequest) -> bytes:
Expand All @@ -283,7 +283,7 @@ async def request_bytes(_request: FrameworkRequest) -> bytes:
return res


def make_header_dependency(
def _make_header_dependency(
type: type,
headerspec: HeaderSpec,
name: str,
Expand Down Expand Up @@ -321,7 +321,7 @@ def read_opt_conv_header(_request: FrameworkRequest) -> Any:
return read_opt_conv_header


def make_cookie_dependency(cookie_name: str, default=Signature.empty):
def _make_cookie_dependency(cookie_name: str, default=Signature.empty):
if default is Signature.empty:

def read_cookie(_request: FrameworkRequest) -> str:
Expand All @@ -335,7 +335,7 @@ def read_cookie_opt(_request: FrameworkRequest) -> Any:
return read_cookie_opt


def extract_cookies(headers: Headers) -> tuple[dict[str, str], list[str]]:
def _extract_cookies(headers: Headers) -> tuple[dict[str, str], list[str]]:
h = {}
cookies = []
for k, v in headers.items():
Expand All @@ -362,7 +362,7 @@ async def read_form(_request: FrameworkRequest) -> C:

def _framework_return_adapter(resp: BaseResponse) -> FrameworkResponse:
if resp.headers:
headers, cookies = extract_cookies(resp.headers)
headers, cookies = _extract_cookies(resp.headers)
res = FrameworkResponse(
resp.ret or b"", get_status_code(resp.__class__), headers # type: ignore
)
Expand Down
8 changes: 8 additions & 0 deletions tests/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ async def header_str_default(test_header: Header[str] = "def") -> str:
async def header_default(test_header: Header[str | None] = None) -> str:
return test_header or "default"

@app.get("/header-nonstring")
async def non_str_header_no_default(test_header: Header[int]) -> str:
return str(test_header)

@app.get("/header-renamed")
async def header_renamed(
test_header: Annotated[str, HeaderSpec("test_header")]
Expand Down Expand Up @@ -366,6 +370,10 @@ def header_str_default(test_header: Header[str] = "def") -> str:
def header_default(test_header: Header[str | None] = None) -> str:
return test_header or "default"

@app.get("/header-nonstring")
def non_str_header_no_default(test_header: Header[int]) -> str:
return str(test_header)

@app.get("/header-renamed")
def header_renamed(test_header: Annotated[str, HeaderSpec("test_header")]) -> str:
return test_header
Expand Down
10 changes: 10 additions & 0 deletions tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ async def test_header_with_default(server: int) -> None:
assert resp.text == "1"


async def test_nonstring_header(server: int) -> None:
"""Non-string headers without defaults work properly."""
async with AsyncClient() as client:
resp = await client.get(
f"http://localhost:{server}/header-nonstring", headers={"test-header": "1"}
)
assert resp.status_code == 200
assert resp.text == "1"


async def test_header_name_override(server: int) -> None:
"""Headers can override their names."""
async with AsyncClient() as client:
Expand Down

0 comments on commit f6df0b2

Please sign in to comment.