diff --git a/athena-federation-sdk/pom.xml b/athena-federation-sdk/pom.xml
index e5a4baac91..0f0d8cbc78 100644
--- a/athena-federation-sdk/pom.xml
+++ b/athena-federation-sdk/pom.xml
@@ -290,6 +290,11 @@
core${io.substrait.version}
+
+ org.junit.jupiter
+ junit-jupiter-params
+ test
+
diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java
index 2b1d14bcef..8c3a446048 100644
--- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java
+++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java
@@ -179,6 +179,34 @@ public boolean offerValue(String fieldName, int row, Object value)
return false;
}
+ /**
+ * Attempts to write the provided value to the specified field on the specified row. This method does _not_ update the
+ * row count on the underlying Apache Arrow VectorSchema. You must call setRowCount(...) to ensure the values
+ * your have written are considered 'valid rows' and thus available when you attempt to serialize this Block. This
+ * method replies on BlockUtils' field conversion/coercion logic to convert the provided value into a type that
+ * matches Apache Arrow's supported serialization format. For more details on coercion please see @BlockUtils
+ *
+ * @param fieldName The name of the field you wish to write to.
+ * @param row The row number to write to. Note that Apache Arrow Blocks begin with row 0 just like a typical array.
+ * @param value The value you wish to write.
+ * @param hasQueryPlan Whether the operation is running under a query plan, if true, bypasses constraint checks.
+ * @return True if the value was written to the Block (even if the field is missing from the Block),
+ * False if the value was not written due to failing a constraint or if query plan is present.
+ * @note This method will take no action if the provided fieldName is not a valid field in this Block's Schema.
+ * In such cases the method will return true.
+ */
+ public boolean offerValue(String fieldName, int row, Object value, boolean hasQueryPlan)
+ {
+ if (!hasQueryPlan && !constraintEvaluator.apply(fieldName, value)) {
+ return false;
+ }
+ FieldVector vector = getFieldVector(fieldName);
+ if (vector != null) {
+ BlockUtils.setValue(vector, row, value);
+ }
+ return true;
+ }
+
/**
* Attempts to set the provided value for the given field name and row. If the Block's schema does not
* contain such a field, this method does nothing and returns false.
diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java
index 4e7008fd9e..276113bde3 100644
--- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java
+++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java
@@ -7,9 +7,9 @@
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -19,8 +19,16 @@
*/
package com.amazonaws.athena.connector.lambda.handlers;
+import com.amazonaws.athena.connector.credentials.CredentialsProvider;
+import com.amazonaws.athena.connector.credentials.DefaultCredentialsProvider;
+import com.amazonaws.athena.connector.lambda.request.FederationRequest;
+import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager;
+import com.amazonaws.athena.connector.lambda.security.FederatedIdentity;
import com.amazonaws.athena.connector.lambda.security.KmsEncryptionProvider;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
+import org.apache.commons.lang3.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
@@ -36,6 +44,72 @@
public interface FederationRequestHandler extends RequestStreamHandler
{
+ /**
+ * Gets the CachableSecretsManager instance used by this handler.
+ * Implementations must provide access to their secrets manager instance.
+ *
+ * @return The CachableSecretsManager instance
+ */
+ CachableSecretsManager getCachableSecretsManager();
+
+ /**
+ * Gets the KmsEncryptionProvider instance used by this handler.
+ * Implementations must provide access to their KMS encryption provider instance.
+ *
+ * @return The KmsEncryptionProvider instance
+ */
+ KmsEncryptionProvider getKmsEncryptionProvider();
+
+ /**
+ * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret}
+ * replaced by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found
+ * the function throws.
+ *
+ * @param rawString The string in which you'd like to replace SecretsManager placeholders.
+ * (e.g. ThisIsA${Secret}Here - The ${Secret} would be replaced with the contents of a SecretsManager
+ * secret called Secret. If no such secret is found, the function throws. If no ${} are found in
+ * the input string, nothing is replaced and the original string is returned.
+ * @return The processed string with secrets resolved
+ */
+ default String resolveSecrets(String rawString)
+ {
+ return getCachableSecretsManager().resolveSecrets(rawString);
+ }
+
+ /**
+ * Resolves secrets with default credentials format (username:password).
+ *
+ * @param rawString The string containing secret placeholders to resolve
+ * @return The processed string with secrets resolved in default credentials format
+ */
+ default String resolveWithDefaultCredentials(String rawString)
+ {
+ return getCachableSecretsManager().resolveWithDefaultCredentials(rawString);
+ }
+
+ /**
+ * Retrieves a secret from AWS Secrets Manager.
+ *
+ * @param secretName The name of the secret to retrieve
+ * @return The secret value
+ */
+ default String getSecret(String secretName)
+ {
+ return getCachableSecretsManager().getSecret(secretName);
+ }
+
+ /**
+ * Retrieves a secret from AWS Secrets Manager with request override configuration.
+ *
+ * @param secretName The name of the secret to retrieve
+ * @param requestOverrideConfiguration AWS request override configuration for federated requests
+ * @return The secret value
+ */
+ default String getSecret(String secretName, AwsRequestOverrideConfiguration requestOverrideConfiguration)
+ {
+ return getCachableSecretsManager().getSecret(secretName, requestOverrideConfiguration);
+ }
+
default AwsCredentials getSessionCredentials(String kmsKeyId,
String tokenString,
KmsEncryptionProvider kmsEncryptionProvider)
@@ -43,6 +117,40 @@ default AwsCredentials getSessionCredentials(String kmsKeyId,
return kmsEncryptionProvider.getFasCredentials(kmsKeyId, tokenString);
}
+ /**
+ * Gets the AWS request override configuration for a FederationRequest.
+ * This method extracts the configuration options from the federated identity and delegates
+ * to the Map-based overload.
+ *
+ * @param request The federation request
+ * @return The AWS request override configuration, or null if not a federated request
+ */
+ default AwsRequestOverrideConfiguration getRequestOverrideConfig(FederationRequest request)
+ {
+ if (isRequestFederated(request)) {
+ FederatedIdentity federatedIdentity = request.getIdentity();
+ Map connectorRequestOptions = federatedIdentity != null ? federatedIdentity.getConfigOptions() : null;
+
+ if (connectorRequestOptions != null && connectorRequestOptions.get(FAS_TOKEN) != null) {
+ return getRequestOverrideConfig(connectorRequestOptions);
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Gets the AWS request override configuration for the given config options.
+ * This is a convenience method that delegates to the full overload using the handler's
+ * KMS encryption provider.
+ *
+ * @param configOptions The configuration options map
+ * @return The AWS request override configuration, or null if not applicable
+ */
+ default AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions)
+ {
+ return getRequestOverrideConfig(configOptions, getKmsEncryptionProvider());
+ }
+
default AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions,
KmsEncryptionProvider kmsEncryptionProvider)
{
@@ -85,4 +193,54 @@ default AthenaClient getAthenaClient(AwsRequestOverrideConfiguration awsRequestO
return defaultAthena;
}
}
+
+ default boolean isRequestFederated(FederationRequest req)
+ {
+ FederatedIdentity federatedIdentity = req.getIdentity();
+ Map connectorRequestOptions = federatedIdentity != null ? federatedIdentity.getConfigOptions() : null;
+ return (connectorRequestOptions != null && connectorRequestOptions.get(FAS_TOKEN) != null);
+ }
+
+ /**
+ * Gets a credentials provider for database connections with optional request override configuration.
+ * This method checks if a secret name is configured and creates a credentials provider if available.
+ * Subclasses can override createCredentialsProvider() to provide custom credential provider implementations.
+ *
+ * @param requestOverrideConfiguration Optional AWS request override configuration for federated requests
+ * @return CredentialsProvider instance or null if no secret is configured
+ */
+ default CredentialsProvider getCredentialProvider(AwsRequestOverrideConfiguration requestOverrideConfiguration)
+ {
+ final String secretName = getDatabaseConnectionSecret();
+ if (StringUtils.isNotBlank(secretName)) {
+ Logger logger = LoggerFactory.getLogger(this.getClass());
+ logger.info("Using Secrets Manager.");
+ return createCredentialsProvider(secretName, requestOverrideConfiguration);
+ }
+ return null;
+ }
+
+ /**
+ * Factory method to create CredentialsProvider. Subclasses can override this to provide
+ * custom credential provider implementations (e.g., SnowflakeCredentialsProvider).
+ *
+ * @param secretName The secret name to retrieve credentials from
+ * @param requestOverrideConfiguration Optional AWS request override configuration
+ * @return CredentialsProvider instance
+ */
+ default CredentialsProvider createCredentialsProvider(String secretName, AwsRequestOverrideConfiguration requestOverrideConfiguration)
+ {
+ return new DefaultCredentialsProvider(getSecret(secretName, requestOverrideConfiguration));
+ }
+
+ /**
+ * Gets the database connection secret name. Subclasses that use database credentials
+ * should override this method to provide the secret name from their configuration.
+ *
+ * @return The secret name, or null if not applicable
+ */
+ default String getDatabaseConnectionSecret()
+ {
+ return null;
+ }
}
diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java
index 5921ac6c8c..b7b05ace0a 100644
--- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java
+++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java
@@ -66,6 +66,7 @@
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
@@ -75,12 +76,13 @@
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
-import software.amazon.awssdk.utils.StringUtils;
+import software.amazon.awssdk.utils.CollectionUtils;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
+import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
@@ -172,13 +174,13 @@ public MetadataHandler(String sourceType, java.util.Map configOp
* @param sourceType Used to aid in logging diagnostic info when raising a support case.
*/
public MetadataHandler(
- EncryptionKeyFactory encryptionKeyFactory,
- SecretsManagerClient secretsManager,
- AthenaClient athena,
- String sourceType,
- String spillBucket,
- String spillPrefix,
- java.util.Map configOptions)
+ EncryptionKeyFactory encryptionKeyFactory,
+ SecretsManagerClient secretsManager,
+ AthenaClient athena,
+ String sourceType,
+ String spillBucket,
+ String spillPrefix,
+ java.util.Map configOptions)
{
this.configOptions = configOptions;
this.encryptionKeyFactory = encryptionKeyFactory;
@@ -194,38 +196,22 @@ public MetadataHandler(
}
/**
- * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret}
- * by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found
- * the function throws.
- *
- * @param rawString The string in which you'd like to replace SecretsManager placeholders.
- * (e.g. ThisIsA${Secret}Here - The ${Secret} would be replaced with the contents of an SecretsManager
- * secret called Secret. If no such secret is found, the function throws. If no ${} are found in
- * the input string, nothing is replaced and the original string is returned.
+ * Gets the CachableSecretsManager instance used by this handler.
+ * This is used by credential providers to reuse the same secrets manager instance.
+ * @return The CachableSecretsManager instance
*/
- protected String resolveSecrets(String rawString)
- {
- return secretsManager.resolveSecrets(rawString);
- }
-
- protected String resolveWithDefaultCredentials(String rawString)
- {
- return secretsManager.resolveWithDefaultCredentials(rawString);
- }
-
- protected String getSecret(String secretName)
+ public CachableSecretsManager getCachableSecretsManager()
{
- return secretsManager.getSecret(secretName);
+ return secretsManager;
}
/**
- * Gets the CachableSecretsManager instance used by this handler.
- * This is used by credential providers to reuse the same secrets manager instance.
- * @return The CachableSecretsManager instance
+ * Gets the KmsEncryptionProvider instance used by this handler.
+ * @return The KmsEncryptionProvider instance
*/
- protected CachableSecretsManager getCachableSecretsManager()
+ public KmsEncryptionProvider getKmsEncryptionProvider()
{
- return secretsManager;
+ return kmsEncryptionProvider;
}
protected EncryptionKey makeEncryptionKey()
@@ -241,13 +227,17 @@ protected EncryptionKey makeEncryptionKey(AwsRequestOverrideConfiguration awsReq
/**
* Used to make a spill location for a split. Each split should have a unique spill location, so be sure
* to call this method once per split!
- * @param request
+ * @param request
* @return A unique spill location.
*/
protected SpillLocation makeSpillLocation(MetadataRequest request)
{
FederatedIdentity federatedIdentity = request.getIdentity();
Map configOptions = federatedIdentity.getConfigOptions();
+ if (CollectionUtils.isNullOrEmpty(configOptions)) {
+ logger.debug("configOptions is empty from federation. Use default configOptions.");
+ configOptions = new HashMap<>(this.configOptions);
+ }
String queryId = request.getQueryId();
String prefix = StringUtils.isBlank(configOptions.get(SPILL_PREFIX_ENV))
? spillPrefix : configOptions.get(SPILL_PREFIX_ENV);
@@ -292,9 +282,9 @@ public final void handleRequest(InputStream inputStream, OutputStream outputStre
}
protected final void doHandleRequest(BlockAllocator allocator,
- ObjectMapper objectMapper,
- MetadataRequest req,
- OutputStream outputStream)
+ ObjectMapper objectMapper,
+ MetadataRequest req,
+ OutputStream outputStream)
throws Exception
{
logger.info("doHandleRequest: request[{}]", req);
@@ -465,8 +455,8 @@ public GetTableLayoutResponse doGetTableLayout(final BlockAllocator allocator, f
try (ConstraintEvaluator constraintEvaluator = new ConstraintEvaluator(allocator,
constraintSchema.build(),
request.getConstraints());
- QueryStatusChecker queryStatusChecker = new QueryStatusChecker(getAthenaClient(overrideConfig, athena),
- athenaInvoker, request.getQueryId())
+ QueryStatusChecker queryStatusChecker = new QueryStatusChecker(getAthenaClient(overrideConfig, athena),
+ athenaInvoker, request.getQueryId())
) {
Block partitions = allocator.createBlock(partitionSchemaBuilder.build());
partitions.constrain(constraintEvaluator);
@@ -513,7 +503,7 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl
* for pushing down into the source you are querying.
*/
public abstract void getPartitions(final BlockWriter blockWriter,
- final GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker)
+ final GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker)
throws Exception;
/**
@@ -575,11 +565,6 @@ public void onPing(PingRequest request)
//NoOp
}
- public AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions)
- {
- return getRequestOverrideConfig(configOptions, kmsEncryptionProvider);
- }
-
/**
* Helper function that is used to ensure we always have a non-null response.
*
diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java
index 5e4727a713..6f27b33fcf 100644
--- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java
+++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java
@@ -28,6 +28,7 @@
import com.amazonaws.athena.connector.lambda.data.S3BlockSpiller;
import com.amazonaws.athena.connector.lambda.data.SpillConfig;
import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintEvaluator;
+import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse;
@@ -43,8 +44,11 @@
import com.amazonaws.athena.connector.lambda.security.FederatedIdentity;
import com.amazonaws.athena.connector.lambda.security.KmsEncryptionProvider;
import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory;
+import com.amazonaws.athena.connector.substrait.util.LimitAndSortHelper;
import com.amazonaws.services.lambda.runtime.Context;
import com.fasterxml.jackson.databind.ObjectMapper;
+import io.substrait.proto.Plan;
+import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
@@ -58,7 +62,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
-import java.util.Map;
+import java.util.List;
import static com.amazonaws.athena.connector.lambda.handlers.AthenaExceptionFilter.ATHENA_EXCEPTION_FILTER;
import static com.amazonaws.athena.connector.lambda.handlers.FederationCapabilities.CAPABILITIES;
@@ -111,38 +115,22 @@ public RecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, Ath
}
/**
- * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret}
- * by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found
- * the function throws.
- *
- * @param rawString The string in which you'd like to replace SecretsManager placeholders.
- * (e.g. ThisIsA${Secret}Here - The ${Secret} would be replaced with the contents of an SecretsManager
- * secret called Secret. If no such secret is found, the function throws. If no ${} are found in
- * the input string, nothing is replaced and the original string is returned.
+ * Gets the CachableSecretsManager instance used by this handler.
+ * This is used by credential providers to reuse the same secrets manager instance.
+ * @return The CachableSecretsManager instance
*/
- protected String resolveSecrets(String rawString)
- {
- return secretsManager.resolveSecrets(rawString);
- }
-
- protected String resolveWithDefaultCredentials(String rawString)
+ public CachableSecretsManager getCachableSecretsManager()
{
- return secretsManager.resolveWithDefaultCredentials(rawString);
- }
-
- protected String getSecret(String secretName)
- {
- return secretsManager.getSecret(secretName);
+ return secretsManager;
}
/**
- * Gets the CachableSecretsManager instance used by this handler.
- * This is used by credential providers to reuse the same secrets manager instance.
- * @return The CachableSecretsManager instance
+ * Gets the KmsEncryptionProvider instance used by this handler.
+ * @return The KmsEncryptionProvider instance
*/
- protected CachableSecretsManager getCachableSecretsManager()
+ public KmsEncryptionProvider getKmsEncryptionProvider()
{
- return secretsManager;
+ return kmsEncryptionProvider;
}
public final void handleRequest(InputStream inputStream, OutputStream outputStream, final Context context)
@@ -173,9 +161,9 @@ public final void handleRequest(InputStream inputStream, OutputStream outputStre
}
protected final void doHandleRequest(BlockAllocator allocator,
- ObjectMapper objectMapper,
- RecordRequest req,
- OutputStream outputStream)
+ ObjectMapper objectMapper,
+ RecordRequest req,
+ OutputStream outputStream)
throws Exception
{
logger.info("doHandleRequest: request[{}]", req);
@@ -217,8 +205,8 @@ public RecordResponse doReadRecords(BlockAllocator allocator, ReadRecordsRequest
try (ConstraintEvaluator evaluator = new ConstraintEvaluator(allocator,
request.getSchema(),
request.getConstraints());
- S3BlockSpiller spiller = new S3BlockSpiller(s3Client, spillConfig, allocator, request.getSchema(), evaluator, configOptions);
- QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athenaClient, athenaInvoker, request.getQueryId())
+ S3BlockSpiller spiller = new S3BlockSpiller(s3Client, spillConfig, allocator, request.getSchema(), evaluator, configOptions);
+ QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athenaClient, athenaInvoker, request.getQueryId())
) {
readWithConstraint(spiller, request, queryStatusChecker);
@@ -234,11 +222,6 @@ public RecordResponse doReadRecords(BlockAllocator allocator, ReadRecordsRequest
}
}
- public AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions)
- {
- return getRequestOverrideConfig(configOptions, kmsEncryptionProvider);
- }
-
/**
* A more stream lined option for reading the row data associated with the provided Split. This method differs from
* doReadRecords(...) in that the SDK handles more of the request lifecycle, leaving you to focus more closely on
@@ -292,6 +275,22 @@ protected void onPing(PingRequest request)
//NoOp
}
+ /**
+ * Determines if a LIMIT can be applied and extracts the limit value.
+ */
+ protected Pair getLimitFromPlan(Plan plan, Constraints constraints)
+ {
+ return LimitAndSortHelper.getLimit(plan, constraints);
+ }
+
+ /**
+ * Extracts sort information from Substrait plan for ORDER BY pushdown optimization.
+ */
+ protected Pair> getSortFromPlan(Plan plan)
+ {
+ return LimitAndSortHelper.getSortFromPlan(plan);
+ }
+
private void assertNotNull(FederationResponse response)
{
if (response == null) {
diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java
index 43f607ea71..34a919be1e 100644
--- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java
+++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java
@@ -25,6 +25,7 @@
import org.apache.arrow.util.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest;
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
@@ -133,6 +134,18 @@ private String useDefaultCredentials(String secret) throws RuntimeException
* @return The value of the secret, throws if no such secret is found.
*/
public String getSecret(String secretName)
+ {
+ return getSecret(secretName, null);
+ }
+
+ /**
+ * Retrieves a secret from SecretsManager, first checking the cache. Newly fetched secrets are added to the cache.
+ *
+ * @param secretName The name of the secret to retrieve.
+ * @param overrideConfiguration override configuration for the aws request. Most commonly, for FAS_TOKEN AwsCredentials override for federation requests.
+ * @return The value of the secret, throws if no such secret is found.
+ */
+ public String getSecret(String secretName, AwsRequestOverrideConfiguration overrideConfiguration)
{
CacheEntry cacheEntry = cache.get(secretName);
@@ -140,6 +153,7 @@ public String getSecret(String secretName)
logger.info("getSecret: Resolving secret[{}].", secretName);
GetSecretValueResponse secretValueResult = secretsManager.getSecretValue(GetSecretValueRequest.builder()
.secretId(secretName)
+ .overrideConfiguration(overrideConfiguration)
.build());
cacheEntry = new CacheEntry(secretName, secretValueResult.secretString());
evictCache(cache.size() >= MAX_CACHE_SIZE);
diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java
index 39c403f67c..2b702ee6ab 100644
--- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java
+++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java
@@ -7,9 +7,9 @@
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -20,6 +20,7 @@
package com.amazonaws.athena.connector.substrait;
import com.amazonaws.athena.connector.substrait.model.ColumnPredicate;
+import com.amazonaws.athena.connector.substrait.model.LogicalExpression;
import com.amazonaws.athena.connector.substrait.model.SubstraitOperator;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
@@ -32,6 +33,22 @@
import java.util.List;
import java.util.Map;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.AND_BOOL;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.EQUAL_ANY_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GT_ANY_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GTE_ANY_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GTE_PTS_PTS;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GT_PTS_PTS;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.IS_NOT_NULL_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.IS_NULL_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LT_ANY_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LTE_ANY_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LTE_PTS_PTS;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LT_PTS_PTS;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.NOT_BOOL;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.NOT_EQUAL_ANY_ANY;
+import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.OR_BOOL;
+
/**
* Utility class for parsing Substrait function expressions into filter predicates and column predicates.
* This parser handles scalar functions, comparison operations, logical operations (AND/OR), and literal values
@@ -47,7 +64,7 @@ private SubstraitFunctionParser()
/**
* Parses a Substrait expression into a map of column predicates grouped by column name.
* This method extracts all column predicates from the expression and organizes them by the column they apply to.
- *
+ *
* @param extensionDeclarationList List of function extension declarations from the Substrait plan
* @param expression The Substrait expression to parse
* @param columnNames List of column names in the schema for field reference resolution
@@ -68,7 +85,7 @@ public static Map> getColumnPredicatesMap(List parseColumnPredicates(List columnPredicates = new ArrayList<>();
ScalarFunctionInfo functionInfo = extractScalarFunctionInfo(expression, extensionDeclarationList);
-
+
if (functionInfo == null) {
return columnPredicates;
}
+ // Handle NOT unary operator
+ if (NOT_BOOL.equals(functionInfo.getFunctionName())) {
+ ColumnPredicate notPredicate = handleNotOperator(functionInfo, extensionDeclarationList, columnNames);
+ if (notPredicate != null) {
+ columnPredicates.add(notPredicate);
+ }
+ return columnPredicates;
+ }
+
// Handle logical operators by flattening
if (isLogicalOperator(functionInfo.getFunctionName())) {
for (FunctionArgument argument : functionInfo.getArguments()) {
@@ -108,10 +134,79 @@ public static List parseColumnPredicates(List extensionDeclarationList,
+ Expression expression,
+ List columnNames)
+ {
+ if (expression == null) {
+ return null;
+ }
+
+ // Extract function information from the Substrait expression
+ ScalarFunctionInfo functionInfo = extractScalarFunctionInfo(expression, extensionDeclarationList);
+
+ if (functionInfo == null) {
+ return null;
+ }
+
+ // Handle NOT operator by delegating to handleNotOperator which converts NOT(expr)
+ // to appropriate predicates (NOT_EQUAL, NAND, NOR) based on the inner expression type
+ if (NOT_BOOL.equals(functionInfo.getFunctionName())) {
+ ColumnPredicate notPredicate = handleNotOperator(functionInfo, extensionDeclarationList, columnNames);
+ if (notPredicate != null) {
+ return new LogicalExpression(notPredicate);
+ }
+ return null;
+ }
+
+ // Handle logical operators (AND/OR) by building tree structure instead of flattening
+ // This preserves the original logical hierarchy from the SQL query
+ if (isLogicalOperator(functionInfo.getFunctionName())) {
+ List childExpressions = new ArrayList<>();
+
+ // Recursively parse each child argument to build the expression tree
+ for (FunctionArgument argument : functionInfo.getArguments()) {
+ LogicalExpression childExpr = parseLogicalExpression(extensionDeclarationList, argument.getValue(), columnNames);
+ if (childExpr != null) {
+ childExpressions.add(childExpr);
+ }
+ }
+
+ // Create logical expression node with the operator and its children
+ SubstraitOperator operator = mapToOperator(functionInfo.getFunctionName());
+ return new LogicalExpression(operator, childExpressions);
+ }
+
+ // Handle binary comparison operations (e.g., column = value, column > value)
+ if (functionInfo.getArguments().size() == 2) {
+ ColumnPredicate predicate = createBinaryColumnPredicate(functionInfo, columnNames);
+ // Wrap the predicate in a leaf LogicalExpression node
+ return new LogicalExpression(predicate);
+ }
+
+ // Handle unary operations (e.g., column IS NULL, column IS NOT NULL)
+ if (functionInfo.getArguments().size() == 1) {
+ ColumnPredicate predicate = createUnaryColumnPredicate(functionInfo, columnNames);
+ // Wrap the predicate in a leaf LogicalExpression node
+ return new LogicalExpression(predicate);
+ }
+
+ return null;
+ }
+
/**
* Creates a mapping from function reference anchors to function names.
* This mapping is used to resolve function references in Substrait expressions.
- *
+ *
* @param extensionDeclarationList List of extension declarations containing function definitions
* @return A map from function anchor IDs to function names
*/
@@ -161,15 +256,15 @@ private static ScalarFunctionInfo extractScalarFunctionInfo(Expression expressio
if (!expression.hasScalarFunction()) {
return null;
}
-
+
Expression.ScalarFunction scalarFunction = expression.getScalarFunction();
Map functionMap = mapFunctionReferences(extensionDeclarationList);
String functionName = functionMap.get(scalarFunction.getFunctionReference());
List arguments = scalarFunction.getArgumentsList();
-
+
return new ScalarFunctionInfo(functionName, arguments);
}
-
+
/**
* Creates a column predicate for unary operations.
*/
@@ -179,7 +274,7 @@ private static ColumnPredicate createUnaryColumnPredicate(ScalarFunctionInfo fun
SubstraitOperator substraitOperator = mapToOperator(functionInfo.getFunctionName());
return new ColumnPredicate(columnName, substraitOperator, null, null);
}
-
+
/**
* Creates a column predicate for binary operations.
*/
@@ -190,46 +285,53 @@ private static ColumnPredicate createBinaryColumnPredicate(ScalarFunctionInfo fu
SubstraitOperator substraitOperator = mapToOperator(functionInfo.getFunctionName());
return new ColumnPredicate(columnName, substraitOperator, value.getLeft(), value.getRight());
}
-
+
/**
* Checks if a function name represents a logical operator.
*/
private static boolean isLogicalOperator(String functionName)
{
- return "and:bool".equals(functionName) || "or:bool".equals(functionName);
+ return AND_BOOL.equals(functionName) || OR_BOOL.equals(functionName);
}
/**
* Maps Substrait function names to corresponding Operator enum values.
- * This method is mapping only small set of operators, and we will extend this as we need.
- *
- * @param functionName The Substrait function name (e.g., "gt:any_any", "equal:any_any")
+ * This method supports comparison operators, logical operators, null checks, and the NOT operator.
+ * The mapping will be extended as additional operators are needed.
+ *
+ * @param functionName The Substrait function name (e.g., "gt:any_any", "equal:any_any", "not:bool")
* @return The corresponding Operator enum value
* @throws UnsupportedOperationException if the function name is not supported
*/
private static SubstraitOperator mapToOperator(String functionName)
{
switch (functionName) {
- case "gt:any_any":
+ case GT_ANY_ANY:
+ case GT_PTS_PTS:
return SubstraitOperator.GREATER_THAN;
- case "gte:any_any":
+ case GTE_ANY_ANY:
+ case GTE_PTS_PTS:
return SubstraitOperator.GREATER_THAN_OR_EQUAL_TO;
- case "lt:any_any":
+ case LT_ANY_ANY:
+ case LT_PTS_PTS:
return SubstraitOperator.LESS_THAN;
- case "lte:any_any":
+ case LTE_ANY_ANY:
+ case LTE_PTS_PTS:
return SubstraitOperator.LESS_THAN_OR_EQUAL_TO;
- case "equal:any_any":
+ case EQUAL_ANY_ANY:
return SubstraitOperator.EQUAL;
- case "not_equal:any_any":
+ case NOT_EQUAL_ANY_ANY:
return SubstraitOperator.NOT_EQUAL;
- case "is_null:any":
+ case IS_NULL_ANY:
return SubstraitOperator.IS_NULL;
- case "is_not_null:any":
+ case IS_NOT_NULL_ANY:
return SubstraitOperator.IS_NOT_NULL;
- case "and:bool":
+ case AND_BOOL:
return SubstraitOperator.AND;
- case "or:bool":
+ case OR_BOOL:
return SubstraitOperator.OR;
+ case NOT_BOOL:
+ return SubstraitOperator.NOT;
default:
throw new UnsupportedOperationException("Unsupported operator function: " + functionName);
}
@@ -259,4 +361,142 @@ public List getArguments()
return arguments;
}
}
+
+ /**
+ * Handles NOT operator expressions by analyzing the inner expression and applying appropriate negation logic.
+ * Supports various NOT patterns including NOT(AND), NOT(OR), NOT IN, and simple predicate negation.
+ *
+ * @param notFunctionInfo The scalar function info for the NOT operation
+ * @param extensionDeclarationList List of extension declarations for function mapping
+ * @param columnNames List of available column names
+ * @return ColumnPredicate representing the negated expression, or null if not supported
+ */
+ private static ColumnPredicate handleNotOperator(
+ ScalarFunctionInfo notFunctionInfo,
+ List extensionDeclarationList,
+ List columnNames)
+ {
+ if (notFunctionInfo.getArguments().size() != 1) {
+ return null;
+ }
+ Expression innerExpression = notFunctionInfo.getArguments().get(0).getValue();
+ ScalarFunctionInfo innerFunctionInfo = extractScalarFunctionInfo(innerExpression, extensionDeclarationList);
+ // Case: NOT(AND(...)) => NAND
+ if (innerFunctionInfo != null && AND_BOOL.equals(innerFunctionInfo.getFunctionName())) {
+ List childPredicates =
+ parseColumnPredicates(extensionDeclarationList, innerExpression, columnNames);
+ return new ColumnPredicate(
+ null,
+ SubstraitOperator.NAND,
+ childPredicates,
+ null
+ );
+ }
+ // Case: NOT(OR(...)) => NOR
+ if (innerFunctionInfo != null && OR_BOOL.equals(innerFunctionInfo.getFunctionName())) {
+ List childPredicates =
+ parseColumnPredicates(extensionDeclarationList, innerExpression, columnNames);
+ return new ColumnPredicate(
+ null,
+ SubstraitOperator.NOR,
+ childPredicates,
+ null
+ );
+ }
+ // NOT IN pattern detection - reserved for future use cases
+ // NOTE: This handles scenarios where expressions other than direct OR operations
+ // may flatten to multiple EQUAL predicates on the same column. Currently, standard
+ // NOT(OR(...)) expressions are processed as NOR operations above.
+ List innerPredicates = parseColumnPredicates(extensionDeclarationList, innerExpression, columnNames);
+ if (isNotInPattern(innerPredicates)) {
+ return createNotInPredicate(innerPredicates);
+ }
+ if (innerPredicates.size() == 1) {
+ return createNegatedPredicate(innerPredicates.get(0));
+ }
+ return null;
+ }
+
+ /**
+ * Determines if a list of predicates represents a NOT IN pattern.
+ * A NOT IN pattern consists of multiple EQUAL predicates on the same column.
+ *
+ * @param predicates List of column predicates to analyze
+ * @return true if the predicates form a NOT IN pattern, false otherwise
+ */
+ private static boolean isNotInPattern(List predicates)
+ {
+ if (predicates.size() <= 1) {
+ return false;
+ }
+ String firstColumn = predicates.get(0).getColumn();
+ for (ColumnPredicate predicate : predicates) {
+ if (predicate.getOperator() != SubstraitOperator.EQUAL ||
+ !predicate.getColumn().equals(firstColumn)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Creates a NOT_IN predicate from a list of EQUAL predicates on the same column.
+ * Extracts all values from the EQUAL predicates and combines them into a single NOT_IN operation.
+ *
+ * @param equalPredicates List of EQUAL predicates on the same column
+ * @return ColumnPredicate with NOT_IN operator containing all excluded values, or null if input is empty
+ */
+ private static ColumnPredicate createNotInPredicate(List equalPredicates)
+ {
+ if (equalPredicates.isEmpty()) {
+ return null;
+ }
+ String column = equalPredicates.get(0).getColumn();
+ List