Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.glue.model.ErrorDetails;
Expand All @@ -49,6 +50,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
Expand Down Expand Up @@ -343,6 +345,23 @@ private Map<String, String> getRequestHeadersFromEnv()
return Collections.emptyMap();
}

/**
* Creates an AwsRequestOverrideConfiguration with custom headers from the environment
*/
private Optional<AwsRequestOverrideConfiguration> createRequestOverrideConfig()
{
Map<String, String> headers = getRequestHeadersFromEnv();
if (headers.isEmpty()) {
return Optional.empty();
}

AwsRequestOverrideConfiguration.Builder overrideConfigBuilder = AwsRequestOverrideConfiguration.builder();
for (Map.Entry<String, String> header : headers.entrySet()) {
overrideConfigBuilder.putHeader(header.getKey(), header.getValue());
}
return Optional.of(overrideConfigBuilder.build());
}

/**
* Writes (aka spills) a Block.
*/
Expand All @@ -361,12 +380,15 @@ protected SpillLocation write(Block block)

// Set the contentLength otherwise the s3 client will buffer again since it
// only sees the InputStream wrapper.
PutObjectRequest request = PutObjectRequest.builder()
PutObjectRequest.Builder requestBuilder = PutObjectRequest.builder()
.bucket(spillLocation.getBucket())
.key(spillLocation.getKey())
.contentLength((long) bytes.length)
.metadata(getRequestHeadersFromEnv())
.build();
.contentLength((long) bytes.length);

// Set request headers via overrideConfiguration instead of metadata
createRequestOverrideConfig().ifPresent(requestBuilder::overrideConfiguration);

PutObjectRequest request = requestBuilder.build();
amazonS3.putObject(request, RequestBody.fromBytes(bytes));
logger.info("write: Completed spilling block of size {} bytes", bytes.length);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.mockito.stubbing.Answer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -53,7 +54,6 @@

import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -196,6 +196,140 @@ public Object answer(InvocationOnMock invocationOnMock)
logger.info("spillTest: exit");
}

@Test
public void spillTest_WithRequestHeaders_SetsHeadersInOverrideConfiguration()
throws IOException
{
// Setup config with spill_put_request_headers for SSE-KMS
String spillHeaders = "{\"x-amz-server-side-encryption\":\"aws:kms\",\"x-amz-server-side-encryption-aws-kms-key-id\":\"arn:aws:kms:us-east-1:123456789012:key/test-key-id\"}";
java.util.Map<String, String> configOptions = com.google.common.collect.ImmutableMap.of("spill_put_request_headers", spillHeaders);

PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions);

// Verify headers are in overrideConfiguration (request headers), NOT in metadata
assertTrue("Request should have overrideConfiguration", capturedRequest.overrideConfiguration().isPresent());
AwsRequestOverrideConfiguration overrideConfig = capturedRequest.overrideConfiguration().get();

// Verify SSE-KMS headers are present in request headers
assertTrue("x-amz-server-side-encryption header should be present",
overrideConfig.headers().containsKey("x-amz-server-side-encryption"));
assertEquals("aws:kms", overrideConfig.headers().get("x-amz-server-side-encryption").get(0));

assertTrue("x-amz-server-side-encryption-aws-kms-key-id header should be present",
overrideConfig.headers().containsKey("x-amz-server-side-encryption-aws-kms-key-id"));
assertEquals("arn:aws:kms:us-east-1:123456789012:key/test-key-id",
overrideConfig.headers().get("x-amz-server-side-encryption-aws-kms-key-id").get(0));

// Verify headers are NOT in metadata
assertTrue("Metadata should be null or empty, not contain headers", capturedRequest.metadata().isEmpty());
}

@Test
public void spillTest_WithoutRequestHeaders_DoesNotSetOverrideConfiguration()
throws IOException
{
// Setup config without spill_put_request_headers
java.util.Map<String, String> configOptions = com.google.common.collect.ImmutableMap.of();

PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions);

