Skip to content

Commit 864cd60

Browse files
authored
Merge pull request #1005 from Yelp/u/mpiano/SEC-19555
auth support for Tron APIs
2 parents c7f3063 + 54360bd commit 864cd60

File tree

5 files changed

+216
-0
lines changed

5 files changed

+216
-0
lines changed

tests/api/auth_test.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from unittest.mock import MagicMock
2+
from unittest.mock import patch
3+
4+
import pytest
5+
from twisted.web.server import Request
6+
7+
from tron.api.auth import AuthorizationFilter
8+
from tron.api.auth import AuthorizationOutcome
9+
10+
11+
@pytest.fixture
12+
def mock_auth_filter():
13+
with patch("tron.api.auth.requests"):
14+
yield AuthorizationFilter("http://localhost:31337/whatever", True)
15+
16+
17+
def mock_request(path: str, token: str, method: str):
18+
res = MagicMock(spec=Request, path=path.encode(), method=method.encode())
19+
res.getHeader.return_value = token
20+
return res
21+
22+
23+
def test_is_request_authorized(mock_auth_filter):
24+
mock_auth_filter.session.post.return_value.json.return_value = {
25+
"result": {"allowed": True, "reason": "User allowed"}
26+
}
27+
assert mock_auth_filter.is_request_authorized(
28+
mock_request("/api/jobs/foobar.run.2", "aaa.bbb.ccc", "get")
29+
) == AuthorizationOutcome(True, "User allowed")
30+
mock_auth_filter.session.post.assert_called_once_with(
31+
url="http://localhost:31337/whatever",
32+
json={
33+
"input": {
34+
"path": "/api/jobs/foobar.run.2",
35+
"backend": "tron",
36+
"token": "aaa.bbb.ccc",
37+
"method": "get",
38+
"service": "foobar",
39+
}
40+
},
41+
timeout=2,
42+
)
43+
44+
45+
def test_is_request_authorized_fail(mock_auth_filter):
46+
mock_auth_filter.session.post.side_effect = Exception
47+
assert mock_auth_filter.is_request_authorized(
48+
mock_request("/allowed", "eee.ddd.fff", "get")
49+
) == AuthorizationOutcome(False, "Auth backend error")
50+
51+
52+
def test_is_request_authorized_malformed(mock_auth_filter):
53+
mock_auth_filter.session.post.return_value.json.return_value = {"foo": "bar"}
54+
assert mock_auth_filter.is_request_authorized(
55+
mock_request("/allowed", "eee.ddd.fff", "post")
56+
) == AuthorizationOutcome(False, "Malformed auth response")
57+
58+
59+
def test_is_request_authorized_no_enforce(mock_auth_filter):
60+
mock_auth_filter.session.post.return_value.json.return_value = {
61+
"result": {"allowed": False, "reason": "Missing token"}
62+
}
63+
with patch.object(mock_auth_filter, "enforce", False):
64+
assert mock_auth_filter.is_request_authorized(mock_request("/foobar", "", "post")) == AuthorizationOutcome(
65+
True, "Auth dry-run"
66+
)
67+
68+
69+
def test_is_request_authorized_disabled(mock_auth_filter):
70+
mock_auth_filter.session.post.return_value.json.return_value = {
71+
"result": {"allowed": False, "reason": "Missing token"}
72+
}
73+
with patch.object(mock_auth_filter, "endpoint", None):
74+
assert mock_auth_filter.is_request_authorized(mock_request("/buzz", "", "post")) == AuthorizationOutcome(
75+
True, "Auth not enabled"
76+
)

