Skip to content

Commit e9684da

Browse files
feat(BA-690): Make CSP configurable (#3682) (#3684)
Co-authored-by: HyeockJinKim <[email protected]> Co-authored-by: HyeockJinKim <[email protected]>
1 parent 41dcb9e commit e9684da

File tree

6 files changed

+51
-42
lines changed

6 files changed

+51
-42
lines changed

changes/3682.feature.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make CSP configurable

configs/webserver/halfstack.conf

+8-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@ max_file_upload_size = 4294967296
3030

3131
[security]
3232
request_policies = ["reject_metadata_local_link_policy", "reject_access_for_unsafe_file_policy"]
33-
response_policies = ["add_self_content_security_policy", "set_content_type_nosniff_policy"]
33+
response_policies = ["set_content_type_nosniff_policy"]
34+
35+
[security.csp]
36+
default-src = ["'self'"]
37+
style-src = ["'self'", "'unsafe-inline'"]
38+
connect-src = ["'self'", "http://127.0.0.1:6021", "http://127.0.0.1:5050"]
39+
frame-ancestors = ["'none'"]
40+
form-action = ["'self'"]
3441

3542
[environments]
3643

src/ai/backend/web/config.py

+14
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@
7575
t.Key("security", default=_default_security_config): t.Dict({
7676
t.Key("request_policies", default=[]): t.List(t.String),
7777
t.Key("response_policies", default=[]): t.List(t.String),
78+
t.Key("csp", default=None): t.Null
79+
| t.Dict({
80+
t.Key("default-src", default=None): t.Null | t.List(t.String),
81+
t.Key("connect-src", default=None): t.Null | t.List(t.String),
82+
t.Key("img-src", default=None): t.Null | t.List(t.String),
83+
t.Key("media-src", default=None): t.Null | t.List(t.String),
84+
t.Key("font-src", default=None): t.Null | t.List(t.String),
85+
t.Key("script-src", default=None): t.Null | t.List(t.String),
86+
t.Key("style-src", default=None): t.Null | t.List(t.String),
87+
t.Key("frame-src", default=None): t.Null | t.List(t.String),
88+
t.Key("object-src", default=None): t.Null | t.List(t.String),
89+
t.Key("frame-ancestors", default=None): t.Null | t.List(t.String),
90+
t.Key("form-action", default=None): t.Null | t.List(t.String),
91+
}),
7892
}).allow_extra("*"),
7993
t.Key("resources"): t.Dict({
8094
t.Key("open_port_to_public", default=False): t.ToBool,

src/ai/backend/web/security.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Iterable, Self
1+
from typing import Callable, Iterable, Mapping, Optional, Self
22

33
from aiohttp import web
44
from aiohttp.typedefs import Handler
@@ -30,31 +30,29 @@ def __init__(
3030

3131
@classmethod
3232
def from_config(
33-
cls, request_policy_config: list[str], response_policy_config: list[str]
33+
cls,
34+
request_policy_config: list[str],
35+
response_policy_config: list[str],
36+
csp_config: Optional[Mapping[str, Optional[list[str]]]] = None,
3437
) -> Self:
3538
request_policy_map = {
3639
"reject_metadata_local_link_policy": reject_metadata_local_link_policy,
3740
"reject_access_for_unsafe_file_policy": reject_access_for_unsafe_file_policy,
3841
}
3942
response_policy_map = {
40-
"add_self_content_security_policy": add_self_content_security_policy,
4143
"set_content_type_nosniff_policy": set_content_type_nosniff_policy,
4244
}
4345
try:
4446
request_policies = [
4547
request_policy_map[policy_name] for policy_name in request_policy_config
4648
]
47-
response_policies = [
48-
response_policy_map[policy_name] for policy_name in response_policy_config
49-
]
49+
response_policies: list[ResponsePolicy] = []
50+
for policy_name in response_policy_config:
51+
response_policies.append(response_policy_map[policy_name])
5052
except KeyError as e:
5153
raise ValueError(f"Unknown security policy name: {e}")
52-
return cls(request_policies, response_policies)
53-
54-
@classmethod
55-
def default_policy(cls) -> Self:
56-
request_policies = [reject_metadata_local_link_policy, reject_access_for_unsafe_file_policy]
57-
response_policies = [add_self_content_security_policy, set_content_type_nosniff_policy]
54+
if csp_config is not None:
55+
response_policies.append(csp_policy_builder(csp_config))
5856
return cls(request_policies, response_policies)
5957

6058
def check_request_policies(self, request: web.Request) -> None:
@@ -102,6 +100,19 @@ def add_self_content_security_policy(response: web.StreamResponse) -> web.Stream
102100
return response
103101

104102

103+
def csp_policy_builder(csp_config: Mapping[str, Optional[list[str]]]) -> ResponsePolicy:
104+
csp = [key + " " + " ".join(value) for key, value in csp_config.items() if value]
105+
csp_str = "; ".join(csp)
106+
if csp_str:
107+
csp_str = csp_str + ";"
108+
109+
def policy(response: web.StreamResponse) -> web.StreamResponse:
110+
response.headers["Content-Security-Policy"] = csp_str
111+
return response
112+
113+
return policy
114+
115+
105116
def set_content_type_nosniff_policy(response: web.StreamResponse) -> web.StreamResponse:
106117
response.headers["X-Content-Type-Options"] = "nosniff"
107118
return response

src/ai/backend/web/server.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from functools import partial
1515
from pathlib import Path
1616
from pprint import pprint
17-
from typing import Any, AsyncIterator, Tuple, cast
17+
from typing import Any, AsyncIterator, Mapping, Optional, Tuple, cast
1818

1919
import aiohttp_cors
2020
import aiotools
@@ -608,10 +608,11 @@ async def server_main(
608608
middlewares=[decrypt_payload, track_active_handlers, security_policy_middleware]
609609
)
610610
app["config"] = config
611-
request_policy_config = config["security"]["request_policies"]
612-
response_policy_config = config["security"]["response_policies"]
611+
request_policy_config: list[str] = config["security"]["request_policies"]
612+
response_policy_config: list[str] = config["security"]["response_policies"]
613+
csp_policy_config: Optional[Mapping[str, Optional[list[str]]]] = config["security"]["csp"]
613614
app["security_policy"] = SecurityPolicy.from_config(
614-
request_policy_config, response_policy_config
615+
request_policy_config, response_policy_config, csp_policy_config
615616
)
616617
j2env = jinja2.Environment(
617618
extensions=[

tests/webserver/test_security_policy.py

-25
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@
1313
)
1414

1515

16-
@pytest.fixture
17-
def default_app():
18-
app = web.Application(middlewares=[security_policy_middleware])
19-
app["security_policy"] = SecurityPolicy.default_policy()
20-
return app
21-
22-
2316
@pytest.fixture
2417
async def async_handler() -> Handler:
2518
async def handler(request):
@@ -28,24 +21,6 @@ async def handler(request):
2821
return handler
2922

3023

31-
async def test_default_security_policy_reject_metadata_local_link(
32-
default_app, async_handler
33-
) -> None:
34-
request = make_mocked_request("GET", "/", headers={"Host": "169.254.169.254"}, app=default_app)
35-
with pytest.raises(web.HTTPForbidden):
36-
await security_policy_middleware(request, async_handler)
37-
38-
39-
async def test_default_security_policy_response(default_app, async_handler) -> None:
40-
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=default_app)
41-
response = await security_policy_middleware(request, async_handler)
42-
assert (
43-
response.headers["Content-Security-Policy"]
44-
== "default-src 'self'; style-src 'self' 'unsafe-inline'; frame-ancestors 'none'; form-action 'self';"
45-
)
46-
assert response.headers["X-Content-Type-Options"] == "nosniff"
47-
48-
4924
@pytest.mark.parametrize(
5025
"meta_local_link",
5126
[

0 commit comments

Comments
 (0)