Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions athena-docdb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@
<artifactId>aws-athena-federation-sdk</artifactId>
<version>2022.47.1</version>
<classifier>withdep</classifier>
<exclusions>
<!-- replaced with jcl-over-slf4j -->
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
Expand All @@ -34,6 +27,12 @@
<artifactId>docdb</artifactId>
<version>${aws-sdk-v2.version}</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>software.amazon.awssdk</groupId>
<artifactId>netty-nio-client</artifactId>
</exclusion>
</exclusions>
</dependency>
<!-- https://mvnrepository.com/artifact/software.amazon.awscdk/docdb -->
<dependency>
Expand Down Expand Up @@ -86,6 +85,12 @@
<version>${log4j2Version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>5.13.3</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.amazonaws.athena.connector.lambda.data.BlockWriter;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
Expand All @@ -39,9 +40,16 @@
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest;
import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.DataSourceOptimizations;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.ComplexExpressionPushdownSubType;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.LimitPushdownSubType;
import com.amazonaws.athena.connector.lambda.request.FederationRequest;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connector.lambda.security.FederatedIdentity;
import com.amazonaws.athena.connectors.docdb.qpt.DocDBQueryPassthrough;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.mongodb.client.MongoClient;
Expand All @@ -62,10 +70,15 @@
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.ENFORCE_SSL;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.FAS_TOKEN;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.JDBC_PARAMS;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.PORT;
import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE;

/**
Expand All @@ -86,13 +99,13 @@ public class DocDBMetadataHandler

//Used to denote the 'type' of this connector for diagnostic purposes.
private static final String SOURCE_TYPE = "documentdb";
private static final String CONNECTION_STRING_TEMPLATE = "mongodb://%s:%s@%s:%s/%s";
private static final String ENFORCE_SSL_JDBC_PARAM = "ssl=true&ssl_ca_certs=rds-combined-ca-bundle.pem";
//Field name used to store the connection string as a property on Split objects.
protected static final String DOCDB_CONN_STR = "connStr";
//The Env variable name used to store the default DocDB connection string if no catalog specific
//env variable is set.
private static final String DEFAULT_DOCDB = "default_docdb";
//The env secret_name to use if defined
private static final String SECRET_NAME = "secret_name";
//The Glue table property that indicates that a table matching the name of an DocDB table
//is indeed enabled for use by this connector.
private static final String DOCDB_METADATA_FLAG = "docdb-metadata-flag";
Expand All @@ -103,6 +116,14 @@ public class DocDBMetadataHandler
// used to filter out Glue databases which lack the docdb-metadata-flag in the URI.
private static final DatabaseFilter DB_FILTER = (Database database) -> (database.locationUri() != null && database.locationUri().contains(DOCDB_METADATA_FLAG));

private static final String SECRET_ARN_KEY = "secret_arn";
private static final String AUTH_DB_KEY = "AUTHENTICATION_DATABASE";

// JSON credential field names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should all these be here; athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBEnvironmentProperties.java ?

private static final String USERNAME_FIELD = "username";
private static final String PASSWORD_FIELD = "password";
public static final String HOST = "host";

private final GlueClient glue;
private final DocDBConnectionFactory connectionFactory;
private final DocDBQueryPassthrough queryPassthrough = new DocDBQueryPassthrough();
Expand Down Expand Up @@ -140,6 +161,16 @@ private MongoClient getOrCreateConn(MetadataRequest request)
/**
* Retrieves the DocDB connection details from an env variable matching the catalog name, if no such
* env variable exists we fall back to the default env variable defined by DEFAULT_DOCDB.
*
* <p>For federated requests, this method dynamically constructs the connection string using:
* <ul>
* <li>Host and port from federated identity config options</li>
* <li>Username and password extracted from AWS Secrets Manager (JSON format)</li>
* <li>SSL enforcement and authentication database settings</li>
* </ul>
*
* @param request The metadata request containing catalog name and federated identity information
* @return The DocDB connection string, either from environment variables or dynamically constructed for federated requests
*/
private String getConnStr(MetadataRequest request)
{
Expand All @@ -149,6 +180,11 @@ private String getConnStr(MetadataRequest request)
request.getCatalogName(), DEFAULT_DOCDB);
conStr = configOptions.get(DEFAULT_DOCDB);
}
if (isRequestFederated(request)) {
logger.info("Using federated request to frame default_docdb connection string.");
final Map<String, String> configOptionsFromFederatedIdentity = request.getIdentity().getConfigOptions();
conStr = getConfigOptionsFromFederatedIdentity(configOptionsFromFederatedIdentity);
}
return conStr;
}

Expand All @@ -157,6 +193,30 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);
capabilities.put(DataSourceOptimizations.SUPPORTS_LIMIT_PUSHDOWN.withSupportedSubTypes(
LimitPushdownSubType.INTEGER_CONSTANT
));

List<StandardFunctions> supportedFunctions = new ArrayList<>();
supportedFunctions.add(StandardFunctions.AND_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.IN_PREDICATE_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.NOT_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.IS_NULL_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME);
supportedFunctions.add(StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME);

// To check for $nin and $nor