// Verify no overrideConfiguration when headers are not configured
assertFalse("Request should not have overrideConfiguration when no headers configured",
capturedRequest.overrideConfiguration().isPresent());

// Verify metadata is null or empty
assertTrue("Metadata should be null when no headers configured", capturedRequest.metadata().isEmpty());
}

@Test
public void spillTest_WithInvalidJsonHeaders_HandlesGracefully()
throws IOException
{
// Setup config with invalid JSON in spill_put_request_headers
String invalidJson = "{\"x-amz-server-side-encryption\":\"aws:kms\"invalid}";
java.util.Map<String, String> configOptions = com.google.common.collect.ImmutableMap.of("spill_put_request_headers", invalidJson);

PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions);

// Verify no overrideConfiguration when JSON is invalid (should be handled gracefully)
assertFalse("Request should not have overrideConfiguration when JSON is invalid",
capturedRequest.overrideConfiguration().isPresent());
}

@Test
public void spillTest_WithMultipleHeaders_SetsAllHeadersInOverrideConfiguration()
throws IOException
{
// Setup config with multiple headers
String spillHeaders = "{\"x-amz-server-side-encryption\":\"aws:kms\",\"x-amz-server-side-encryption-aws-kms-key-id\":\"arn:aws:kms:us-east-1:123456789012:key/test-key-id\",\"x-amz-storage-class\":\"STANDARD_IA\"}";
java.util.Map<String, String> configOptions = com.google.common.collect.ImmutableMap.of("spill_put_request_headers", spillHeaders);

PutObjectRequest capturedRequest = executeSpillWithConfig(configOptions);

// Verify all headers are present in overrideConfiguration
assertTrue("Request should have overrideConfiguration", capturedRequest.overrideConfiguration().isPresent());
AwsRequestOverrideConfiguration overrideConfig = capturedRequest.overrideConfiguration().get();

assertEquals("aws:kms", overrideConfig.headers().get("x-amz-server-side-encryption").get(0));
assertEquals("arn:aws:kms:us-east-1:123456789012:key/test-key-id",
overrideConfig.headers().get("x-amz-server-side-encryption-aws-kms-key-id").get(0));
assertEquals("STANDARD_IA", overrideConfig.headers().get("x-amz-storage-class").get(0));

assertEquals("Should have 3 headers", 3, overrideConfig.headers().size());

// Verify headers are NOT in metadata
assertTrue("Metadata should be null, not contain headers", capturedRequest.metadata().isEmpty());
}

/**
* Helper method to create S3BlockSpiller with given config options
*/
private S3BlockSpiller createBlockSpiller(java.util.Map<String, String> configOptions)
{
return new S3BlockSpiller(mockS3, spillConfig, allocator, expected.getSchema(),
ConstraintEvaluator.emptyEvaluator(), configOptions);
}

/**
* Helper method to setup mock S3 putObject call
*/
private void setupMockPutObject(ByteHolder byteHolder)
{
when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class)))
.thenAnswer(new Answer<Object>()
{
@Override
public Object answer(InvocationOnMock invocationOnMock)
throws Throwable
{
PutObjectResponse response = PutObjectResponse.builder().build();
InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream();
byteHolder.setBytes(ByteStreams.toByteArray(inputStream));
return response;
}
});
}

/**
* Helper method to execute spill and capture PutObjectRequest
*/
private PutObjectRequest executeSpillWithConfig(java.util.Map<String, String> configOptions)
throws IOException
{
S3BlockSpiller blockWriter = createBlockSpiller(configOptions);
ByteHolder byteHolder = new ByteHolder();
ArgumentCaptor<PutObjectRequest> requestArgument = ArgumentCaptor.forClass(PutObjectRequest.class);

setupMockPutObject(byteHolder);
blockWriter.write(expected);

verify(mockS3, times(1)).putObject(requestArgument.capture(), any(RequestBody.class));
PutObjectRequest capturedRequest = requestArgument.getValue();

blockWriter.close();
return capturedRequest;
}

private class ByteHolder
{
private byte[] bytes;
Expand Down