Skip to content

Commit

Permalink
Merge pull request #15300 from SergeyYakubov/oidc_tokens
Browse files Browse the repository at this point in the history
OIDC tokens
  • Loading branch information
mvdbeek committed Jun 21, 2023
2 parents 7a8fdd0 + a71f205 commit acc7136
Show file tree
Hide file tree
Showing 16 changed files with 352 additions and 23 deletions.
42 changes: 42 additions & 0 deletions client/src/schema/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,13 @@ export interface paths {
*/
get: operations["show_api_jobs__id__get"];
};
"/api/jobs/{job_id}/oidc-tokens": {
/**
* Get a fresh OIDC token
* @description Allows remote job running mechanisms to get a fresh OIDC token that can be used on remote side to authorize user. It is not meant to represent part of Galaxy's stable, user facing API
*/
get: operations["get_token_api_jobs__job_id__oidc_tokens_get"];
};
"/api/libraries": {
/**
* Returns a list of summary data for all libraries.
Expand Down Expand Up @@ -13759,6 +13766,41 @@ export interface operations {
};
};
};
get_token_api_jobs__job_id__oidc_tokens_get: {
/**
* Get a fresh OIDC token
* @description Allows remote job running mechanisms to get a fresh OIDC token that can be used on remote side to authorize user. It is not meant to represent part of Galaxy's stable, user facing API
*/
parameters: {
/** @description A key used to authenticate this request as acting onbehalf or a job runner for the specified job */
/** @description OIDC provider name */
query: {
job_key: string;
provider: string;
};
/** @description The user ID that will be used to effectively make this API call. Only admins and designated users can make API calls on behalf of other users. */
header?: {
"run-as"?: string;
};
path: {
job_id: string;
};
};
responses: {
/** @description Successful Response */
200: {
content: {
"text/plain": string;
};
};
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
index_api_libraries_get: {
/**
* Returns a list of summary data for all libraries.
Expand Down
3 changes: 3 additions & 0 deletions lib/galaxy/authnz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(self, provider, config, backend_config):
"""
raise NotImplementedError()

def refresh(self, trans, token):
raise NotImplementedError()

def authenticate(self, provider, trans):
"""Runs for authentication process. Checks the database if a
valid identity exists in the database; if yes, then the user
Expand Down
94 changes: 75 additions & 19 deletions lib/galaxy/authnz/custos_authnz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import time
from datetime import (
datetime,
timedelta,
Expand All @@ -14,7 +15,10 @@
from oauthlib.common import generate_nonce
from requests_oauthlib import OAuth2Session

from galaxy import util
from galaxy import (
exceptions,
util,
)
from galaxy.model import (
CustosAuthnzToken,
User,
Expand Down Expand Up @@ -69,6 +73,43 @@ def __init__(self, provider, oidc_config, oidc_backend_config, idphint=None):
def _decode_token_no_signature(self, token):
return jwt.decode(token, audience=self.config["client_id"], options={"verify_signature": False})

def refresh(self, trans, custos_authnz_token):
if custos_authnz_token is None:
raise exceptions.AuthenticationFailed("cannot find authorized user while refreshing token")
id_token_decoded = self._decode_token_no_signature(custos_authnz_token.id_token)
# do not refresh tokens if they didn't reach their half lifetime
if int(id_token_decoded["iat"]) + int(id_token_decoded["exp"]) > 2 * int(time.time()):
return False
log.info(custos_authnz_token.access_token)
oauth2_session = self._create_oauth2_session()
token_endpoint = self.config["token_endpoint"]
if self.config.get("iam_client_secret"):
client_secret = self.config["iam_client_secret"]
else:
client_secret = self.config["client_secret"]
clientIdAndSec = f"{self.config['client_id']}:{self.config['client_secret']}" # for custos

params = {
"client_secret": client_secret,
"refresh_token": custos_authnz_token.refresh_token,
"headers": {
"Authorization": f"Basic {util.unicodify(base64.b64encode(util.smart_str(clientIdAndSec)))}"
}, # for custos
}

token = oauth2_session.refresh_token(token_endpoint, **params)
processed_token = self._process_token(trans, oauth2_session, token, False)

custos_authnz_token.access_token = processed_token["access_token"]
custos_authnz_token.id_token = processed_token["id_token"]
custos_authnz_token.refresh_token = processed_token["refresh_token"]
custos_authnz_token.expiration_time = processed_token["expiration_time"]
custos_authnz_token.refresh_expiration_time = processed_token["refresh_expiration_time"]

trans.sa_session.add(custos_authnz_token)
trans.sa_session.flush()
return True

def authenticate(self, trans, idphint=None):
base_authorize_url = self.config["authorization_endpoint"]
scopes = ["openid", "email", "profile"]
Expand Down Expand Up @@ -97,35 +138,51 @@ def authenticate(self, trans, idphint=None):
trans.set_cookie(value=nonce, name=NONCE_COOKIE_NAME)
return authorization_url

def callback(self, state_token, authz_code, trans, login_redirect_url):
# Take state value to validate from token. OAuth2Session.fetch_token
# will validate that the state query parameter value on the URL matches
# this value.
state_cookie = trans.get_cookie(name=STATE_COOKIE_NAME)
oauth2_session = self._create_oauth2_session(state=state_cookie)
token = self._fetch_token(oauth2_session, trans)
access_token = token["access_token"]
id_token = token["id_token"]
refresh_token = token["refresh_token"] if "refresh_token" in token else None
expiration_time = datetime.now() + timedelta(seconds=token.get("expires_in", 3600))
refresh_expiration_time = (
def _process_token(self, trans, oauth2_session, token, validate_nonce=True):
processed_token = {}
processed_token["access_token"] = token["access_token"]
processed_token["id_token"] = token["id_token"]
processed_token["refresh_token"] = token["refresh_token"] if "refresh_token" in token else None
processed_token["expiration_time"] = datetime.now() + timedelta(seconds=token.get("expires_in", 3600))
processed_token["refresh_expiration_time"] = (
(datetime.now() + timedelta(seconds=token["refresh_expires_in"])) if "refresh_expires_in" in token else None
)

# Get nonce from token['id_token'] and validate. 'nonce' in the
# id_token is a hash of the nonce stored in the NONCE_COOKIE_NAME
# cookie.
id_token_decoded = self._decode_token_no_signature(id_token)
nonce_hash = id_token_decoded["nonce"]
self._validate_nonce(trans, nonce_hash)
id_token_decoded = self._decode_token_no_signature(processed_token["id_token"])
if validate_nonce:
nonce_hash = id_token_decoded["nonce"]
self._validate_nonce(trans, nonce_hash)

# Get userinfo and lookup/create Galaxy user record
if id_token_decoded.get("email", None):
userinfo = id_token_decoded
else:
userinfo = self._get_userinfo(oauth2_session)
email = userinfo["email"]
user_id = userinfo["sub"]
processed_token["email"] = userinfo["email"]
processed_token["user_id"] = userinfo["sub"]
processed_token["username"] = self._username_from_userinfo(trans, userinfo)
return processed_token

def callback(self, state_token, authz_code, trans, login_redirect_url):
# Take state value to validate from token. OAuth2Session.fetch_token
# will validate that the state query parameter value on the URL matches
# this value.
state_cookie = trans.get_cookie(name=STATE_COOKIE_NAME)
oauth2_session = self._create_oauth2_session(state=state_cookie)
token = self._fetch_token(oauth2_session, trans)
processed_token = self._process_token(trans, oauth2_session, token)

user_id = processed_token["user_id"]
email = processed_token["email"]
username = processed_token["username"]
access_token = processed_token["access_token"]
id_token = processed_token["id_token"]
refresh_token = processed_token["refresh_token"]
expiration_time = processed_token["expiration_time"]
refresh_expiration_time = processed_token["refresh_expiration_time"]

# Create or update custos_authnz_token record
custos_authnz_token = self._get_custos_authnz_token(trans.sa_session, user_id, self.config["provider"])
Expand Down Expand Up @@ -160,7 +217,6 @@ def callback(self, state_token, authz_code, trans, login_redirect_url):
login_redirect_url = f"{login_redirect_url}login/start?confirm=true&provider_token={json.dumps(token)}&provider={self.config['provider']}"
return login_redirect_url, None
else:
username = self._username_from_userinfo(trans, userinfo)
user = trans.app.user_manager.create(email=email, username=username)
if trans.app.config.user_activation_on:
trans.app.user_manager.send_activation_email(trans, email, username)
Expand Down
27 changes: 27 additions & 0 deletions lib/galaxy/authnz/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def _parse_idp_config(self, config_xml):
rtv["icon"] = config_xml.find("icon").text
if config_xml.find("extra_scopes") is not None:
rtv["extra_scopes"] = listify(config_xml.find("extra_scopes").text)
if config_xml.find("tenant_id") is not None:
rtv["tenant_id"] = config_xml.find("tenant_id").text
if config_xml.find("pkce_support") is not None:
rtv["pkce_support"] = asbool(config_xml.find("pkce_support").text)

Expand Down Expand Up @@ -317,6 +319,31 @@ def try_get_authz_config(sa_session, user_id, authz_id):
raise exceptions.ItemAccessibilityException(msg)
return qres

def refresh_expiring_oidc_tokens_for_provider(self, trans, auth):
try:
success, message, backend = self._get_authnz_backend(auth.provider)
if success is False:
msg = f"An error occurred when refreshing user token on `{auth.provider}` identity provider: {message}"
log.error(msg)
return False
refreshed = backend.refresh(trans, auth)
if refreshed:
log.debug(f"Refreshed user token via `{auth.provider}` identity provider")
return True
except Exception as e:
msg = f"An error occurred when refreshing user token: {e}"
log.error(msg)
return False

def refresh_expiring_oidc_tokens(self, trans, user=None):
user = trans.user or user
if not isinstance(user, model.User):
return
for auth in user.custos_auth or []:
self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
for auth in user.social_auth or []:
self.refresh_expiring_oidc_tokens_for_provider(trans, auth)

def authenticate(self, provider, trans, idphint=None):
"""
:type provider: string
Expand Down
44 changes: 42 additions & 2 deletions lib/galaxy/authnz/psa_authnz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
import logging
import time

import jwt
import requests
from msal import ConfidentialClientApplication
from social_core.actions import (
do_auth,
do_complete,
Expand Down Expand Up @@ -35,15 +39,15 @@
"globus": "social_core.backends.globus.GlobusOpenIdConnect",
"elixir": "social_core.backends.elixir.ElixirOpenIdConnect",
"okta": "social_core.backends.okta_openidconnect.OktaOpenIdConnect",
"azure": "social_core.backends.azuread_tenant.AzureADTenantOAuth2",
"azure": "social_core.backends.azuread_tenant.AzureADV2TenantOAuth2",
}

BACKENDS_NAME = {
"google": "google-openidconnect",
"globus": "globus",
"elixir": "elixir",
"okta": "okta-openidconnect",
"azure": "azuread-tenant-oauth2",
"azure": "azuread-v2-tenant-oauth2",
}

AUTH_PIPELINE = (
Expand Down Expand Up @@ -122,6 +126,7 @@ def _setup_idp(self, oidc_backend_config):
self.config[setting_name("AUTH_EXTRA_ARGUMENTS")] = {"access_type": "offline"}
self.config["KEY"] = oidc_backend_config.get("client_id")
self.config["SECRET"] = oidc_backend_config.get("client_secret")
self.config["TENANT_ID"] = oidc_backend_config.get("tenant_id")
self.config["redirect_uri"] = oidc_backend_config.get("redirect_uri")
self.config["EXTRA_SCOPES"] = oidc_backend_config.get("extra_scopes")
if oidc_backend_config.get("prompt") is not None:
Expand All @@ -143,6 +148,40 @@ def _load_backend(self, strategy, redirect_uri):
def _login_user(self, backend, user, social_user):
self.config["user"] = user

def refresh_azure(self, user_authnz_token):
logging.getLogger("msal").setLevel(logging.WARN)
old_extra_data = user_authnz_token.extra_data
app = ConfidentialClientApplication(
self.config["KEY"],
self.config["SECRET"],
authority="https://login.microsoftonline.com/" + self.config["TENANT_ID"],
)
extra_data = app.acquire_token_by_refresh_token(
old_extra_data["refresh_token"], scopes=["https://graph.microsoft.com/.default"]
)
decoded_token = jwt.decode(extra_data["id_token"], options={"verify_signature": False})
if "auth_time" not in extra_data:
extra_data["auth_time"] = decoded_token["iat"]
expires = decoded_token["exp"]
extra_data["expires"] = int(expires - time.time())
user_authnz_token.set_extra_data(extra_data)

def refresh(self, trans, user_authnz_token):
if not user_authnz_token or not user_authnz_token.extra_data:
return False
# refresh tokens if they reached their half lifetime
if int(user_authnz_token.extra_data["auth_time"]) + int(user_authnz_token.extra_data["expires"]) / 2 <= int(
time.time()
):
on_the_fly_config(trans.sa_session)
if self.config["provider"] == "azure":
self.refresh_azure(user_authnz_token)
else:
strategy = Strategy(trans.request, trans.session, Storage, self.config)
user_authnz_token.refresh_token(strategy)
return True
return False

def authenticate(self, trans):
on_the_fly_config(trans.sa_session)
strategy = Strategy(trans.request, trans.session, Storage, self.config)
Expand Down Expand Up @@ -171,6 +210,7 @@ def callback(self, state_token, authz_code, trans, login_redirect_url):
user=trans.user,
state=state_token,
)

return redirect_url, self.config.get("user", None)

def disconnect(self, provider, trans, disconnect_redirect_url=None, association_id=None):
Expand Down
11 changes: 11 additions & 0 deletions lib/galaxy/authnz/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .custos_authnz import KEYCLOAK_BACKENDS
from .psa_authnz import BACKENDS_NAME


def provider_name_to_backend(provider):
if provider.lower() in KEYCLOAK_BACKENDS:
return provider.lower()
for k, v in BACKENDS_NAME.items():
if k.lower() == provider:
return v
return None
1 change: 1 addition & 0 deletions lib/galaxy/dependencies/pinned-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ mdurl==0.1.2 ; python_version >= "3.7" and python_version < "3.12"
mercurial==6.4.4 ; python_version >= "3.7" and python_version < "3.12"
mistune==2.0.5 ; python_version >= "3.7" and python_version < "3.12"
mrcfile==1.4.3 ; python_version >= "3.7" and python_version < "3.12"
msal==1.21.0 ; python_version >= "3.7" and python_version < "3.12"
msgpack==1.0.5 ; python_version >= "3.7" and python_version < "3.12"
multidict==6.0.4 ; python_version >= "3.7" and python_version < "3.12"
mypy-extensions==1.0.0 ; python_version >= "3.7" and python_version < "3.12"
Expand Down
11 changes: 10 additions & 1 deletion lib/galaxy/jobs/runners/pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@
map=specs.to_str_or_none,
default=None,
),
secret=dict(
map=specs.to_str_or_none,
default=None,
),
pulsar_config=dict(
map=specs.to_str_or_none,
default=None,
Expand Down Expand Up @@ -610,7 +614,12 @@ def get_client(self, job_destination_params, job_id, env=None):
if self.app.config.nginx_upload_job_files_path:
endpoint_base = "%s" + self.app.config.nginx_upload_job_files_path + "?job_id=%s&job_key=%s"
files_endpoint = endpoint_base % (self.galaxy_url, encoded_job_id, job_key)
get_client_kwds = dict(job_id=str(job_id), files_endpoint=files_endpoint, env=env)
secret = job_destination_params.get("job_secret_base", "jobs_token")
job_key = self.app.security.encode_id(job_id, kind=secret)
token_endpoint = f"{self.galaxy_url}/api/jobs/{encoded_job_id}/oidc-tokens?job_key={job_key}"
get_client_kwds = dict(
job_id=str(job_id), files_endpoint=files_endpoint, token_endpoint=token_endpoint, env=env
)
# Turn MutableDict into standard dict for pulsar consumption
job_destination_params = dict(job_destination_params.items())
return self.client_manager.get_client(job_destination_params, **get_client_kwds)
Expand Down
Loading

0 comments on commit acc7136

Please sign in to comment.