tron/api/auth.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import logging
2+
import os
3+
from functools import lru_cache
4+
from typing import NamedTuple
5+
from typing import Optional
6+
7+
import cachetools.func
8+
import requests
9+
from twisted.web.server import Request
10+
11+
12+
logger = logging.getLogger(__name__)
13+
AUTH_CACHE_SIZE = 50000
14+
AUTH_CACHE_TTL = 30 * 60
15+
16+
17+
class AuthorizationOutcome(NamedTuple):
18+
authorized: bool
19+
reason: str
20+
21+
22+
class AuthorizationFilter:
23+
"""API request authorization via external system"""
24+
25+
def __init__(self, endpoint: str, enforce: bool):
26+
"""Constructor
27+
28+
:param str endpoint: HTTP endpoint of external authorization system
29+
:param bool enforce: whether to enforce authorization decisions
30+
"""
31+
self.endpoint = endpoint
32+
self.enforce = enforce
33+
self.session = requests.Session()
34+
35+
@classmethod
36+
@lru_cache(maxsize=1)
37+
def get_from_env(cls) -> "AuthorizationFilter":
38+
return cls(
39+
endpoint=os.getenv("API_AUTH_ENDPOINT", ""),
40+
enforce=bool(os.getenv("API_AUTH_ENFORCE", "")),
41+
)
42+
43+
def is_request_authorized(self, request: Request) -> AuthorizationOutcome:
44+
"""Check if API request is authorized
45+
46+
:param Request request: API request object
47+
:return: auth outcome
48+
"""
49+
if not self.endpoint:
50+
return AuthorizationOutcome(True, "Auth not enabled")
51+
token = (request.getHeader("Authorization") or "").strip()
52+
token = token.split()[-1] if token else "" # removes "Bearer" prefix
53+
url_path = request.path.decode()
54+
service = url_path.split("/")[-1].split(".", 1)[0] if "/jobs/" in url_path else None
55+
auth_outcome = self._is_request_authorized_impl(
56+
# path and method are byte arrays in twisted
57+
path=url_path,
58+
token=token,
59+
method=request.method.decode(),
60+
service=service,
61+
)
62+
return auth_outcome if self.enforce else AuthorizationOutcome(True, "Auth dry-run")
63+
64+
@cachetools.func.ttl_cache(maxsize=AUTH_CACHE_SIZE, ttl=AUTH_CACHE_TTL)
65+
def _is_request_authorized_impl(
66+
self,
67+
path: str,
68+
token: str,
69+
method: str,
70+
service: Optional[str],
71+
) -> AuthorizationOutcome:
72+
"""Check if API request is authorized
73+
74+
:param str path: API path
75+
:param str token: authentication token
76+
:param str method: http method
77+
:return: auth outcome
78+
"""
79+
try:
80+
response = self.session.post(
81+
url=self.endpoint,
82+
json={
83+
"input": {
84+
"path": path,
85+
"backend": "tron",
86+
"token": token,
87+
"method": method.lower(),
88+
"service": service,
89+
},
90+
},
91+
timeout=2,
92+
).json()
93+
except Exception as e:
94+
logger.exception(f"Issue communicating with auth endpoint: {e}")
95+
return AuthorizationOutcome(False, "Auth backend error")
96+
97+
auth_result_allowed = response.get("result", {}).get("allowed")
98+
if auth_result_allowed is None:
99+
return AuthorizationOutcome(False, "Malformed auth response")
100+
101+
if not auth_result_allowed:
102+
reason = response["result"].get("reason", "Denied")
103+
return AuthorizationOutcome(False, reason)
104+
105+
reason = response["result"].get("reason", "Ok")
106+
return AuthorizationOutcome(True, reason)

tron/api/resource.py

+13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tron.api import adapter, controller
2727
from tron.api import requestargs
2828
from tron.api.async_resource import AsyncResource
29+
from tron.api.auth import AuthorizationFilter
2930
from tron.metrics import view_all_metrics
3031
from tron.metrics import meter
3132
from tron.utils import maybe_decode
@@ -514,6 +515,18 @@ def render_GET(self, request):
514515
}
515516
return respond(request=request, response=response)
516517

518+
def render(self, request):
519+
"""Overriding base `render` method to support auth"""
520+
auth_outcome = AuthorizationFilter.get_from_env().is_request_authorized(request)
521+
if not auth_outcome.authorized:
522+
return respond(
523+
request=request,
524+
response={"reason": auth_outcome.reason},
525+
code=http.FORBIDDEN,
526+
headers={"X-Auth-Failure-Reason": auth_outcome.reason},
527+
)
528+
return super().render(request)
529+
517530

518531
class RootResource(resource.Resource):
519532
def __init__(self, mcp, web_path):

tron/commands/client.py

+18
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,27 @@ class RequestError(ValueError):
3737
}
3838

3939

40+
def get_sso_auth_token() -> str:
41+
"""Generate an authentication token for the calling user from the Single Sign On provider, if configured"""
42+
43+
# These imports are here because:
44+
# - okta-auth is an internal library, so can never be imported or type-checked in public builds
45+
# - there's an annoying circular import with the cmd_utils module
46+
from okta_auth import get_and_cache_jwt_default # type: ignore
47+
from tron.commands.cmd_utils import get_client_config
48+
49+
client_id = get_client_config().get("auth_sso_oidc_client_id")
50+
return get_and_cache_jwt_default(client_id) if client_id else "" # type: ignore
51+
52+
4053
def build_url_request(uri, data, headers=None, method=None):
4154
headers = headers or default_headers
4255
enc_data = urllib.parse.urlencode(data).encode() if data else None
56+
# Currently implementing auth only for management actions (i.e. POST requests)
57+
if os.getenv("TRONCTL_API_AUTH") and (data or method.upper() == "POST"):
58+
token = get_sso_auth_token()
59+
if token:
60+
headers["Authorization"] = f"Bearer {token}"
4361
return urllib.request.Request(uri, enc_data, headers=headers, method=method)
4462

4563

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
clusterman-metrics==2.2.1 # used by tron for pre-scaling for Spark runs
22
logreader==1.2.0 # used by tron logreader
3+
okta-auth==1.0.2 # used for API auth
4+
pyjwt==2.9.0 # required by okta-auth
5+
saml-helper==2.5.3 # required by okta-auth
36
simplejson==3.19.2 # required by tron CLI
47
yelp-meteorite==2.1.1 # used by task-processing to emit metrics, clusterman-metrics dependency

0 commit comments

Comments
 (0)