capabilities.put(DataSourceOptimizations.SUPPORTS_COMPLEX_EXPRESSION_PUSHDOWN.withSupportedSubTypes(
ComplexExpressionPushdownSubType.SUPPORTED_FUNCTION_EXPRESSION_TYPES
.withSubTypeProperties(supportedFunctions.stream()
.map(f -> f.getFunctionName().getFunctionName())
.toArray(String[]::new))
));

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}
Expand Down Expand Up @@ -365,4 +425,119 @@ protected Field convertField(String name, String glueType)
{
return GlueFieldLexer.lex(name, glueType);
}

/**
* Constructs a DocDB connection string from federated identity configuration options.
*
* <p>This method dynamically builds a MongoDB connection string by:
* <ul>
* <li>Extracting host and port from the provided config options</li>
* <li>Retrieving credentials from AWS Secrets Manager using the secret ARN</li>
* <li>Parsing JSON credentials to extract username and password</li>
* <li>Applying SSL enforcement and authentication database settings</li>
* <li>Constructing the final MongoDB connection string with proper formatting</li>
* </ul>
*
* <p>Expected JSON credential format from Secrets Manager:
* <pre>
* {
* "username": "mongodbadmin",
* "password": "secretpassword",
* "engine": "mongo",
* "host": "cluster.docdb.amazonaws.com",
* "port": 27017
* }
* </pre>
*
* @param configOptions Map containing federated identity configuration including:
* HOST, PORT, secret_arn, JDBC_PARAMS, ENFORCE_SSL, AUTHENTICATION_DATABASE
* @return Fully constructed MongoDB connection string in format: mongodb://username:password@host:port/?jdbcParams
* @throws RuntimeException if JSON credential parsing fails or required parameters are missing
*/
private String getConfigOptionsFromFederatedIdentity(Map<String, String> configOptions)
{
final String secretName = getSecretNameFromArn(configOptions.get(SECRET_ARN_KEY));
final String credentials = getSecret(secretName, getRequestOverrideConfig(configOptions));
final String username;
final String password;
final String host;
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode credNode = mapper.readTree(credentials);
username = credNode.get(USERNAME_FIELD).asText();
password = credNode.get(PASSWORD_FIELD).asText();
host = credNode.get(HOST).asText();
}
catch (Exception e) {
logger.error("Failed to parse JSON credentials", e);
throw new RuntimeException("Invalid JSON credentials format", e);
}

String jdbcParams = configOptions.get(JDBC_PARAMS);
String enforceSsl = configOptions.get(ENFORCE_SSL);
String authDb = configOptions.getOrDefault(AUTH_DB_KEY, "");

if (Boolean.parseBoolean(enforceSsl)) {
if (jdbcParams == null) {
jdbcParams = ENFORCE_SSL_JDBC_PARAM;
}
else if (!jdbcParams.contains(ENFORCE_SSL_JDBC_PARAM)) {
jdbcParams = ENFORCE_SSL_JDBC_PARAM + "&" + jdbcParams;
}
}

String connStr = String.format(CONNECTION_STRING_TEMPLATE, username, password, host, configOptions.get(PORT),
authDb);
if (jdbcParams != null) {
connStr += "?" + jdbcParams;
}
return connStr;
}

/**
* Extracts the secret name from an AWS Secrets Manager ARN.
*
* <p>AWS Secrets Manager ARNs follow the format:
* {@code arn:aws:secretsmanager:region:account:secret:name-suffix}
*
* <p>This method extracts the secret name by:
* <ul>
* <li>Splitting the ARN by colons to get individual components</li>
* <li>Taking the 7th component (index 6) which contains "name-suffix"</li>
* <li>Removing the suffix (everything after the last hyphen) to get the clean secret name</li>
* </ul>
*
* @param secretArn The full AWS Secrets Manager ARN
* @return The extracted secret name without the suffix
* @throws ArrayIndexOutOfBoundsException if the ARN format is invalid
*/
private static String getSecretNameFromArn(String secretArn)
{
final String[] parts = secretArn.split(":");
final String nameWithSuffix = parts[6];
return nameWithSuffix.substring(0, nameWithSuffix.lastIndexOf('-'));
}

/**
* Determines if the current request is a federated request by checking for the presence of a FAS token.
*
* <p>A federated request is identified by:
* <ul>
* <li>The presence of a {@link FederatedIdentity} in the request</li>
* <li>The existence of configuration options within the federated identity</li>
* <li>The presence of a FAS (Federation Access Service) token in the config options</li>
* </ul>
*
* <p>Federated requests require dynamic connection string construction using credentials
* from AWS Secrets Manager rather than static environment variables.
*
* @param req The federation request to check
* @return true if this is a federated request with a FAS token, false otherwise
*/
private boolean isRequestFederated(FederationRequest req)
{
FederatedIdentity federatedIdentity = req.getIdentity();
Map<String, String> connectorRequestOptions = federatedIdentity != null ? federatedIdentity.getConfigOptions() : null;
return (connectorRequestOptions != null && connectorRequestOptions.get(FAS_TOKEN) != null);
}
}
Loading
Loading