Skip to content

Commit

Permalink
Added interface for managing token caches
Browse files Browse the repository at this point in the history
  • Loading branch information
thisaltennakoon committed Sep 23, 2024
1 parent 322d28a commit 097ae45
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ public class AuthorizationCodeHandler extends OAuthHandler {

public AuthorizationCodeHandler(String tokenApiUrl, String clientId, String clientSecret,
String refreshToken, String authMode, int connectionTimeout,
int connectionRequestTimeout, int socketTimeout) {
int connectionRequestTimeout, int socketTimeout,
TokenCacheProvider tokenCacheProvider) {

super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout,
socketTimeout);
socketTimeout, tokenCacheProvider);
this.refreshToken = refreshToken;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
public class ClientCredentialsHandler extends OAuthHandler {

public ClientCredentialsHandler(String tokenApiUrl, String clientId, String clientSecret, String authMode,
int connectionTimeout, int connectionRequestTimeout, int socketTimeout) {
int connectionTimeout, int connectionRequestTimeout, int socketTimeout,
TokenCacheProvider tokenCacheProvider) {

super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout);
super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout,
tokenCacheProvider);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.apache.http.impl.NoConnectionReuseStrategy;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.conn.BasicHttpClientConnectionManager;
import org.apache.synapse.MessageContext;
import org.apache.synapse.core.axis2.Axis2MessageContext;
Expand Down Expand Up @@ -100,28 +99,30 @@ public class OAuthClient {
public static String generateToken(String tokenApiUrl, String payload, String credentials,
MessageContext messageContext, Map<String, String> customHeaders,
int connectionTimeout, int connectionRequestTimeout, int socketTimeout) throws AuthException, IOException {
CloseableHttpClient httpClient = getSecureClient(tokenApiUrl, messageContext, connectionTimeout,
connectionRequestTimeout, socketTimeout);

if (log.isDebugEnabled()) {
log.debug("Initializing token generation request: [token-endpoint] " + tokenApiUrl);
}

HttpPost httpPost = new HttpPost(tokenApiUrl);
httpPost.setHeader(AuthConstants.CONTENT_TYPE_HEADER, AuthConstants.APPLICATION_X_WWW_FORM_URLENCODED);
if (!(customHeaders == null || customHeaders.isEmpty())) {
for (Map.Entry<String, String> entry : customHeaders.entrySet()) {
httpPost.setHeader(entry.getKey(), entry.getValue());
try (CloseableHttpClient httpClient = getSecureClient(tokenApiUrl, messageContext, connectionTimeout,
connectionRequestTimeout, socketTimeout)) {
HttpPost httpPost = new HttpPost(tokenApiUrl);
httpPost.setHeader(AuthConstants.CONTENT_TYPE_HEADER, AuthConstants.APPLICATION_X_WWW_FORM_URLENCODED);
if (!(customHeaders == null || customHeaders.isEmpty())) {
for (Map.Entry<String, String> entry : customHeaders.entrySet()) {
httpPost.setHeader(entry.getKey(), entry.getValue());
}
}
}
if (credentials != null) {
httpPost.setHeader(AuthConstants.AUTHORIZATION_HEADER, AuthConstants.BASIC + credentials);
}
httpPost.setEntity(new StringEntity(payload));
if (credentials != null) {
httpPost.setHeader(AuthConstants.AUTHORIZATION_HEADER, AuthConstants.BASIC + credentials);
}
httpPost.setEntity(new StringEntity(payload));

try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
return extractToken(response);
} finally {
httpPost.releaseConnection();
try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
return extractToken(response);
} finally {
httpPost.releaseConnection();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;

/**
* This abstract class is to be used by OAuth handlers
Expand All @@ -55,9 +53,11 @@ public abstract class OAuthHandler implements AuthHandler {
protected final int connectionTimeout;
protected final int connectionRequestTimeout;
protected final int socketTimeout;
private final TokenCacheProvider tokenCacheProvider;

protected OAuthHandler(String tokenApiUrl, String clientId, String clientSecret, String authMode,
int connectionTimeout, int connectionRequestTimeout, int socketTimeout) {
int connectionTimeout, int connectionRequestTimeout, int socketTimeout,
TokenCacheProvider tokenCacheProvider) {

this.id = OAuthUtils.getRandomOAuthHandlerID();
this.tokenApiUrl = tokenApiUrl;
Expand All @@ -67,6 +67,7 @@ protected OAuthHandler(String tokenApiUrl, String clientId, String clientSecret,
this.connectionTimeout = connectionTimeout;
this.connectionRequestTimeout = connectionRequestTimeout;
this.socketTimeout = socketTimeout;
this.tokenCacheProvider = tokenCacheProvider;
}

@Override
Expand All @@ -87,18 +88,25 @@ public void setAuthHeader(MessageContext messageContext) throws AuthException {
*/
private String getToken(final MessageContext messageContext) throws AuthException {

try {
return TokenCache.getInstance().getToken(getId(messageContext), new Callable<String>() {
@Override
public String call() throws AuthException, IOException {
return OAuthClient.generateToken(OAuthUtils.resolveExpression(tokenApiUrl, messageContext),
// Check if the token is already cached
String token = tokenCacheProvider.getToken(getId(messageContext));

synchronized (getId(messageContext).intern()) {
if (StringUtils.isEmpty(token)) {
// If no token found, generate a new one
try {
token = OAuthClient.generateToken(OAuthUtils.resolveExpression(tokenApiUrl, messageContext),
buildTokenRequestPayload(messageContext), getEncodedCredentials(messageContext),
messageContext, getResolvedCustomHeadersMap(customHeadersMap, messageContext), connectionTimeout,
connectionRequestTimeout, socketTimeout);
messageContext, getResolvedCustomHeadersMap(customHeadersMap, messageContext),
connectionTimeout, connectionRequestTimeout, socketTimeout);

// Cache the newly generated token
tokenCacheProvider.putToken(getId(messageContext), token);
} catch (IOException e) {
throw new AuthException("Error generating token", e);
}
});
} catch (ExecutionException e) {
throw new AuthException(e.getCause());
}
return token;
}
}

Expand Down Expand Up @@ -133,15 +141,15 @@ public int compare(String o1, String o2) {
*/
public void removeTokenFromCache(MessageContext messageContext) throws AuthException {

TokenCache.getInstance().removeToken(getId(messageContext));
tokenCacheProvider.removeToken(getId(messageContext));
}

/**
* Method to remove the token from the cache when the endpoint is destroyed.
*/
public void removeTokensFromCache() {

TokenCache.getInstance().removeTokens(id.concat("_"));
tokenCacheProvider.removeTokens(id.concat("_"));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ private static AuthorizationCodeHandler getAuthorizationCodeHandler(OMElement au
return null;
}
AuthorizationCodeHandler handler = new AuthorizationCodeHandler(tokenApiUrl, clientId, clientSecret,
refreshToken, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout);
refreshToken, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout,
TokenCacheFactory.getTokenCache());
if (hasRequestParameters(authCodeElement)) {
Map<String, String> requestParameters = getRequestParameters(authCodeElement);
if (requestParameters == null) {
Expand Down Expand Up @@ -170,7 +171,7 @@ private static ClientCredentialsHandler getClientCredentialsHandler(
return null;
}
ClientCredentialsHandler handler = new ClientCredentialsHandler(tokenApiUrl, clientId, clientSecret, authMode,
connectionTimeout, connectionRequestTimeout, socketTimeout);
connectionTimeout, connectionRequestTimeout, socketTimeout, TokenCacheFactory.getTokenCache());
if (hasRequestParameters(clientCredentialsElement)) {
Map<String, String> requestParameters = getRequestParameters(clientCredentialsElement);
if (requestParameters == null) {
Expand Down Expand Up @@ -213,7 +214,8 @@ private static PasswordCredentialsHandler getPasswordCredentialsHandler(
return null;
}
PasswordCredentialsHandler handler = new PasswordCredentialsHandler(tokenApiUrl, clientId, clientSecret,
username, password, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout);
username, password, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout,
TokenCacheFactory.getTokenCache());
if (hasRequestParameters(passwordCredentialsElement)) {
Map<String, String> requestParameters = getRequestParameters(passwordCredentialsElement);
if (requestParameters == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ public class PasswordCredentialsHandler extends OAuthHandler {

protected PasswordCredentialsHandler(String tokenApiUrl, String clientId, String clientSecret, String username,
String password, String authMode, int connectionTimeout,
int connectionRequestTimeout, int socketTimeout) {
int connectionRequestTimeout, int socketTimeout,
TokenCacheProvider tokenCacheProvider) {

super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout);
super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout,
tokenCacheProvider);
this.username = username;
this.password = password;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import org.apache.synapse.config.SynapsePropertiesLoader;
import org.apache.synapse.endpoints.auth.AuthConstants;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static org.apache.synapse.endpoints.auth.AuthConstants.TOKEN_CACHE_TIMEOUT_PROPERTY;
Expand All @@ -35,7 +33,7 @@
* Token Cache Implementation
* Tokens will be invalidate after a interval of TOKEN_CACHE_TIMEOUT minutes
*/
public class TokenCache {
public class TokenCache implements TokenCacheProvider {

private static final Log log = LogFactory.getLog(TokenCache.class);

Expand Down Expand Up @@ -70,22 +68,35 @@ public static TokenCache getInstance() {
}

/**
* This method returns the value in the cache, or computes it from the specified Callable
* Stores a token in the cache with the specified ID.
*
* @param id id of the oauth handler
* @param callable to generate a new token by calling oauth server
* @return Token object
* @param id the unique identifier for the token
* @param token the token to be cached
*/
public String getToken(String id, Callable<String> callable) throws ExecutionException {
@Override
public void putToken(String id, String token) {

return tokenMap.get(id, callable);
tokenMap.put(id, token);
}

/**
* Retrieves a token from the cache using the specified ID.
*
* @param id the unique identifier for the token
* @return the cached token, or {@code null} if not found
*/
@Override
public String getToken(String id) {

return tokenMap.getIfPresent(id);
}

/**
* This method is called to remove the token from the cache when the token is invalid
*
* @param id id of the endpoint
*/
@Override
public void removeToken(String id) {

tokenMap.invalidate(id);
Expand All @@ -96,6 +107,7 @@ public void removeToken(String id) {
*
* @param oauthHandlerId id of the OAuth handler bounded to the endpoint
*/
@Override
public void removeTokens(String oauthHandlerId) {
tokenMap.asMap().entrySet().removeIf(entry -> entry.getKey().startsWith(oauthHandlerId));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2024, WSO2 LLC. (https://www.wso2.com/).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.synapse.endpoints.auth.oauth;

import org.apache.synapse.SynapseException;
import org.apache.synapse.config.SynapsePropertiesLoader;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
* Factory class responsible for providing the appropriate implementation of the TokenCacheProvider interface.
* This class manages the singleton instance of TokenCacheProvider, ensuring that it is only loaded once and reused
* across the application.
*/
public class TokenCacheFactory {

/**
* Singleton instance of TokenCacheProvider. This will be initialized the first time, and the same instance will be
* returned on subsequent calls.
*/
private static TokenCacheProvider tokenCacheProvider;

/**
* Retrieves the singleton instance of TokenCacheProvider. If the instance is not already initialized,
* it attempts to load the provider class specified in the `token.cache.class` property. If the property
* is not set or the class cannot be loaded, it defaults to the TokenCache implementation.
*
* @return the singleton instance of TokenCacheProvider
* @throws SynapseException if there is an error loading the specified class
*/
public static TokenCacheProvider getTokenCache() {
if (tokenCacheProvider != null) {
return tokenCacheProvider;
}

String classPath = SynapsePropertiesLoader.loadSynapseProperties().getProperty("token.cache.class");
if (classPath != null) {
tokenCacheProvider = loadTokenCacheProvider(classPath);
} else {
tokenCacheProvider = TokenCache.getInstance();
}
return tokenCacheProvider;
}

/**
* Loads the TokenCacheProvider implementation specified by the given class path.
*
* @param classPath the fully qualified class path of the TokenCacheProvider implementation
* @return an instance of the specified TokenCacheProvider implementation
* @throws SynapseException if there is an error loading the class or invoking the `getInstance` method
*/
private static TokenCacheProvider loadTokenCacheProvider(String classPath) {
try {
Class<?> clazz = Class.forName(classPath);
Method getInstanceMethod = clazz.getMethod("getInstance");
return (TokenCacheProvider) getInstanceMethod.invoke(null);
} catch (ClassNotFoundException e) {
throw new SynapseException("Error loading class: " + classPath, e);
} catch (NoSuchMethodException e) {
throw new SynapseException("getInstance method not found for class: " + classPath, e);
} catch (InvocationTargetException | IllegalAccessException e) {
throw new SynapseException("Error invoking getInstance method for class: " + classPath, e);
}
}
}
Loading

0 comments on commit 097ae45

Please sign in to comment.