diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java index 6965b9aa62..9848e51b3c 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java @@ -35,6 +35,7 @@ import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.glue.model.ErrorDetails; @@ -49,6 +50,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.RejectedExecutionException; @@ -343,6 +345,23 @@ private Map getRequestHeadersFromEnv() return Collections.emptyMap(); } + /** + * Creates an AwsRequestOverrideConfiguration with custom headers from the environment + */ + private Optional createRequestOverrideConfig() + { + Map headers = getRequestHeadersFromEnv(); + if (headers.isEmpty()) { + return Optional.empty(); + } + + AwsRequestOverrideConfiguration.Builder overrideConfigBuilder = AwsRequestOverrideConfiguration.builder(); + for (Map.Entry header : headers.entrySet()) { + overrideConfigBuilder.putHeader(header.getKey(), header.getValue()); + } + return Optional.of(overrideConfigBuilder.build()); + } + /** * Writes (aka spills) a Block. */ @@ -361,12 +380,15 @@ protected SpillLocation write(Block block) // Set the contentLength otherwise the s3 client will buffer again since it // only sees the InputStream wrapper. - PutObjectRequest request = PutObjectRequest.builder() + PutObjectRequest.Builder requestBuilder = PutObjectRequest.builder() .bucket(spillLocation.getBucket()) .key(spillLocation.getKey()) - .contentLength((long) bytes.length) - .metadata(getRequestHeadersFromEnv()) - .build(); + .contentLength((long) bytes.length); + + // Set request headers via overrideConfiguration instead of metadata + createRequestOverrideConfig().ifPresent(requestBuilder::overrideConfiguration); + + PutObjectRequest request = requestBuilder.build(); amazonS3.putObject(request, RequestBody.fromBytes(bytes)); logger.info("write: Completed spilling block of size {} bytes", bytes.length); diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java index 0abc45c3ec..718b21ef5c 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java @@ -39,6 +39,7 @@ import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.s3.S3Client; @@ -53,7 +54,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -196,6 +196,140 @@ public Object answer(InvocationOnMock invocationOnMock) logger.info("spillTest: exit"); } + @Test + public void spillTest_WithRequestHeaders_SetsHeadersInOverrideConfiguration() + throws IOException + { + // Setup config with spill_put_request_headers for SSE-KMS + String spillHeaders = "{\"x-amz-server-side-encryption\":\"aws:kms\",\"x-amz-server-side-encryption-aws-kms-key-id\":\"arn:aws:kms:us-east-1:123456789012:key/test-key-id\"}"; + java.util.Map configOptions = com.google.common.collect.ImmutableMap.of("spill_put_request_headers", spillHeaders); + + PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions); + + // Verify headers are in overrideConfiguration (request headers), NOT in metadata + assertTrue("Request should have overrideConfiguration", capturedRequest.overrideConfiguration().isPresent()); + AwsRequestOverrideConfiguration overrideConfig = capturedRequest.overrideConfiguration().get(); + + // Verify SSE-KMS headers are present in request headers + assertTrue("x-amz-server-side-encryption header should be present", + overrideConfig.headers().containsKey("x-amz-server-side-encryption")); + assertEquals("aws:kms", overrideConfig.headers().get("x-amz-server-side-encryption").get(0)); + + assertTrue("x-amz-server-side-encryption-aws-kms-key-id header should be present", + overrideConfig.headers().containsKey("x-amz-server-side-encryption-aws-kms-key-id")); + assertEquals("arn:aws:kms:us-east-1:123456789012:key/test-key-id", + overrideConfig.headers().get("x-amz-server-side-encryption-aws-kms-key-id").get(0)); + + // Verify headers are NOT in metadata + assertTrue("Metadata should be null or empty, not contain headers", capturedRequest.metadata().isEmpty()); + } + + @Test + public void spillTest_WithoutRequestHeaders_DoesNotSetOverrideConfiguration() + throws IOException + { + // Setup config without spill_put_request_headers + java.util.Map configOptions = com.google.common.collect.ImmutableMap.of(); + + PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions); + + // Verify no overrideConfiguration when headers are not configured + assertFalse("Request should not have overrideConfiguration when no headers configured", + capturedRequest.overrideConfiguration().isPresent()); + + // Verify metadata is null or empty + assertTrue("Metadata should be null when no headers configured", capturedRequest.metadata().isEmpty()); + } + + @Test + public void spillTest_WithInvalidJsonHeaders_HandlesGracefully() + throws IOException + { + // Setup config with invalid JSON in spill_put_request_headers + String invalidJson = "{\"x-amz-server-side-encryption\":\"aws:kms\"invalid}"; + java.util.Map configOptions = com.google.common.collect.ImmutableMap.of("spill_put_request_headers", invalidJson); + + PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions); + + // Verify no overrideConfiguration when JSON is invalid (should be handled gracefully) + assertFalse("Request should not have overrideConfiguration when JSON is invalid", + capturedRequest.overrideConfiguration().isPresent()); + } + + @Test + public void spillTest_WithMultipleHeaders_SetsAllHeadersInOverrideConfiguration() + throws IOException + { + // Setup config with multiple headers + String spillHeaders = "{\"x-amz-server-side-encryption\":\"aws:kms\",\"x-amz-server-side-encryption-aws-kms-key-id\":\"arn:aws:kms:us-east-1:123456789012:key/test-key-id\",\"x-amz-storage-class\":\"STANDARD_IA\"}"; + java.util.Map configOptions = com.google.common.collect.ImmutableMap.of("spill_put_request_headers", spillHeaders); + + PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions); + + // Verify all headers are present in overrideConfiguration + assertTrue("Request should have overrideConfiguration", capturedRequest.overrideConfiguration().isPresent()); + AwsRequestOverrideConfiguration overrideConfig = capturedRequest.overrideConfiguration().get(); + + assertEquals("aws:kms", overrideConfig.headers().get("x-amz-server-side-encryption").get(0)); + assertEquals("arn:aws:kms:us-east-1:123456789012:key/test-key-id", + overrideConfig.headers().get("x-amz-server-side-encryption-aws-kms-key-id").get(0)); + assertEquals("STANDARD_IA", overrideConfig.headers().get("x-amz-storage-class").get(0)); + + assertEquals("Should have 3 headers", 3, overrideConfig.headers().size()); + + // Verify headers are NOT in metadata + assertTrue("Metadata should be null, not contain headers", capturedRequest.metadata().isEmpty()); + } + + /** + * Helper method to create S3BlockSpiller with given config options + */ + private S3BlockSpiller createBlockSpiller(java.util.Map configOptions) + { + return new S3BlockSpiller(mockS3, spillConfig, allocator, expected.getSchema(), + ConstraintEvaluator.emptyEvaluator(), configOptions); + } + + /** + * Helper method to setup mock S3 putObject call + */ + private void setupMockPutObject(ByteHolder byteHolder) + { + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .thenAnswer(new Answer() + { + @Override + public Object answer(InvocationOnMock invocationOnMock) + throws Throwable + { + PutObjectResponse response = PutObjectResponse.builder().build(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); + return response; + } + }); + } + + /** + * Helper method to execute spill and capture PutObjectRequest + */ + private PutObjectRequest executeSpillWithConfig(java.util.Map configOptions) + throws IOException + { + S3BlockSpiller blockWriter = createBlockSpiller(configOptions); + ByteHolder byteHolder = new ByteHolder(); + ArgumentCaptor requestArgument = ArgumentCaptor.forClass(PutObjectRequest.class); + + setupMockPutObject(byteHolder); + blockWriter.write(expected); + + verify(mockS3, times(1)).putObject(requestArgument.capture(), any(RequestBody.class)); + PutObjectRequest capturedRequest = requestArgument.getValue(); + + blockWriter.close(); + return capturedRequest; + } + private class ByteHolder { private byte[] bytes;