diff --git a/athena-federation-sdk/pom.xml b/athena-federation-sdk/pom.xml index e5a4baac91..0f0d8cbc78 100644 --- a/athena-federation-sdk/pom.xml +++ b/athena-federation-sdk/pom.xml @@ -290,6 +290,11 @@ core ${io.substrait.version} + + org.junit.jupiter + junit-jupiter-params + test + diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java index 2b1d14bcef..8c3a446048 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/Block.java @@ -179,6 +179,34 @@ public boolean offerValue(String fieldName, int row, Object value) return false; } + /** + * Attempts to write the provided value to the specified field on the specified row. This method does _not_ update the + * row count on the underlying Apache Arrow VectorSchema. You must call setRowCount(...) to ensure the values + * your have written are considered 'valid rows' and thus available when you attempt to serialize this Block. This + * method replies on BlockUtils' field conversion/coercion logic to convert the provided value into a type that + * matches Apache Arrow's supported serialization format. For more details on coercion please see @BlockUtils + * + * @param fieldName The name of the field you wish to write to. + * @param row The row number to write to. Note that Apache Arrow Blocks begin with row 0 just like a typical array. + * @param value The value you wish to write. + * @param hasQueryPlan Whether the operation is running under a query plan, if true, bypasses constraint checks. + * @return True if the value was written to the Block (even if the field is missing from the Block), + * False if the value was not written due to failing a constraint or if query plan is present. + * @note This method will take no action if the provided fieldName is not a valid field in this Block's Schema. + * In such cases the method will return true. + */ + public boolean offerValue(String fieldName, int row, Object value, boolean hasQueryPlan) + { + if (!hasQueryPlan && !constraintEvaluator.apply(fieldName, value)) { + return false; + } + FieldVector vector = getFieldVector(fieldName); + if (vector != null) { + BlockUtils.setValue(vector, row, value); + } + return true; + } + /** * Attempts to set the provided value for the given field name and row. If the Block's schema does not * contain such a field, this method does nothing and returns false. diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java index 4e7008fd9e..276113bde3 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/FederationRequestHandler.java @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,8 +19,16 @@ */ package com.amazonaws.athena.connector.lambda.handlers; +import com.amazonaws.athena.connector.credentials.CredentialsProvider; +import com.amazonaws.athena.connector.credentials.DefaultCredentialsProvider; +import com.amazonaws.athena.connector.lambda.request.FederationRequest; +import com.amazonaws.athena.connector.lambda.security.CachableSecretsManager; +import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.KmsEncryptionProvider; import com.amazonaws.services.lambda.runtime.RequestStreamHandler; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; @@ -36,6 +44,72 @@ public interface FederationRequestHandler extends RequestStreamHandler { + /** + * Gets the CachableSecretsManager instance used by this handler. + * Implementations must provide access to their secrets manager instance. + * + * @return The CachableSecretsManager instance + */ + CachableSecretsManager getCachableSecretsManager(); + + /** + * Gets the KmsEncryptionProvider instance used by this handler. + * Implementations must provide access to their KMS encryption provider instance. + * + * @return The KmsEncryptionProvider instance + */ + KmsEncryptionProvider getKmsEncryptionProvider(); + + /** + * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret} + * replaced by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found + * the function throws. + * + * @param rawString The string in which you'd like to replace SecretsManager placeholders. + * (e.g. ThisIsA${Secret}Here - The ${Secret} would be replaced with the contents of a SecretsManager + * secret called Secret. If no such secret is found, the function throws. If no ${} are found in + * the input string, nothing is replaced and the original string is returned. + * @return The processed string with secrets resolved + */ + default String resolveSecrets(String rawString) + { + return getCachableSecretsManager().resolveSecrets(rawString); + } + + /** + * Resolves secrets with default credentials format (username:password). + * + * @param rawString The string containing secret placeholders to resolve + * @return The processed string with secrets resolved in default credentials format + */ + default String resolveWithDefaultCredentials(String rawString) + { + return getCachableSecretsManager().resolveWithDefaultCredentials(rawString); + } + + /** + * Retrieves a secret from AWS Secrets Manager. + * + * @param secretName The name of the secret to retrieve + * @return The secret value + */ + default String getSecret(String secretName) + { + return getCachableSecretsManager().getSecret(secretName); + } + + /** + * Retrieves a secret from AWS Secrets Manager with request override configuration. + * + * @param secretName The name of the secret to retrieve + * @param requestOverrideConfiguration AWS request override configuration for federated requests + * @return The secret value + */ + default String getSecret(String secretName, AwsRequestOverrideConfiguration requestOverrideConfiguration) + { + return getCachableSecretsManager().getSecret(secretName, requestOverrideConfiguration); + } + default AwsCredentials getSessionCredentials(String kmsKeyId, String tokenString, KmsEncryptionProvider kmsEncryptionProvider) @@ -43,6 +117,40 @@ default AwsCredentials getSessionCredentials(String kmsKeyId, return kmsEncryptionProvider.getFasCredentials(kmsKeyId, tokenString); } + /** + * Gets the AWS request override configuration for a FederationRequest. + * This method extracts the configuration options from the federated identity and delegates + * to the Map-based overload. + * + * @param request The federation request + * @return The AWS request override configuration, or null if not a federated request + */ + default AwsRequestOverrideConfiguration getRequestOverrideConfig(FederationRequest request) + { + if (isRequestFederated(request)) { + FederatedIdentity federatedIdentity = request.getIdentity(); + Map connectorRequestOptions = federatedIdentity != null ? federatedIdentity.getConfigOptions() : null; + + if (connectorRequestOptions != null && connectorRequestOptions.get(FAS_TOKEN) != null) { + return getRequestOverrideConfig(connectorRequestOptions); + } + } + return null; + } + + /** + * Gets the AWS request override configuration for the given config options. + * This is a convenience method that delegates to the full overload using the handler's + * KMS encryption provider. + * + * @param configOptions The configuration options map + * @return The AWS request override configuration, or null if not applicable + */ + default AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions) + { + return getRequestOverrideConfig(configOptions, getKmsEncryptionProvider()); + } + default AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions, KmsEncryptionProvider kmsEncryptionProvider) { @@ -85,4 +193,54 @@ default AthenaClient getAthenaClient(AwsRequestOverrideConfiguration awsRequestO return defaultAthena; } } + + default boolean isRequestFederated(FederationRequest req) + { + FederatedIdentity federatedIdentity = req.getIdentity(); + Map connectorRequestOptions = federatedIdentity != null ? federatedIdentity.getConfigOptions() : null; + return (connectorRequestOptions != null && connectorRequestOptions.get(FAS_TOKEN) != null); + } + + /** + * Gets a credentials provider for database connections with optional request override configuration. + * This method checks if a secret name is configured and creates a credentials provider if available. + * Subclasses can override createCredentialsProvider() to provide custom credential provider implementations. + * + * @param requestOverrideConfiguration Optional AWS request override configuration for federated requests + * @return CredentialsProvider instance or null if no secret is configured + */ + default CredentialsProvider getCredentialProvider(AwsRequestOverrideConfiguration requestOverrideConfiguration) + { + final String secretName = getDatabaseConnectionSecret(); + if (StringUtils.isNotBlank(secretName)) { + Logger logger = LoggerFactory.getLogger(this.getClass()); + logger.info("Using Secrets Manager."); + return createCredentialsProvider(secretName, requestOverrideConfiguration); + } + return null; + } + + /** + * Factory method to create CredentialsProvider. Subclasses can override this to provide + * custom credential provider implementations (e.g., SnowflakeCredentialsProvider). + * + * @param secretName The secret name to retrieve credentials from + * @param requestOverrideConfiguration Optional AWS request override configuration + * @return CredentialsProvider instance + */ + default CredentialsProvider createCredentialsProvider(String secretName, AwsRequestOverrideConfiguration requestOverrideConfiguration) + { + return new DefaultCredentialsProvider(getSecret(secretName, requestOverrideConfiguration)); + } + + /** + * Gets the database connection secret name. Subclasses that use database credentials + * should override this method to provide the secret name from their configuration. + * + * @return The secret name, or null if not applicable + */ + default String getDatabaseConnectionSecret() + { + return null; + } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java index 5921ac6c8c..b7b05ace0a 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java @@ -66,6 +66,7 @@ import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; @@ -75,12 +76,13 @@ import software.amazon.awssdk.services.kms.KmsClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; -import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.CollectionUtils; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.UUID; @@ -172,13 +174,13 @@ public MetadataHandler(String sourceType, java.util.Map configOp * @param sourceType Used to aid in logging diagnostic info when raising a support case. */ public MetadataHandler( - EncryptionKeyFactory encryptionKeyFactory, - SecretsManagerClient secretsManager, - AthenaClient athena, - String sourceType, - String spillBucket, - String spillPrefix, - java.util.Map configOptions) + EncryptionKeyFactory encryptionKeyFactory, + SecretsManagerClient secretsManager, + AthenaClient athena, + String sourceType, + String spillBucket, + String spillPrefix, + java.util.Map configOptions) { this.configOptions = configOptions; this.encryptionKeyFactory = encryptionKeyFactory; @@ -194,38 +196,22 @@ public MetadataHandler( } /** - * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret} - * by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found - * the function throws. - * - * @param rawString The string in which you'd like to replace SecretsManager placeholders. - * (e.g. ThisIsA${Secret}Here - The ${Secret} would be replaced with the contents of an SecretsManager - * secret called Secret. If no such secret is found, the function throws. If no ${} are found in - * the input string, nothing is replaced and the original string is returned. + * Gets the CachableSecretsManager instance used by this handler. + * This is used by credential providers to reuse the same secrets manager instance. + * @return The CachableSecretsManager instance */ - protected String resolveSecrets(String rawString) - { - return secretsManager.resolveSecrets(rawString); - } - - protected String resolveWithDefaultCredentials(String rawString) - { - return secretsManager.resolveWithDefaultCredentials(rawString); - } - - protected String getSecret(String secretName) + public CachableSecretsManager getCachableSecretsManager() { - return secretsManager.getSecret(secretName); + return secretsManager; } /** - * Gets the CachableSecretsManager instance used by this handler. - * This is used by credential providers to reuse the same secrets manager instance. - * @return The CachableSecretsManager instance + * Gets the KmsEncryptionProvider instance used by this handler. + * @return The KmsEncryptionProvider instance */ - protected CachableSecretsManager getCachableSecretsManager() + public KmsEncryptionProvider getKmsEncryptionProvider() { - return secretsManager; + return kmsEncryptionProvider; } protected EncryptionKey makeEncryptionKey() @@ -241,13 +227,17 @@ protected EncryptionKey makeEncryptionKey(AwsRequestOverrideConfiguration awsReq /** * Used to make a spill location for a split. Each split should have a unique spill location, so be sure * to call this method once per split! - * @param request + * @param request * @return A unique spill location. */ protected SpillLocation makeSpillLocation(MetadataRequest request) { FederatedIdentity federatedIdentity = request.getIdentity(); Map configOptions = federatedIdentity.getConfigOptions(); + if (CollectionUtils.isNullOrEmpty(configOptions)) { + logger.debug("configOptions is empty from federation. Use default configOptions."); + configOptions = new HashMap<>(this.configOptions); + } String queryId = request.getQueryId(); String prefix = StringUtils.isBlank(configOptions.get(SPILL_PREFIX_ENV)) ? spillPrefix : configOptions.get(SPILL_PREFIX_ENV); @@ -292,9 +282,9 @@ public final void handleRequest(InputStream inputStream, OutputStream outputStre } protected final void doHandleRequest(BlockAllocator allocator, - ObjectMapper objectMapper, - MetadataRequest req, - OutputStream outputStream) + ObjectMapper objectMapper, + MetadataRequest req, + OutputStream outputStream) throws Exception { logger.info("doHandleRequest: request[{}]", req); @@ -465,8 +455,8 @@ public GetTableLayoutResponse doGetTableLayout(final BlockAllocator allocator, f try (ConstraintEvaluator constraintEvaluator = new ConstraintEvaluator(allocator, constraintSchema.build(), request.getConstraints()); - QueryStatusChecker queryStatusChecker = new QueryStatusChecker(getAthenaClient(overrideConfig, athena), - athenaInvoker, request.getQueryId()) + QueryStatusChecker queryStatusChecker = new QueryStatusChecker(getAthenaClient(overrideConfig, athena), + athenaInvoker, request.getQueryId()) ) { Block partitions = allocator.createBlock(partitionSchemaBuilder.build()); partitions.constrain(constraintEvaluator); @@ -513,7 +503,7 @@ public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTabl * for pushing down into the source you are querying. */ public abstract void getPartitions(final BlockWriter blockWriter, - final GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) + final GetTableLayoutRequest request, QueryStatusChecker queryStatusChecker) throws Exception; /** @@ -575,11 +565,6 @@ public void onPing(PingRequest request) //NoOp } - public AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions) - { - return getRequestOverrideConfig(configOptions, kmsEncryptionProvider); - } - /** * Helper function that is used to ensure we always have a non-null response. * diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java index 5e4727a713..6f27b33fcf 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java @@ -28,6 +28,7 @@ import com.amazonaws.athena.connector.lambda.data.S3BlockSpiller; import com.amazonaws.athena.connector.lambda.data.SpillConfig; import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintEvaluator; +import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse; @@ -43,8 +44,11 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.KmsEncryptionProvider; import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory; +import com.amazonaws.athena.connector.substrait.util.LimitAndSortHelper; import com.amazonaws.services.lambda.runtime.Context; import com.fasterxml.jackson.databind.ObjectMapper; +import io.substrait.proto.Plan; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; @@ -58,7 +62,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.Map; +import java.util.List; import static com.amazonaws.athena.connector.lambda.handlers.AthenaExceptionFilter.ATHENA_EXCEPTION_FILTER; import static com.amazonaws.athena.connector.lambda.handlers.FederationCapabilities.CAPABILITIES; @@ -111,38 +115,22 @@ public RecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, Ath } /** - * Resolves any secrets found in the supplied string, for example: MyString${WithSecret} would have ${WithSecret} - * by the corresponding value of the secret in AWS Secrets Manager with that name. If no such secret is found - * the function throws. - * - * @param rawString The string in which you'd like to replace SecretsManager placeholders. - * (e.g. ThisIsA${Secret}Here - The ${Secret} would be replaced with the contents of an SecretsManager - * secret called Secret. If no such secret is found, the function throws. If no ${} are found in - * the input string, nothing is replaced and the original string is returned. + * Gets the CachableSecretsManager instance used by this handler. + * This is used by credential providers to reuse the same secrets manager instance. + * @return The CachableSecretsManager instance */ - protected String resolveSecrets(String rawString) - { - return secretsManager.resolveSecrets(rawString); - } - - protected String resolveWithDefaultCredentials(String rawString) + public CachableSecretsManager getCachableSecretsManager() { - return secretsManager.resolveWithDefaultCredentials(rawString); - } - - protected String getSecret(String secretName) - { - return secretsManager.getSecret(secretName); + return secretsManager; } /** - * Gets the CachableSecretsManager instance used by this handler. - * This is used by credential providers to reuse the same secrets manager instance. - * @return The CachableSecretsManager instance + * Gets the KmsEncryptionProvider instance used by this handler. + * @return The KmsEncryptionProvider instance */ - protected CachableSecretsManager getCachableSecretsManager() + public KmsEncryptionProvider getKmsEncryptionProvider() { - return secretsManager; + return kmsEncryptionProvider; } public final void handleRequest(InputStream inputStream, OutputStream outputStream, final Context context) @@ -173,9 +161,9 @@ public final void handleRequest(InputStream inputStream, OutputStream outputStre } protected final void doHandleRequest(BlockAllocator allocator, - ObjectMapper objectMapper, - RecordRequest req, - OutputStream outputStream) + ObjectMapper objectMapper, + RecordRequest req, + OutputStream outputStream) throws Exception { logger.info("doHandleRequest: request[{}]", req); @@ -217,8 +205,8 @@ public RecordResponse doReadRecords(BlockAllocator allocator, ReadRecordsRequest try (ConstraintEvaluator evaluator = new ConstraintEvaluator(allocator, request.getSchema(), request.getConstraints()); - S3BlockSpiller spiller = new S3BlockSpiller(s3Client, spillConfig, allocator, request.getSchema(), evaluator, configOptions); - QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athenaClient, athenaInvoker, request.getQueryId()) + S3BlockSpiller spiller = new S3BlockSpiller(s3Client, spillConfig, allocator, request.getSchema(), evaluator, configOptions); + QueryStatusChecker queryStatusChecker = new QueryStatusChecker(athenaClient, athenaInvoker, request.getQueryId()) ) { readWithConstraint(spiller, request, queryStatusChecker); @@ -234,11 +222,6 @@ public RecordResponse doReadRecords(BlockAllocator allocator, ReadRecordsRequest } } - public AwsRequestOverrideConfiguration getRequestOverrideConfig(Map configOptions) - { - return getRequestOverrideConfig(configOptions, kmsEncryptionProvider); - } - /** * A more stream lined option for reading the row data associated with the provided Split. This method differs from * doReadRecords(...) in that the SDK handles more of the request lifecycle, leaving you to focus more closely on @@ -292,6 +275,22 @@ protected void onPing(PingRequest request) //NoOp } + /** + * Determines if a LIMIT can be applied and extracts the limit value. + */ + protected Pair getLimitFromPlan(Plan plan, Constraints constraints) + { + return LimitAndSortHelper.getLimit(plan, constraints); + } + + /** + * Extracts sort information from Substrait plan for ORDER BY pushdown optimization. + */ + protected Pair> getSortFromPlan(Plan plan) + { + return LimitAndSortHelper.getSortFromPlan(plan); + } + private void assertNotNull(FederationResponse response) { if (response == null) { diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java index 43f607ea71..34a919be1e 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/CachableSecretsManager.java @@ -25,6 +25,7 @@ import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -133,6 +134,18 @@ private String useDefaultCredentials(String secret) throws RuntimeException * @return The value of the secret, throws if no such secret is found. */ public String getSecret(String secretName) + { + return getSecret(secretName, null); + } + + /** + * Retrieves a secret from SecretsManager, first checking the cache. Newly fetched secrets are added to the cache. + * + * @param secretName The name of the secret to retrieve. + * @param overrideConfiguration override configuration for the aws request. Most commonly, for FAS_TOKEN AwsCredentials override for federation requests. + * @return The value of the secret, throws if no such secret is found. + */ + public String getSecret(String secretName, AwsRequestOverrideConfiguration overrideConfiguration) { CacheEntry cacheEntry = cache.get(secretName); @@ -140,6 +153,7 @@ public String getSecret(String secretName) logger.info("getSecret: Resolving secret[{}].", secretName); GetSecretValueResponse secretValueResult = secretsManager.getSecretValue(GetSecretValueRequest.builder() .secretId(secretName) + .overrideConfiguration(overrideConfiguration) .build()); cacheEntry = new CacheEntry(secretName, secretValueResult.secretString()); evictCache(cache.size() >= MAX_CACHE_SIZE); diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java index 39c403f67c..2b702ee6ab 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParser.java @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,6 +20,7 @@ package com.amazonaws.athena.connector.substrait; import com.amazonaws.athena.connector.substrait.model.ColumnPredicate; +import com.amazonaws.athena.connector.substrait.model.LogicalExpression; import com.amazonaws.athena.connector.substrait.model.SubstraitOperator; import io.substrait.proto.Expression; import io.substrait.proto.FunctionArgument; @@ -32,6 +33,22 @@ import java.util.List; import java.util.Map; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.AND_BOOL; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.EQUAL_ANY_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GT_ANY_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GTE_ANY_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GTE_PTS_PTS; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.GT_PTS_PTS; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.IS_NOT_NULL_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.IS_NULL_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LT_ANY_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LTE_ANY_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LTE_PTS_PTS; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.LT_PTS_PTS; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.NOT_BOOL; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.NOT_EQUAL_ANY_ANY; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.OR_BOOL; + /** * Utility class for parsing Substrait function expressions into filter predicates and column predicates. * This parser handles scalar functions, comparison operations, logical operations (AND/OR), and literal values @@ -47,7 +64,7 @@ private SubstraitFunctionParser() /** * Parses a Substrait expression into a map of column predicates grouped by column name. * This method extracts all column predicates from the expression and organizes them by the column they apply to. - * + * * @param extensionDeclarationList List of function extension declarations from the Substrait plan * @param expression The Substrait expression to parse * @param columnNames List of column names in the schema for field reference resolution @@ -68,7 +85,7 @@ public static Map> getColumnPredicatesMap(List parseColumnPredicates(List columnPredicates = new ArrayList<>(); ScalarFunctionInfo functionInfo = extractScalarFunctionInfo(expression, extensionDeclarationList); - + if (functionInfo == null) { return columnPredicates; } + // Handle NOT unary operator + if (NOT_BOOL.equals(functionInfo.getFunctionName())) { + ColumnPredicate notPredicate = handleNotOperator(functionInfo, extensionDeclarationList, columnNames); + if (notPredicate != null) { + columnPredicates.add(notPredicate); + } + return columnPredicates; + } + // Handle logical operators by flattening if (isLogicalOperator(functionInfo.getFunctionName())) { for (FunctionArgument argument : functionInfo.getArguments()) { @@ -108,10 +134,79 @@ public static List parseColumnPredicates(List extensionDeclarationList, + Expression expression, + List columnNames) + { + if (expression == null) { + return null; + } + + // Extract function information from the Substrait expression + ScalarFunctionInfo functionInfo = extractScalarFunctionInfo(expression, extensionDeclarationList); + + if (functionInfo == null) { + return null; + } + + // Handle NOT operator by delegating to handleNotOperator which converts NOT(expr) + // to appropriate predicates (NOT_EQUAL, NAND, NOR) based on the inner expression type + if (NOT_BOOL.equals(functionInfo.getFunctionName())) { + ColumnPredicate notPredicate = handleNotOperator(functionInfo, extensionDeclarationList, columnNames); + if (notPredicate != null) { + return new LogicalExpression(notPredicate); + } + return null; + } + + // Handle logical operators (AND/OR) by building tree structure instead of flattening + // This preserves the original logical hierarchy from the SQL query + if (isLogicalOperator(functionInfo.getFunctionName())) { + List childExpressions = new ArrayList<>(); + + // Recursively parse each child argument to build the expression tree + for (FunctionArgument argument : functionInfo.getArguments()) { + LogicalExpression childExpr = parseLogicalExpression(extensionDeclarationList, argument.getValue(), columnNames); + if (childExpr != null) { + childExpressions.add(childExpr); + } + } + + // Create logical expression node with the operator and its children + SubstraitOperator operator = mapToOperator(functionInfo.getFunctionName()); + return new LogicalExpression(operator, childExpressions); + } + + // Handle binary comparison operations (e.g., column = value, column > value) + if (functionInfo.getArguments().size() == 2) { + ColumnPredicate predicate = createBinaryColumnPredicate(functionInfo, columnNames); + // Wrap the predicate in a leaf LogicalExpression node + return new LogicalExpression(predicate); + } + + // Handle unary operations (e.g., column IS NULL, column IS NOT NULL) + if (functionInfo.getArguments().size() == 1) { + ColumnPredicate predicate = createUnaryColumnPredicate(functionInfo, columnNames); + // Wrap the predicate in a leaf LogicalExpression node + return new LogicalExpression(predicate); + } + + return null; + } + /** * Creates a mapping from function reference anchors to function names. * This mapping is used to resolve function references in Substrait expressions. - * + * * @param extensionDeclarationList List of extension declarations containing function definitions * @return A map from function anchor IDs to function names */ @@ -161,15 +256,15 @@ private static ScalarFunctionInfo extractScalarFunctionInfo(Expression expressio if (!expression.hasScalarFunction()) { return null; } - + Expression.ScalarFunction scalarFunction = expression.getScalarFunction(); Map functionMap = mapFunctionReferences(extensionDeclarationList); String functionName = functionMap.get(scalarFunction.getFunctionReference()); List arguments = scalarFunction.getArgumentsList(); - + return new ScalarFunctionInfo(functionName, arguments); } - + /** * Creates a column predicate for unary operations. */ @@ -179,7 +274,7 @@ private static ColumnPredicate createUnaryColumnPredicate(ScalarFunctionInfo fun SubstraitOperator substraitOperator = mapToOperator(functionInfo.getFunctionName()); return new ColumnPredicate(columnName, substraitOperator, null, null); } - + /** * Creates a column predicate for binary operations. */ @@ -190,46 +285,53 @@ private static ColumnPredicate createBinaryColumnPredicate(ScalarFunctionInfo fu SubstraitOperator substraitOperator = mapToOperator(functionInfo.getFunctionName()); return new ColumnPredicate(columnName, substraitOperator, value.getLeft(), value.getRight()); } - + /** * Checks if a function name represents a logical operator. */ private static boolean isLogicalOperator(String functionName) { - return "and:bool".equals(functionName) || "or:bool".equals(functionName); + return AND_BOOL.equals(functionName) || OR_BOOL.equals(functionName); } /** * Maps Substrait function names to corresponding Operator enum values. - * This method is mapping only small set of operators, and we will extend this as we need. - * - * @param functionName The Substrait function name (e.g., "gt:any_any", "equal:any_any") + * This method supports comparison operators, logical operators, null checks, and the NOT operator. + * The mapping will be extended as additional operators are needed. + * + * @param functionName The Substrait function name (e.g., "gt:any_any", "equal:any_any", "not:bool") * @return The corresponding Operator enum value * @throws UnsupportedOperationException if the function name is not supported */ private static SubstraitOperator mapToOperator(String functionName) { switch (functionName) { - case "gt:any_any": + case GT_ANY_ANY: + case GT_PTS_PTS: return SubstraitOperator.GREATER_THAN; - case "gte:any_any": + case GTE_ANY_ANY: + case GTE_PTS_PTS: return SubstraitOperator.GREATER_THAN_OR_EQUAL_TO; - case "lt:any_any": + case LT_ANY_ANY: + case LT_PTS_PTS: return SubstraitOperator.LESS_THAN; - case "lte:any_any": + case LTE_ANY_ANY: + case LTE_PTS_PTS: return SubstraitOperator.LESS_THAN_OR_EQUAL_TO; - case "equal:any_any": + case EQUAL_ANY_ANY: return SubstraitOperator.EQUAL; - case "not_equal:any_any": + case NOT_EQUAL_ANY_ANY: return SubstraitOperator.NOT_EQUAL; - case "is_null:any": + case IS_NULL_ANY: return SubstraitOperator.IS_NULL; - case "is_not_null:any": + case IS_NOT_NULL_ANY: return SubstraitOperator.IS_NOT_NULL; - case "and:bool": + case AND_BOOL: return SubstraitOperator.AND; - case "or:bool": + case OR_BOOL: return SubstraitOperator.OR; + case NOT_BOOL: + return SubstraitOperator.NOT; default: throw new UnsupportedOperationException("Unsupported operator function: " + functionName); } @@ -259,4 +361,142 @@ public List getArguments() return arguments; } } + + /** + * Handles NOT operator expressions by analyzing the inner expression and applying appropriate negation logic. + * Supports various NOT patterns including NOT(AND), NOT(OR), NOT IN, and simple predicate negation. + * + * @param notFunctionInfo The scalar function info for the NOT operation + * @param extensionDeclarationList List of extension declarations for function mapping + * @param columnNames List of available column names + * @return ColumnPredicate representing the negated expression, or null if not supported + */ + private static ColumnPredicate handleNotOperator( + ScalarFunctionInfo notFunctionInfo, + List extensionDeclarationList, + List columnNames) + { + if (notFunctionInfo.getArguments().size() != 1) { + return null; + } + Expression innerExpression = notFunctionInfo.getArguments().get(0).getValue(); + ScalarFunctionInfo innerFunctionInfo = extractScalarFunctionInfo(innerExpression, extensionDeclarationList); + // Case: NOT(AND(...)) => NAND + if (innerFunctionInfo != null && AND_BOOL.equals(innerFunctionInfo.getFunctionName())) { + List childPredicates = + parseColumnPredicates(extensionDeclarationList, innerExpression, columnNames); + return new ColumnPredicate( + null, + SubstraitOperator.NAND, + childPredicates, + null + ); + } + // Case: NOT(OR(...)) => NOR + if (innerFunctionInfo != null && OR_BOOL.equals(innerFunctionInfo.getFunctionName())) { + List childPredicates = + parseColumnPredicates(extensionDeclarationList, innerExpression, columnNames); + return new ColumnPredicate( + null, + SubstraitOperator.NOR, + childPredicates, + null + ); + } + // NOT IN pattern detection - reserved for future use cases + // NOTE: This handles scenarios where expressions other than direct OR operations + // may flatten to multiple EQUAL predicates on the same column. Currently, standard + // NOT(OR(...)) expressions are processed as NOR operations above. + List innerPredicates = parseColumnPredicates(extensionDeclarationList, innerExpression, columnNames); + if (isNotInPattern(innerPredicates)) { + return createNotInPredicate(innerPredicates); + } + if (innerPredicates.size() == 1) { + return createNegatedPredicate(innerPredicates.get(0)); + } + return null; + } + + /** + * Determines if a list of predicates represents a NOT IN pattern. + * A NOT IN pattern consists of multiple EQUAL predicates on the same column. + * + * @param predicates List of column predicates to analyze + * @return true if the predicates form a NOT IN pattern, false otherwise + */ + private static boolean isNotInPattern(List predicates) + { + if (predicates.size() <= 1) { + return false; + } + String firstColumn = predicates.get(0).getColumn(); + for (ColumnPredicate predicate : predicates) { + if (predicate.getOperator() != SubstraitOperator.EQUAL || + !predicate.getColumn().equals(firstColumn)) { + return false; + } + } + return true; + } + + /** + * Creates a NOT_IN predicate from a list of EQUAL predicates on the same column. + * Extracts all values from the EQUAL predicates and combines them into a single NOT_IN operation. + * + * @param equalPredicates List of EQUAL predicates on the same column + * @return ColumnPredicate with NOT_IN operator containing all excluded values, or null if input is empty + */ + private static ColumnPredicate createNotInPredicate(List equalPredicates) + { + if (equalPredicates.isEmpty()) { + return null; + } + String column = equalPredicates.get(0).getColumn(); + List excludedValues = new ArrayList<>(); + for (ColumnPredicate predicate : equalPredicates) { + excludedValues.add(predicate.getValue()); + } + return new ColumnPredicate(column, SubstraitOperator.NOT_IN, excludedValues, null); + } + + /** + * Creates a negated version of a given predicate by applying logical negation rules. + * Maps operators to their logical opposites (e.g., EQUAL → NOT_EQUAL, GREATER_THAN → LESS_THAN_OR_EQUAL_TO). + * For operators that cannot be directly negated, returns a NOT operator predicate. + * + * @param predicate The original predicate to negate + * @return ColumnPredicate representing the negated form of the input predicate + */ + private static ColumnPredicate createNegatedPredicate(ColumnPredicate predicate) + { + switch (predicate.getOperator()) { + case EQUAL: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.NOT_EQUAL, + predicate.getValue(), predicate.getArrowType()); + case NOT_EQUAL: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.EQUAL, + predicate.getValue(), predicate.getArrowType()); + case GREATER_THAN: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.LESS_THAN_OR_EQUAL_TO, + predicate.getValue(), predicate.getArrowType()); + case GREATER_THAN_OR_EQUAL_TO: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.LESS_THAN, + predicate.getValue(), predicate.getArrowType()); + case LESS_THAN: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.GREATER_THAN_OR_EQUAL_TO, + predicate.getValue(), predicate.getArrowType()); + case LESS_THAN_OR_EQUAL_TO: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.GREATER_THAN, + predicate.getValue(), predicate.getArrowType()); + case IS_NULL: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.IS_NOT_NULL, + null, predicate.getArrowType()); + case IS_NOT_NULL: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.IS_NULL, + null, predicate.getArrowType()); + default: + return new ColumnPredicate(predicate.getColumn(), SubstraitOperator.NOT, + predicate.getValue(), predicate.getArrowType()); + } + } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/LogicalExpression.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/LogicalExpression.java new file mode 100644 index 0000000000..7c42878140 --- /dev/null +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/LogicalExpression.java @@ -0,0 +1,78 @@ +/*- + * #%L + * Amazon Athena Query Federation SDK Tools + * %% + * Copyright (C) 2019 - 2025 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connector.substrait.model; + +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a logical expression tree that preserves AND/OR hierarchy from Substrait expressions. + * This allows proper conversion to MongoDB queries while maintaining the original logical structure. + */ +public class LogicalExpression +{ + private final SubstraitOperator operator; + private final List children; + private final ColumnPredicate leafPredicate; + + // Constructor for logical operators (AND, OR, NOT, etc.) + public LogicalExpression(SubstraitOperator operator, List children) + { + this.operator = operator; + this.children = new ArrayList<>(children); + this.leafPredicate = null; + } + + // Constructor for leaf predicates (EQUAL, GREATER_THAN, etc.) + public LogicalExpression(ColumnPredicate leafPredicate) + { + this.operator = leafPredicate.getOperator(); + this.children = null; + this.leafPredicate = leafPredicate; + } + + public SubstraitOperator getOperator() + { + return operator; + } + + public List getChildren() + { + return children; + } + + public ColumnPredicate getLeafPredicate() + { + return leafPredicate; + } + + public boolean isLeaf() + { + return leafPredicate != null; + } + + public boolean hasComplexLogic() + { + if (isLeaf()) { + return false; + } + return operator == SubstraitOperator.AND || operator == SubstraitOperator.OR; + } +} diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitFunctionNames.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitFunctionNames.java new file mode 100644 index 0000000000..19b2b873a0 --- /dev/null +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitFunctionNames.java @@ -0,0 +1,52 @@ +/*- + * #%L + * Amazon Athena Query Federation SDK Tools + * %% + * Copyright (C) 2019 - 2025 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connector.substrait.model; + +/** + * Constants for Substrait function names used in expression parsing. + */ +public final class SubstraitFunctionNames +{ + private SubstraitFunctionNames() + { + // Utility class - prevent instantiation + } + + // Logical operators + public static final String NOT_BOOL = "not:bool"; + public static final String AND_BOOL = "and:bool"; + public static final String OR_BOOL = "or:bool"; + + // Comparison operators + public static final String GT_ANY_ANY = "gt:any_any"; + public static final String GT_PTS_PTS = "gt:pts_pts"; + public static final String GTE_ANY_ANY = "gte:any_any"; + public static final String GTE_PTS_PTS = "gte:pts_pts"; + public static final String LT_ANY_ANY = "lt:any_any"; + public static final String LT_PTS_PTS = "lt:pts_pts"; + public static final String LTE_ANY_ANY = "lte:any_any"; + public static final String LTE_PTS_PTS = "lte:pts_pts"; + public static final String EQUAL_ANY_ANY = "equal:any_any"; + public static final String NOT_EQUAL_ANY_ANY = "not_equal:any_any"; + + // Null check operators + public static final String IS_NULL_ANY = "is_null:any"; + public static final String IS_NOT_NULL_ANY = "is_not_null:any"; +} diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitOperator.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitOperator.java index da70e9a203..2cdc98bc95 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitOperator.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/model/SubstraitOperator.java @@ -37,7 +37,10 @@ public enum SubstraitOperator IS_NOT_NULL("IS NOT NULL"), AND("AND"), OR("OR"), - NOT("NOT"); + NOT("NOT"), + NOT_IN("NOT IN"), + NOR("NOR"), + NAND("NAND"); private final String symbol; diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/util/LimitAndSortHelper.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/util/LimitAndSortHelper.java new file mode 100644 index 0000000000..1e9b74b601 --- /dev/null +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/substrait/util/LimitAndSortHelper.java @@ -0,0 +1,169 @@ +package com.amazonaws.athena.connector.substrait.util; + +/*- + * #%L + * Amazon Athena Query Federation SDK + * %% + * Copyright (C) 2019 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; +import com.amazonaws.athena.connector.substrait.SubstraitMetadataParser; +import com.amazonaws.athena.connector.substrait.model.SubstraitRelModel; +import io.substrait.proto.Expression; +import io.substrait.proto.FetchRel; +import io.substrait.proto.Plan; +import io.substrait.proto.SortRel; +import io.substrait.proto.SortField.SortDirection; +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Utility class for Substrait plan processing operations like LIMIT and SORT pushdown. + */ +public final class LimitAndSortHelper +{ + private static final Logger logger = LoggerFactory.getLogger(LimitAndSortHelper.class); + + private LimitAndSortHelper() {} + + /** + * Determines if a LIMIT can be applied and extracts the limit value from either + * the Substrait plan or constraints. + */ + public static Pair getLimit(Plan plan, Constraints constraints) + { + SubstraitRelModel substraitRelModel = null; + boolean useQueryPlan = false; + if (plan != null) { + substraitRelModel = SubstraitRelModel.buildSubstraitRelModel(plan.getRelations(0).getRoot().getInput()); + useQueryPlan = true; + } + if (canApplyLimit(constraints, substraitRelModel, useQueryPlan)) { + if (useQueryPlan) { + int limit = getLimit(substraitRelModel); + return Pair.of(true, limit); + } + else { + return Pair.of(true, (int) constraints.getLimit()); + } + } + return Pair.of(false, -1); + } + + /** + * Extracts sort information from Substrait plan for ORDER BY pushdown optimization. + */ + public static Pair> getSortFromPlan(Plan plan) + { + if (plan == null || plan.getRelationsList().isEmpty()) { + return Pair.of(false, Collections.emptyList()); + } + try { + SubstraitRelModel substraitRelModel = SubstraitRelModel.buildSubstraitRelModel( + plan.getRelations(0).getRoot().getInput()); + if (substraitRelModel.getSortRel() == null) { + return Pair.of(false, Collections.emptyList()); + } + List tableColumns = SubstraitMetadataParser.getTableColumns(substraitRelModel); + List sortFields = extractGenericSortFields(substraitRelModel.getSortRel(), tableColumns); + return Pair.of(true, sortFields); + } + catch (Exception e) { + logger.warn("Failed to extract sort from plan{}", e); + return Pair.of(false, Collections.emptyList()); + } + } + + private static boolean canApplyLimit(Constraints constraints, SubstraitRelModel substraitRelModel, boolean useQueryPlan) + { + if (useQueryPlan) { + if (substraitRelModel.getFetchRel() != null) { + return getLimit(substraitRelModel) > 0; + } + return false; + } + return constraints.hasLimit(); + } + + private static int getLimit(SubstraitRelModel substraitRelModel) + { + FetchRel fetchRel = substraitRelModel.getFetchRel(); + return (int) fetchRel.getCount(); + } + + private static List extractGenericSortFields(SortRel sortRel, List tableColumns) + { + List sortFields = new ArrayList<>(); + for (io.substrait.proto.SortField sortField : sortRel.getSortsList()) { + try { + int fieldIndex = extractFieldIndexFromExpression(sortField.getExpr()); + if (fieldIndex >= 0 && fieldIndex < tableColumns.size()) { + String columnName = tableColumns.get(fieldIndex); + SortDirection direction = sortField.getDirection(); + boolean ascending = (direction == SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST || + direction == SortDirection.SORT_DIRECTION_ASC_NULLS_LAST); + sortFields.add(new GenericSortField(columnName, ascending)); + } + } + catch (Exception e) { + logger.warn("Failed to extract sort field, skipping: {}", e.getMessage()); + } + } + return sortFields; + } + + private static int extractFieldIndexFromExpression(Expression expression) + { + if (expression.hasSelection() && expression.getSelection().hasDirectReference()) { + Expression.ReferenceSegment segment = expression.getSelection().getDirectReference(); + if (segment.hasStructField()) { + return segment.getStructField().getField(); + } + } + throw new IllegalArgumentException("Cannot extract field index from expression"); + } + + /** + * Generic sort field representation that connectors can use to build their specific sort formats. + */ + public static class GenericSortField + { + private final String columnName; + private final boolean ascending; + + public GenericSortField(String columnName, boolean ascending) + { + this.columnName = columnName; + this.ascending = ascending; + } + + public String getColumnName() + { + return columnName; + } + + public boolean isAscending() + { + return ascending; + } + } +} diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/BlockTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/BlockTest.java index bc1557c768..9bb2f06a63 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/BlockTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/BlockTest.java @@ -104,6 +104,37 @@ public void tearDown() allocator.close(); } + @Test + public void offerValueWithQueryPlanTest() throws Exception { + Schema schema = SchemaBuilder.newBuilder() + .addIntField("col1") + .addStringField("col2") + .build(); + + Block block = allocator.createBlock(schema); + + // Test with hasQueryPlan = true (should skip constraint validation) + ValueSet col1Constraint = EquatableValueSet.newBuilder(allocator, Types.MinorType.INT.getType(), true, false) + .add(10).build(); + Constraints constraints = new Constraints(Collections.singletonMap("col1", col1Constraint), Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT, Collections.emptyMap(), null); + + try (ConstraintEvaluator constraintEvaluator = new ConstraintEvaluator(allocator, schema, constraints)) { + block.constrain(constraintEvaluator); + + assertTrue(block.offerValue("col1", 0, 11, true)); + assertTrue(block.offerValue("col2", 0, "test", true)); + + assertFalse(block.offerValue("col1", 1, 11, false)); + assertTrue(block.offerValue("col1", 1, 10, false)); + + IntVector col1Vector = (IntVector) block.getFieldVector("col1"); + VarCharVector col2Vector = (VarCharVector) block.getFieldVector("col2"); + + assertEquals(11, col1Vector.get(0)); + assertEquals("test", new String(col2Vector.get(0))); + } + } + @Test public void constrainedBlockTest() throws Exception @@ -366,10 +397,10 @@ public static Schema generateTestSchema() .build()); schemaBuilder.addField(FieldBuilder.newBuilder("simplemap", new ArrowType.Map(false)) - .addField("entries", Types.MinorType.STRUCT.getType(), false, Arrays.asList( - FieldBuilder.newBuilder("key", Types.MinorType.VARCHAR.getType(), false).build(), - FieldBuilder.newBuilder("value", Types.MinorType.INT.getType()).build())) - .build()); + .addField("entries", Types.MinorType.STRUCT.getType(), false, Arrays.asList( + FieldBuilder.newBuilder("key", Types.MinorType.VARCHAR.getType(), false).build(), + FieldBuilder.newBuilder("value", Types.MinorType.INT.getType()).build())) + .build()); return schemaBuilder.build(); } diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParserTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParserTest.java index dee7551f5c..6249d0ac03 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParserTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/substrait/SubstraitFunctionParserTest.java @@ -7,9 +7,9 @@ * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -20,18 +20,28 @@ package com.amazonaws.athena.connector.substrait; import com.amazonaws.athena.connector.substrait.model.ColumnPredicate; +import com.amazonaws.athena.connector.substrait.model.LogicalExpression; import com.amazonaws.athena.connector.substrait.model.SubstraitOperator; import io.substrait.proto.Expression; import io.substrait.proto.FunctionArgument; import io.substrait.proto.SimpleExtensionDeclaration; import org.apache.arrow.vector.types.pojo.ArrowType; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.stream.Stream; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static com.amazonaws.athena.connector.substrait.model.SubstraitFunctionNames.EQUAL_ANY_ANY; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; class SubstraitFunctionParserTest { @@ -40,16 +50,16 @@ class SubstraitFunctionParserTest @Test void testGetColumnPredicatesMapWithSinglePredicate() { - SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "equal:any_any"); + SimpleExtensionDeclaration extension = createExtensionDeclaration(1, EQUAL_ANY_ANY); List extensions = Arrays.asList(extension); Expression expression = createBinaryExpression(1, 0, 123); - + Map> result = SubstraitFunctionParser.getColumnPredicatesMap(extensions, expression, COLUMN_NAMES); - + assertEquals(1, result.size()); assertTrue(result.containsKey("id")); assertEquals(1, result.get("id").size()); - + ColumnPredicate predicate = result.get("id").get(0); assertEquals("id", predicate.getColumn()); assertEquals(SubstraitOperator.EQUAL, predicate.getOperator()); @@ -63,21 +73,21 @@ void testGetColumnPredicatesMapWithMultiplePredicates() SimpleExtensionDeclaration gtExt = createExtensionDeclaration(2, "gt:any_any"); SimpleExtensionDeclaration andExt = createExtensionDeclaration(3, "and:bool"); List extensions = Arrays.asList(equalExt, gtExt, andExt); - + Expression idEquals = createBinaryExpression(1, 0, 123); Expression ageGreater = createBinaryExpression(2, 2, 18); Expression andExpression = createLogicalExpression(3, idEquals, ageGreater); - + Map> result = SubstraitFunctionParser.getColumnPredicatesMap(extensions, andExpression, COLUMN_NAMES); - + assertEquals(2, result.size()); assertTrue(result.containsKey("id")); assertTrue(result.containsKey("age")); - + ColumnPredicate idPredicate = result.get("id").get(0); assertEquals(SubstraitOperator.EQUAL, idPredicate.getOperator()); assertEquals(123, idPredicate.getValue()); - + ColumnPredicate agePredicate = result.get("age").get(0); assertEquals(SubstraitOperator.GREATER_THAN, agePredicate.getOperator()); assertEquals(18, agePredicate.getValue()); @@ -89,9 +99,9 @@ void testParseColumnPredicatesWithUnaryOperator() SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "is_null:any"); List extensions = Arrays.asList(extension); Expression expression = createUnaryExpression(1, 1); - + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, expression, COLUMN_NAMES); - + assertEquals(1, result.size()); ColumnPredicate predicate = result.get(0); assertEquals("name", predicate.getColumn()); @@ -105,9 +115,9 @@ void testParseColumnPredicatesWithStringLiteral() SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "equal:any_any"); List extensions = Arrays.asList(extension); Expression expression = createBinaryExpressionWithString(1, 1, "John"); - + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, expression, COLUMN_NAMES); - + assertEquals(1, result.size()); ColumnPredicate predicate = result.get(0); assertEquals("name", predicate.getColumn()); @@ -122,9 +132,9 @@ void testParseColumnPredicatesWithBooleanLiteral() SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "equal:any_any"); List extensions = Arrays.asList(extension); Expression expression = createBinaryExpressionWithBoolean(1, 3, true); - + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, expression, COLUMN_NAMES); - + assertEquals(1, result.size()); ColumnPredicate predicate = result.get(0); assertEquals("active", predicate.getColumn()); @@ -139,9 +149,9 @@ void testParseColumnPredicatesWithUnsupportedOperator() SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "unsupported:function"); List extensions = Arrays.asList(extension); Expression expression = createBinaryExpression(1, 0, 123); - - assertThrows(UnsupportedOperationException.class, () -> - SubstraitFunctionParser.parseColumnPredicates(extensions, expression, COLUMN_NAMES)); + + assertThrows(UnsupportedOperationException.class, () -> + SubstraitFunctionParser.parseColumnPredicates(extensions, expression, COLUMN_NAMES)); } @Test @@ -149,9 +159,9 @@ void testParseColumnPredicatesWithEmptyExpression() { List extensions = Arrays.asList(); Expression expression = Expression.newBuilder().build(); - + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, expression, COLUMN_NAMES); - + assertTrue(result.isEmpty()); } @@ -222,21 +232,6 @@ private Expression createUnaryExpression(int functionRef, int fieldIndex) .build(); } - private Expression createLogicalExpression(int functionRef, Expression left, Expression right) - { - return Expression.newBuilder() - .setScalarFunction(Expression.ScalarFunction.newBuilder() - .setFunctionReference(functionRef) - .addArguments(FunctionArgument.newBuilder() - .setValue(left) - .build()) - .addArguments(FunctionArgument.newBuilder() - .setValue(right) - .build()) - .build()) - .build(); - } - private Expression createFieldReference(int fieldIndex) { return Expression.newBuilder() @@ -268,6 +263,156 @@ private Expression createStringLiteral(String value) .build(); } + @ParameterizedTest + @MethodSource("notOperatorTestCases") + void testNotOperatorHandling(String innerOperator, SubstraitOperator expectedOperator, Object value, Object expectedValue) + { + SimpleExtensionDeclaration notExt = createExtensionDeclaration(1, "not:bool"); + SimpleExtensionDeclaration innerExt = createExtensionDeclaration(2, innerOperator); + List extensions = Arrays.asList(notExt, innerExt); + + Expression innerExpression = createBinaryExpression(2, 0, (Integer) value); + Expression notExpression = createNotExpression(1, innerExpression); + + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, notExpression, COLUMN_NAMES); + + assertEquals(1, result.size()); + ColumnPredicate predicate = result.get(0); + assertEquals("id", predicate.getColumn()); + assertEquals(expectedOperator, predicate.getOperator()); + assertEquals(expectedValue, predicate.getValue()); + } + + static Stream notOperatorTestCases() + { + return Stream.of( + Arguments.of("equal:any_any", SubstraitOperator.NOT_EQUAL, 123, 123), + Arguments.of("not_equal:any_any", SubstraitOperator.EQUAL, 456, 456), + Arguments.of("gt:any_any", SubstraitOperator.LESS_THAN_OR_EQUAL_TO, 100, 100), + Arguments.of("gte:any_any", SubstraitOperator.LESS_THAN, 200, 200), + Arguments.of("lt:any_any", SubstraitOperator.GREATER_THAN_OR_EQUAL_TO, 50, 50), + Arguments.of("lte:any_any", SubstraitOperator.GREATER_THAN, 75, 75) + ); + } + + @Test + void testNotNullOperators() + { + SimpleExtensionDeclaration notExt = createExtensionDeclaration(1, "not:bool"); + SimpleExtensionDeclaration isNullExt = createExtensionDeclaration(2, "is_null:any"); + SimpleExtensionDeclaration isNotNullExt = createExtensionDeclaration(3, "is_not_null:any"); + List extensions = Arrays.asList(notExt, isNullExt, isNotNullExt); + + // Test NOT(IS_NULL) -> IS_NOT_NULL + Expression isNullExpression = createUnaryExpression(2, 0); + Expression notIsNullExpression = createNotExpression(1, isNullExpression); + + List result1 = SubstraitFunctionParser.parseColumnPredicates(extensions, notIsNullExpression, COLUMN_NAMES); + assertEquals(1, result1.size()); + assertEquals(SubstraitOperator.IS_NOT_NULL, result1.get(0).getOperator()); + + // Test NOT(IS_NOT_NULL) -> IS_NULL + Expression isNotNullExpression = createUnaryExpression(3, 0); + Expression notIsNotNullExpression = createNotExpression(1, isNotNullExpression); + + List result2 = SubstraitFunctionParser.parseColumnPredicates(extensions, notIsNotNullExpression, COLUMN_NAMES); + assertEquals(1, result2.size()); + assertEquals(SubstraitOperator.IS_NULL, result2.get(0).getOperator()); + } + + @Test + void testNotAndOperator_NAND() + { + SimpleExtensionDeclaration notExt = createExtensionDeclaration(1, "not:bool"); + SimpleExtensionDeclaration andExt = createExtensionDeclaration(2, "and:bool"); + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(3, "equal:any_any"); + List extensions = Arrays.asList(notExt, andExt, equalExt); + + Expression idEquals = createBinaryExpression(3, 0, 123); + Expression nameEquals = createBinaryExpression(3, 1, 456); + Expression andExpression = createLogicalExpression(2, idEquals, nameEquals); + Expression notAndExpression = createNotExpression(1, andExpression); + + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, notAndExpression, COLUMN_NAMES); + + assertEquals(1, result.size()); + ColumnPredicate predicate = result.get(0); + assertNull(predicate.getColumn()); + assertEquals(SubstraitOperator.NAND, predicate.getOperator()); + assertTrue(predicate.getValue() instanceof List); + assertEquals(2, ((List) predicate.getValue()).size()); + } + + @Test + void testNotOrOperator_NOR() + { + SimpleExtensionDeclaration notExt = createExtensionDeclaration(1, "not:bool"); + SimpleExtensionDeclaration orExt = createExtensionDeclaration(2, "or:bool"); + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(3, "equal:any_any"); + List extensions = Arrays.asList(notExt, orExt, equalExt); + + Expression idEquals = createBinaryExpression(3, 0, 123); + Expression nameEquals = createBinaryExpression(3, 1, 456); + Expression orExpression = createLogicalExpression(2, idEquals, nameEquals); + Expression notOrExpression = createNotExpression(1, orExpression); + + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, notOrExpression, COLUMN_NAMES); + + assertEquals(1, result.size()); + ColumnPredicate predicate = result.get(0); + assertNull(predicate.getColumn()); + assertEquals(SubstraitOperator.NOR, predicate.getOperator()); + assertTrue(predicate.getValue() instanceof List); + assertEquals(2, ((List) predicate.getValue()).size()); + } + + @Test + void testNotInPattern() + { + SimpleExtensionDeclaration notExt = createExtensionDeclaration(1, "not:bool"); + SimpleExtensionDeclaration orExt = createExtensionDeclaration(2, "or:bool"); + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(3, "equal:any_any"); + List extensions = Arrays.asList(notExt, orExt, equalExt); + + // Create nested OR: NOT(OR(OR(id=10, id=20), id=30)) + // This should flatten to 3 EQUAL predicates on same column + Expression equal1 = createBinaryExpression(3, 0, 10); + Expression equal2 = createBinaryExpression(3, 0, 20); + Expression equal3 = createBinaryExpression(3, 0, 30); + + Expression innerOr = createLogicalExpression(2, equal1, equal2); + Expression outerOr = createLogicalExpression(2, innerOr, equal3); + Expression notInExpression = createNotExpression(1, outerOr); + + List result = SubstraitFunctionParser.parseColumnPredicates(extensions, notInExpression, COLUMN_NAMES); + + assertEquals(1, result.size()); + ColumnPredicate predicate = result.get(0); + + // Test for actual NOT_IN behavior - if pattern detection works + if (predicate.getOperator() == SubstraitOperator.NOT_IN) { + assertEquals("id", predicate.getColumn()); + assertTrue(predicate.getValue() instanceof List); + assertEquals(3, ((List) predicate.getValue()).size()); + } else { + // Current behavior: NOR instead of NOT_IN + assertNull(predicate.getColumn()); + assertEquals(SubstraitOperator.NOR, predicate.getOperator()); + } + } + + private Expression createNotExpression(int functionRef, Expression innerExpression) + { + return Expression.newBuilder() + .setScalarFunction(Expression.ScalarFunction.newBuilder() + .setFunctionReference(functionRef) + .addArguments(FunctionArgument.newBuilder() + .setValue(innerExpression) + .build()) + .build()) + .build(); + } + private Expression createBooleanLiteral(boolean value) { return Expression.newBuilder() @@ -276,4 +421,160 @@ private Expression createBooleanLiteral(boolean value) .build()) .build(); } -} \ No newline at end of file + + // Tests for parseLogicalExpression method + @Test + void testParseLogicalExpressionWithSinglePredicate() + { + // Test single binary predicate: id = 123 + SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "equal:any_any"); + List extensions = Arrays.asList(extension); + Expression expression = createBinaryExpression(1, 0, 123); + + LogicalExpression result = SubstraitFunctionParser.parseLogicalExpression(extensions, expression, COLUMN_NAMES); + + assertTrue(result.isLeaf()); + assertEquals(SubstraitOperator.EQUAL, result.getOperator()); + assertEquals("id", result.getLeafPredicate().getColumn()); + assertEquals(123, result.getLeafPredicate().getValue()); + } + + @Test + void testParseLogicalExpressionWithAndOperator() + { + // Test AND operation: id = 123 AND name = 'test' + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(1, "equal:any_any"); + SimpleExtensionDeclaration andExt = createExtensionDeclaration(2, "and:bool"); + List extensions = Arrays.asList(equalExt, andExt); + + Expression leftExpr = createBinaryExpression(1, 0, 123); + Expression rightExpr = createBinaryExpressionWithString(1, 1, "test"); + Expression andExpression = createLogicalExpression(2, leftExpr, rightExpr); + + LogicalExpression result = SubstraitFunctionParser.parseLogicalExpression(extensions, andExpression, COLUMN_NAMES); + + assertEquals(false, result.isLeaf()); + assertEquals(SubstraitOperator.AND, result.getOperator()); + assertEquals(2, result.getChildren().size()); + + // Check left child + LogicalExpression leftChild = result.getChildren().get(0); + assertTrue(leftChild.isLeaf()); + assertEquals(SubstraitOperator.EQUAL, leftChild.getOperator()); + assertEquals("id", leftChild.getLeafPredicate().getColumn()); + assertEquals(123, leftChild.getLeafPredicate().getValue()); + + // Check right child + LogicalExpression rightChild = result.getChildren().get(1); + assertTrue(rightChild.isLeaf()); + assertEquals(SubstraitOperator.EQUAL, rightChild.getOperator()); + assertEquals("name", rightChild.getLeafPredicate().getColumn()); + assertEquals("test", rightChild.getLeafPredicate().getValue()); + } + + @Test + void testParseLogicalExpressionWithOrOperator() + { + // Test OR operation: id = 123 OR name = 'test' + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(1, "equal:any_any"); + SimpleExtensionDeclaration orExt = createExtensionDeclaration(2, "or:bool"); + List extensions = Arrays.asList(equalExt, orExt); + + Expression leftExpr = createBinaryExpression(1, 0, 123); + Expression rightExpr = createBinaryExpressionWithString(1, 1, "test"); + Expression orExpression = createLogicalExpression(2, leftExpr, rightExpr); + + LogicalExpression result = SubstraitFunctionParser.parseLogicalExpression(extensions, orExpression, COLUMN_NAMES); + + assertEquals(false, result.isLeaf()); + assertEquals(SubstraitOperator.OR, result.getOperator()); + assertEquals(2, result.getChildren().size()); + assertTrue(result.hasComplexLogic()); + + // Check children are leaf predicates + assertTrue(result.getChildren().get(0).isLeaf()); + assertTrue(result.getChildren().get(1).isLeaf()); + } + + @Test + void testParseLogicalExpressionWithUnaryPredicate() + { + // Test unary predicate: id IS NULL + SimpleExtensionDeclaration extension = createExtensionDeclaration(1, "is_null:any"); + List extensions = Arrays.asList(extension); + Expression expression = createUnaryExpression(1, 0); + + LogicalExpression result = SubstraitFunctionParser.parseLogicalExpression(extensions, expression, COLUMN_NAMES); + + assertTrue(result.isLeaf()); + assertEquals(SubstraitOperator.IS_NULL, result.getOperator()); + assertEquals("id", result.getLeafPredicate().getColumn()); + } + + @Test + void testParseLogicalExpressionWithNullExpression() + { + // Test null expression + LogicalExpression result = SubstraitFunctionParser.parseLogicalExpression(Arrays.asList(), null, COLUMN_NAMES); + + assertNull(result); + } + + @Test + void testParseLogicalExpressionComplexLogicDetection() + { + // Test hasComplexLogic method + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(1, "equal:any_any"); + List extensions = Arrays.asList(equalExt); + Expression expression = createBinaryExpression(1, 0, 123); + + LogicalExpression leafResult = SubstraitFunctionParser.parseLogicalExpression(extensions, expression, COLUMN_NAMES); + assertEquals(false, leafResult.hasComplexLogic()); // Leaf nodes don't have complex logic + + // Test with OR operator + SimpleExtensionDeclaration orExt = createExtensionDeclaration(2, "or:bool"); + extensions = Arrays.asList(equalExt, orExt); + Expression leftExpr = createBinaryExpression(1, 0, 123); + Expression rightExpr = createBinaryExpressionWithString(1, 1, "test"); + Expression orExpression = createLogicalExpression(2, leftExpr, rightExpr); + + LogicalExpression orResult = SubstraitFunctionParser.parseLogicalExpression(extensions, orExpression, COLUMN_NAMES); + assertTrue(orResult.hasComplexLogic()); // OR operations have complex logic + } + + @Test + void testParseLogicalExpressionWithNotOperator() + { + // Test NOT operation: NOT(id = 123) -> should use handleNotOperator logic + SimpleExtensionDeclaration notExt = createExtensionDeclaration(1, "not:bool"); + SimpleExtensionDeclaration equalExt = createExtensionDeclaration(2, "equal:any_any"); + List extensions = Arrays.asList(notExt, equalExt); + + Expression equalExpression = createBinaryExpression(2, 0, 123); + Expression notExpression = createNotExpression(1, equalExpression); + + LogicalExpression result = SubstraitFunctionParser.parseLogicalExpression(extensions, notExpression, COLUMN_NAMES); + + // Should return a leaf expression with NOT_EQUAL predicate (from handleNotOperator) + assertTrue(result.isLeaf()); + assertEquals(SubstraitOperator.NOT_EQUAL, result.getOperator()); + assertEquals("id", result.getLeafPredicate().getColumn()); + assertEquals(123, result.getLeafPredicate().getValue()); + } + + // Helper method to create logical expressions (AND/OR) + private Expression createLogicalExpression(int functionRef, Expression left, Expression right) + { + return Expression.newBuilder() + .setScalarFunction(Expression.ScalarFunction.newBuilder() + .setFunctionReference(functionRef) + .addArguments(FunctionArgument.newBuilder() + .setValue(left) + .build()) + .addArguments(FunctionArgument.newBuilder() + .setValue(right) + .build()) + .build()) + .build(); + } +}