From 6e7e6438cb5ff8a489ccefef91f9550e8de31eef Mon Sep 17 00:00:00 2001 From: Vishesh Ruparelia Date: Wed, 18 Jun 2025 15:01:16 +0000 Subject: [PATCH] feat: add support for batch execution in parallel with custom Executor --- .../DynamoDBStreamBatchHandlerParallel.java | 37 ++++++++++ .../kinesis/KinesisBatchHandlerParallel.java | 39 +++++++++++ .../batch/sqs/SqsBatchHandlerParallel.java | 37 ++++++++++ .../batch/handler/BatchMessageHandler.java | 13 ++++ .../handler/DynamoDbBatchMessageHandler.java | 26 ++++++- .../KinesisStreamsBatchMessageHandler.java | 25 ++++++- .../batch/handler/SqsBatchMessageHandler.java | 36 +++++++++- .../batch/internal/MultiThreadMDC.java | 7 ++ .../batch/DdbBatchProcessorTest.java | 68 ++++++++++++++---- .../batch/KinesisBatchProcessorTest.java | 69 +++++++++++++++---- .../batch/SQSBatchProcessorTest.java | 68 ++++++++++++++---- 11 files changed, 378 insertions(+), 47 deletions(-) create mode 100644 examples/powertools-examples-batch/src/main/java/org/demo/batch/dynamo/DynamoDBStreamBatchHandlerParallel.java create mode 100644 examples/powertools-examples-batch/src/main/java/org/demo/batch/kinesis/KinesisBatchHandlerParallel.java create mode 100644 examples/powertools-examples-batch/src/main/java/org/demo/batch/sqs/SqsBatchHandlerParallel.java diff --git a/examples/powertools-examples-batch/src/main/java/org/demo/batch/dynamo/DynamoDBStreamBatchHandlerParallel.java b/examples/powertools-examples-batch/src/main/java/org/demo/batch/dynamo/DynamoDBStreamBatchHandlerParallel.java new file mode 100644 index 000000000..bdcc6b080 --- /dev/null +++ b/examples/powertools-examples-batch/src/main/java/org/demo/batch/dynamo/DynamoDBStreamBatchHandlerParallel.java @@ -0,0 +1,37 @@ +package org.demo.batch.dynamo; + +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestHandler; +import com.amazonaws.services.lambda.runtime.events.DynamodbEvent; +import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.lambda.powertools.batch.BatchMessageHandlerBuilder; +import software.amazon.lambda.powertools.batch.handler.BatchMessageHandler; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public class DynamoDBStreamBatchHandlerParallel implements RequestHandler { + + private static final Logger LOGGER = LoggerFactory.getLogger(DynamoDBStreamBatchHandlerParallel.class); + private final BatchMessageHandler handler; + private final ExecutorService executor; + + public DynamoDBStreamBatchHandlerParallel() { + handler = new BatchMessageHandlerBuilder() + .withDynamoDbBatchHandler() + .buildWithRawMessageHandler(this::processMessage); + executor = Executors.newFixedThreadPool(2); + } + + @Override + public StreamsEventResponse handleRequest(DynamodbEvent ddbEvent, Context context) { + return handler.processBatchInParallel(ddbEvent, context, executor); + } + + private void processMessage(DynamodbEvent.DynamodbStreamRecord dynamodbStreamRecord, Context context) { + LOGGER.info("Processing DynamoDB Stream Record" + dynamodbStreamRecord); + } + +} diff --git a/examples/powertools-examples-batch/src/main/java/org/demo/batch/kinesis/KinesisBatchHandlerParallel.java b/examples/powertools-examples-batch/src/main/java/org/demo/batch/kinesis/KinesisBatchHandlerParallel.java new file mode 100644 index 000000000..19e3201d5 --- /dev/null +++ b/examples/powertools-examples-batch/src/main/java/org/demo/batch/kinesis/KinesisBatchHandlerParallel.java @@ -0,0 +1,39 @@ +package org.demo.batch.kinesis; + +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestHandler; +import com.amazonaws.services.lambda.runtime.events.KinesisEvent; +import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse; +import org.demo.batch.model.Product; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.lambda.powertools.batch.BatchMessageHandlerBuilder; +import software.amazon.lambda.powertools.batch.handler.BatchMessageHandler; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public class KinesisBatchHandlerParallel implements RequestHandler { + + private static final Logger LOGGER = LoggerFactory.getLogger(KinesisBatchHandlerParallel.class); + private final BatchMessageHandler handler; + private final ExecutorService executor; + + + public KinesisBatchHandlerParallel() { + handler = new BatchMessageHandlerBuilder() + .withKinesisBatchHandler() + .buildWithMessageHandler(this::processMessage, Product.class); + executor = Executors.newFixedThreadPool(2); + } + + @Override + public StreamsEventResponse handleRequest(KinesisEvent kinesisEvent, Context context) { + return handler.processBatchInParallel(kinesisEvent, context, executor); + } + + private void processMessage(Product p, Context c) { + LOGGER.info("Processing product " + p); + } + +} diff --git a/examples/powertools-examples-batch/src/main/java/org/demo/batch/sqs/SqsBatchHandlerParallel.java b/examples/powertools-examples-batch/src/main/java/org/demo/batch/sqs/SqsBatchHandlerParallel.java new file mode 100644 index 000000000..21294dd55 --- /dev/null +++ b/examples/powertools-examples-batch/src/main/java/org/demo/batch/sqs/SqsBatchHandlerParallel.java @@ -0,0 +1,37 @@ +package org.demo.batch.sqs; + +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestHandler; +import com.amazonaws.services.lambda.runtime.events.SQSBatchResponse; +import com.amazonaws.services.lambda.runtime.events.SQSEvent; +import org.demo.batch.model.Product; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.lambda.powertools.batch.BatchMessageHandlerBuilder; +import software.amazon.lambda.powertools.batch.handler.BatchMessageHandler; +import software.amazon.lambda.powertools.logging.Logging; +import software.amazon.lambda.powertools.tracing.Tracing; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +public class SqsBatchHandlerParallel extends AbstractSqsBatchHandler implements RequestHandler { + private static final Logger LOGGER = LoggerFactory.getLogger(SqsBatchHandlerParallel.class); + private final BatchMessageHandler handler; + private final ExecutorService executor; + + public SqsBatchHandlerParallel() { + handler = new BatchMessageHandlerBuilder() + .withSqsBatchHandler() + .buildWithMessageHandler(this::processMessage, Product.class); + executor = Executors.newFixedThreadPool(2); + } + + @Logging + @Tracing + @Override + public SQSBatchResponse handleRequest(SQSEvent sqsEvent, Context context) { + LOGGER.info("Processing batch of {} messages", sqsEvent.getRecords().size()); + return handler.processBatchInParallel(sqsEvent, context, executor); + } +} diff --git a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/BatchMessageHandler.java b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/BatchMessageHandler.java index 18d74bb25..c63409e35 100644 --- a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/BatchMessageHandler.java +++ b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/BatchMessageHandler.java @@ -16,6 +16,9 @@ import com.amazonaws.services.lambda.runtime.Context; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; + /** * The basic interface a batch message handler must meet. * @@ -50,4 +53,14 @@ public interface BatchMessageHandler { * @return A partial batch response */ R processBatchInParallel(E event, Context context); + + + /** + * Same as {@link #processBatchInParallel(Object, Context)} but with an option to provide custom {@link Executor} + * @param event The Lambda event containing the batch to process + * @param context The lambda context + * @param executor Custom executor to use for parallel processing + * @return A partial batch response + */ + R processBatchInParallel(E event, Context context, Executor executor); } diff --git a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/DynamoDbBatchMessageHandler.java b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/DynamoDbBatchMessageHandler.java index 4b03d0947..ce68907e8 100644 --- a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/DynamoDbBatchMessageHandler.java +++ b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/DynamoDbBatchMessageHandler.java @@ -17,8 +17,13 @@ import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.events.DynamodbEvent; import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse; + +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -66,7 +71,9 @@ public StreamsEventResponse processBatchInParallel(DynamodbEvent event, Context .parallelStream() // Parallel processing .map(eventRecord -> { multiThreadMDC.copyMDCToThread(Thread.currentThread().getName()); - return processBatchItem(eventRecord, context); + Optional failureOpt = processBatchItem(eventRecord, context); + multiThreadMDC.removeThread(Thread.currentThread().getName()); + return failureOpt; }) .filter(Optional::isPresent) .map(Optional::get) @@ -75,6 +82,23 @@ public StreamsEventResponse processBatchInParallel(DynamodbEvent event, Context return StreamsEventResponse.builder().withBatchItemFailures(batchItemFailures).build(); } + @Override + public StreamsEventResponse processBatchInParallel(DynamodbEvent event, Context context, Executor executor) { + MultiThreadMDC multiThreadMDC = new MultiThreadMDC(); + + List batchItemFailures = new ArrayList<>(); + List> futures = event.getRecords().stream() + .map(eventRecord -> CompletableFuture.runAsync(() -> { + multiThreadMDC.copyMDCToThread(Thread.currentThread().getName()); + Optional failureOpt = processBatchItem(eventRecord, context); + failureOpt.ifPresent(batchItemFailures::add); + multiThreadMDC.removeThread(Thread.currentThread().getName()); + }, executor)) + .collect(Collectors.toList()); + futures.forEach(CompletableFuture::join); + return StreamsEventResponse.builder().withBatchItemFailures(batchItemFailures).build(); + } + private Optional processBatchItem(DynamodbEvent.DynamodbStreamRecord streamRecord, Context context) { try { LOGGER.debug("Processing item {}", streamRecord.getEventID()); diff --git a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/KinesisStreamsBatchMessageHandler.java b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/KinesisStreamsBatchMessageHandler.java index 7b4179de7..574256cc6 100644 --- a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/KinesisStreamsBatchMessageHandler.java +++ b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/KinesisStreamsBatchMessageHandler.java @@ -18,8 +18,12 @@ import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.events.KinesisEvent; import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse; + +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -77,7 +81,9 @@ public StreamsEventResponse processBatchInParallel(KinesisEvent event, Context c .parallelStream() // Parallel processing .map(eventRecord -> { multiThreadMDC.copyMDCToThread(Thread.currentThread().getName()); - return processBatchItem(eventRecord, context); + Optional failureOpt = processBatchItem(eventRecord, context); + multiThreadMDC.removeThread(Thread.currentThread().getName()); + return failureOpt; }) .filter(Optional::isPresent) .map(Optional::get) @@ -86,6 +92,23 @@ public StreamsEventResponse processBatchInParallel(KinesisEvent event, Context c return StreamsEventResponse.builder().withBatchItemFailures(batchItemFailures).build(); } + @Override + public StreamsEventResponse processBatchInParallel(KinesisEvent event, Context context, Executor executor) { + MultiThreadMDC multiThreadMDC = new MultiThreadMDC(); + + List batchItemFailures = new ArrayList<>(); + List> futures = event.getRecords().stream() + .map(eventRecord -> CompletableFuture.runAsync(() -> { + multiThreadMDC.copyMDCToThread(Thread.currentThread().getName()); + Optional failureOpt = processBatchItem(eventRecord, context); + failureOpt.ifPresent(batchItemFailures::add); + multiThreadMDC.removeThread(Thread.currentThread().getName()); + }, executor)) + .collect(Collectors.toList()); + futures.forEach(CompletableFuture::join); + return StreamsEventResponse.builder().withBatchItemFailures(batchItemFailures).build(); + } + private Optional processBatchItem(KinesisEvent.KinesisEventRecord eventRecord, Context context) { try { LOGGER.debug("Processing item {}", eventRecord.getEventID()); diff --git a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/SqsBatchMessageHandler.java b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/SqsBatchMessageHandler.java index 2dfb0a28e..1df3fdd1f 100644 --- a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/SqsBatchMessageHandler.java +++ b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/SqsBatchMessageHandler.java @@ -15,15 +15,20 @@ package software.amazon.lambda.powertools.batch.handler; import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.events.KinesisEvent; import com.amazonaws.services.lambda.runtime.events.SQSBatchResponse; import com.amazonaws.services.lambda.runtime.events.SQSEvent; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Collectors; + +import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.lambda.powertools.batch.internal.MultiThreadMDC; @@ -99,7 +104,7 @@ public SQSBatchResponse processBatch(SQSEvent event, Context context) { @Override public SQSBatchResponse processBatchInParallel(SQSEvent event, Context context) { - if (!event.getRecords().isEmpty() && event.getRecords().get(0).getAttributes().get(MESSAGE_GROUP_ID_KEY) != null) { + if (isFIFOEnabled(event)) { throw new UnsupportedOperationException("FIFO queues are not supported in parallel mode, use the processBatch method instead"); } @@ -109,7 +114,9 @@ public SQSBatchResponse processBatchInParallel(SQSEvent event, Context context) .map(sqsMessage -> { multiThreadMDC.copyMDCToThread(Thread.currentThread().getName()); - return processBatchItem(sqsMessage, context); + Optional failureOpt = processBatchItem(sqsMessage, context); + multiThreadMDC.removeThread(Thread.currentThread().getName()); + return failureOpt; }) .filter(Optional::isPresent) .map(Optional::get) @@ -118,6 +125,27 @@ public SQSBatchResponse processBatchInParallel(SQSEvent event, Context context) return SQSBatchResponse.builder().withBatchItemFailures(batchItemFailures).build(); } + @Override + public SQSBatchResponse processBatchInParallel(SQSEvent event, Context context, Executor executor) { + if (isFIFOEnabled(event)) { + throw new UnsupportedOperationException("FIFO queues are not supported in parallel mode, use the processBatch method instead"); + } + + MultiThreadMDC multiThreadMDC = new MultiThreadMDC(); + List batchItemFailures = new ArrayList<>(); + List> futures = event.getRecords().stream() + .map(eventRecord -> CompletableFuture.runAsync(() -> { + multiThreadMDC.copyMDCToThread(Thread.currentThread().getName()); + Optional failureOpt = processBatchItem(eventRecord, context); + failureOpt.ifPresent(batchItemFailures::add); + multiThreadMDC.removeThread(Thread.currentThread().getName()); + }, executor)) + .collect(Collectors.toList()); + futures.forEach(CompletableFuture::join); + + return SQSBatchResponse.builder().withBatchItemFailures(batchItemFailures).build(); + } + private Optional processBatchItem(SQSEvent.SQSMessage message, Context context) { try { LOGGER.debug("Processing message {}", message.getMessageId()); @@ -152,4 +180,8 @@ private Optional processBatchItem(SQSEvent.SQ .build()); } } + + private boolean isFIFOEnabled(SQSEvent sqsEvent) { + return !sqsEvent.getRecords().isEmpty() && sqsEvent.getRecords().get(0).getAttributes().get(MESSAGE_GROUP_ID_KEY) != null; + } } diff --git a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/internal/MultiThreadMDC.java b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/internal/MultiThreadMDC.java index df1c2e7a0..b2b85044b 100644 --- a/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/internal/MultiThreadMDC.java +++ b/powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/internal/MultiThreadMDC.java @@ -44,4 +44,11 @@ public void copyMDCToThread(String thread) { mdcAwareThreads.add(thread); } } + + public void removeThread(String thread) { + if (mdcAwareThreads.contains(thread)) { + LOGGER.debug("Removing thread {}", thread); + mdcAwareThreads.remove(thread); + } + } } diff --git a/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/DdbBatchProcessorTest.java b/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/DdbBatchProcessorTest.java index 6bb247323..662675de9 100644 --- a/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/DdbBatchProcessorTest.java +++ b/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/DdbBatchProcessorTest.java @@ -23,7 +23,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.params.ParameterizedTest; import org.mockito.Mock; @@ -78,6 +83,25 @@ private void processMessageInParallelFailsForFixedMessage(DynamodbEvent.Dynamodb } } + private StreamsEventResponse testParallelBatchExecution(DynamodbEvent event, + BiConsumer messageHandler, + Executor executor) { + // Arrange + BatchMessageHandler handler = new BatchMessageHandlerBuilder() + .withDynamoDbBatchHandler() + .buildWithRawMessageHandler(messageHandler); + + // Act + StreamsEventResponse dynamodbBatchResponse; + if (executor == null) { + dynamodbBatchResponse = handler.processBatchInParallel(event, context); + } else { + dynamodbBatchResponse = handler.processBatchInParallel(event, context, executor); + } + + return dynamodbBatchResponse; + } + @ParameterizedTest @Event(value = "dynamo_event.json", type = DynamodbEvent.class) void batchProcessingSucceedsAndReturns(DynamodbEvent event) { @@ -96,13 +120,7 @@ void batchProcessingSucceedsAndReturns(DynamodbEvent event) { @ParameterizedTest @Event(value = "dynamo_event_big.json", type = DynamodbEvent.class) void parallelBatchProcessingSucceedsAndReturns(DynamodbEvent event) { - // Arrange - BatchMessageHandler handler = new BatchMessageHandlerBuilder() - .withDynamoDbBatchHandler() - .buildWithRawMessageHandler(this::processMessageInParallelSucceeds); - - // Act - StreamsEventResponse dynamodbBatchResponse = handler.processBatchInParallel(event, context); + StreamsEventResponse dynamodbBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelSucceeds, null); // Assert assertThat(dynamodbBatchResponse.getBatchItemFailures()).isEmpty(); @@ -129,13 +147,7 @@ void shouldAddMessageToBatchFailure_whenException_withMessage(DynamodbEvent even @ParameterizedTest @Event(value = "dynamo_event_big.json", type = DynamodbEvent.class) void parallelBatchProcessing_shouldAddMessageToBatchFailure_whenException_withMessage(DynamodbEvent event) { - // Arrange - BatchMessageHandler handler = new BatchMessageHandlerBuilder() - .withDynamoDbBatchHandler() - .buildWithRawMessageHandler(this::processMessageInParallelFailsForFixedMessage); - - // Act - StreamsEventResponse dynamodbBatchResponse = handler.processBatchInParallel(event, context); + StreamsEventResponse dynamodbBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelFailsForFixedMessage, null); // Assert assertThat(dynamodbBatchResponse.getBatchItemFailures()).hasSize(1); @@ -196,4 +208,32 @@ void failingSuccessHandlerShouldntFailBatchButShouldFailMessage(DynamodbEvent ev assertThat(batchItemFailure.getItemIdentifier()).isEqualTo("4421584500000000017450439091"); } + @ParameterizedTest + @Event(value = "dynamo_event_big.json", type = DynamodbEvent.class) + void parallelBatchProcessingWithExecutorSucceedsAndReturns(DynamodbEvent event) { + ExecutorService executor = Executors.newFixedThreadPool(2); + + StreamsEventResponse dynamodbBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelSucceeds, executor); + executor.shutdown(); + + // Assert + assertThat(dynamodbBatchResponse.getBatchItemFailures()).isEmpty(); + assertThat(threadList).hasSizeGreaterThan(1); + } + + @ParameterizedTest + @Event(value = "dynamo_event_big.json", type = DynamodbEvent.class) + void parallelBatchProcessingWithExecutor_shouldAddMessageToBatchFailure_whenException_withMessage(DynamodbEvent event) { + ExecutorService executor = Executors.newFixedThreadPool(2); + + StreamsEventResponse dynamodbBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelFailsForFixedMessage, executor); + executor.shutdown(); + + // Assert + assertThat(dynamodbBatchResponse.getBatchItemFailures()).hasSize(1); + StreamsEventResponse.BatchItemFailure batchItemFailure = dynamodbBatchResponse.getBatchItemFailures().get(0); + assertThat(batchItemFailure.getItemIdentifier()).isEqualTo("4421584500000000017450439091"); + assertThat(threadList).hasSizeGreaterThan(1); + } + } diff --git a/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/KinesisBatchProcessorTest.java b/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/KinesisBatchProcessorTest.java index 059a4d2d0..32acde6f0 100644 --- a/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/KinesisBatchProcessorTest.java +++ b/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/KinesisBatchProcessorTest.java @@ -23,7 +23,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.params.ParameterizedTest; import org.mockito.Mock; @@ -81,6 +86,25 @@ private void processMessageInParallelFailsForFixedMessage(KinesisEvent.KinesisEv } } + private StreamsEventResponse testParallelBatchExecution(KinesisEvent event, + BiConsumer messageHandler, + Executor executor) { + // Arrange + BatchMessageHandler handler = new BatchMessageHandlerBuilder() + .withKinesisBatchHandler() + .buildWithRawMessageHandler(messageHandler); + + // Act + StreamsEventResponse kinesisBatchResponse; + if (executor == null) { + kinesisBatchResponse = handler.processBatchInParallel(event, context); + } else { + kinesisBatchResponse = handler.processBatchInParallel(event, context, executor); + } + + return kinesisBatchResponse; + } + // A handler that throws an exception for _one_ of the deserialized products in the same messages public void processMessageFailsForFixedProduct(Product product, Context context) { if (product.getId() == 1234) { @@ -106,13 +130,7 @@ void batchProcessingSucceedsAndReturns(KinesisEvent event) { @ParameterizedTest @Event(value = "kinesis_event_big.json", type = KinesisEvent.class) void batchProcessingInParallelSucceedsAndReturns(KinesisEvent event) { - // Arrange - BatchMessageHandler handler = new BatchMessageHandlerBuilder() - .withKinesisBatchHandler() - .buildWithRawMessageHandler(this::processMessageInParallelSucceeds); - - // Act - StreamsEventResponse kinesisBatchResponse = handler.processBatchInParallel(event, context); + StreamsEventResponse kinesisBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelSucceeds, null); // Assert assertThat(kinesisBatchResponse.getBatchItemFailures()).isEmpty(); @@ -140,13 +158,7 @@ void shouldAddMessageToBatchFailure_whenException_withMessage(KinesisEvent event @ParameterizedTest @Event(value = "kinesis_event_big.json", type = KinesisEvent.class) void batchProcessingInParallel_shouldAddMessageToBatchFailure_whenException_withMessage(KinesisEvent event) { - // Arrange - BatchMessageHandler handler = new BatchMessageHandlerBuilder() - .withKinesisBatchHandler() - .buildWithRawMessageHandler(this::processMessageInParallelFailsForFixedMessage); - - // Act - StreamsEventResponse kinesisBatchResponse = handler.processBatchInParallel(event, context); + StreamsEventResponse kinesisBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelFailsForFixedMessage, null); // Assert assertThat(kinesisBatchResponse.getBatchItemFailures()).hasSize(1); @@ -227,4 +239,33 @@ void failingSuccessHandlerShouldntFailBatchButShouldFailMessage(KinesisEvent eve "49545115243490985018280067714973144582180062593244200961"); } + @ParameterizedTest + @Event(value = "kinesis_event_big.json", type = KinesisEvent.class) + void batchProcessingInParallelWithExecutorSucceedsAndReturns(KinesisEvent event) { + ExecutorService executor = Executors.newFixedThreadPool(2); + + StreamsEventResponse kinesisBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelSucceeds, executor); + executor.shutdown(); + + // Assert + assertThat(kinesisBatchResponse.getBatchItemFailures()).isEmpty(); + assertThat(threadList).hasSizeGreaterThan(1); + } + + @ParameterizedTest + @Event(value = "kinesis_event_big.json", type = KinesisEvent.class) + void batchProcessingInParallelWithExecutor_shouldAddMessageToBatchFailure_whenException_withMessage(KinesisEvent event) { + ExecutorService executor = Executors.newFixedThreadPool(2); + + StreamsEventResponse kinesisBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelFailsForFixedMessage, executor); + executor.shutdown(); + + // Assert + assertThat(kinesisBatchResponse.getBatchItemFailures()).hasSize(1); + StreamsEventResponse.BatchItemFailure batchItemFailure = kinesisBatchResponse.getBatchItemFailures().get(0); + assertThat(batchItemFailure.getItemIdentifier()).isEqualTo( + "49545115243490985018280067714973144582180062593244200961"); + assertThat(threadList).hasSizeGreaterThan(1); + } + } diff --git a/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/SQSBatchProcessorTest.java b/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/SQSBatchProcessorTest.java index 7dd51374e..f13196fc4 100644 --- a/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/SQSBatchProcessorTest.java +++ b/powertools-batch/src/test/java/software/amazon/lambda/powertools/batch/SQSBatchProcessorTest.java @@ -23,7 +23,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.params.ParameterizedTest; import org.mockito.Mock; @@ -45,7 +50,7 @@ public void clear() { private void processMessageSucceeds(SQSEvent.SQSMessage sqsMessage) { } - private void processMessageInParallelSucceeds(SQSEvent.SQSMessage sqsMessage) { + private void processMessageInParallelSucceeds(SQSEvent.SQSMessage sqsMessage, Context context) { String thread = Thread.currentThread().getName(); if (!threadList.contains(thread)) { threadList.add(thread); @@ -86,6 +91,25 @@ private void processMessageFailsForFixedProduct(Product product, Context context } } + private SQSBatchResponse testParallelBatchExecution(SQSEvent event, + BiConsumer messageHandler, + Executor executor) { + // Arrange + BatchMessageHandler handler = new BatchMessageHandlerBuilder() + .withSqsBatchHandler() + .buildWithRawMessageHandler(messageHandler); + + // Act + SQSBatchResponse sqsBatchResponse; + if (executor == null) { + sqsBatchResponse = handler.processBatchInParallel(event, context); + } else { + sqsBatchResponse = handler.processBatchInParallel(event, context, executor); + } + + return sqsBatchResponse; + } + @ParameterizedTest @Event(value = "sqs_event.json", type = SQSEvent.class) void batchProcessingSucceedsAndReturns(SQSEvent event) { @@ -104,13 +128,7 @@ void batchProcessingSucceedsAndReturns(SQSEvent event) { @ParameterizedTest @Event(value = "sqs_event_big.json", type = SQSEvent.class) void parallelBatchProcessingSucceedsAndReturns(SQSEvent event) { - // Arrange - BatchMessageHandler handler = new BatchMessageHandlerBuilder() - .withSqsBatchHandler() - .buildWithRawMessageHandler(this::processMessageInParallelSucceeds); - - // Act - SQSBatchResponse sqsBatchResponse = handler.processBatchInParallel(event, context); + SQSBatchResponse sqsBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelSucceeds, null); // Assert assertThat(sqsBatchResponse.getBatchItemFailures()).isEmpty(); @@ -137,13 +155,7 @@ void shouldAddMessageToBatchFailure_whenException_withMessage(SQSEvent event) { @ParameterizedTest @Event(value = "sqs_event_big.json", type = SQSEvent.class) void parallelBatchProcessing_shouldAddMessageToBatchFailure_whenException_withMessage(SQSEvent event) { - // Arrange - BatchMessageHandler handler = new BatchMessageHandlerBuilder() - .withSqsBatchHandler() - .buildWithRawMessageHandler(this::processMessageInParallelFailsForFixedMessage); - - // Act - SQSBatchResponse sqsBatchResponse = handler.processBatchInParallel(event, context); + SQSBatchResponse sqsBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelFailsForFixedMessage, null); // Assert assertThat(sqsBatchResponse.getBatchItemFailures()).hasSize(1); @@ -238,5 +250,31 @@ void failingSuccessHandlerShouldntFailBatchButShouldFailMessage(SQSEvent event) assertThat(batchItemFailure.getItemIdentifier()).isEqualTo("e9144555-9a4f-4ec3-99a0-34ce359b4b54"); } + @ParameterizedTest + @Event(value = "sqs_event_big.json", type = SQSEvent.class) + void parallelBatchProcessingWithExecutorSucceedsAndReturns(SQSEvent event) { + ExecutorService executor = Executors.newFixedThreadPool(2); + SQSBatchResponse sqsBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelSucceeds, executor); + executor.shutdown(); + + // Assert + assertThat(sqsBatchResponse.getBatchItemFailures()).isEmpty(); + assertThat(threadList).hasSizeGreaterThan(1); + } + + @ParameterizedTest + @Event(value = "sqs_event_big.json", type = SQSEvent.class) + void parallelBatchProcessingWithExecutor_shouldAddMessageToBatchFailure_whenException_withMessage(SQSEvent event) { + ExecutorService executor = Executors.newFixedThreadPool(2); + SQSBatchResponse sqsBatchResponse = testParallelBatchExecution(event, this::processMessageInParallelFailsForFixedMessage, executor); + executor.shutdown(); + + // Assert + assertThat(sqsBatchResponse.getBatchItemFailures()).hasSize(1); + SQSBatchResponse.BatchItemFailure batchItemFailure = sqsBatchResponse.getBatchItemFailures().get(0); + assertThat(batchItemFailure.getItemIdentifier()).isEqualTo("e9144555-9a4f-4ec3-99a0-34ce359b4b54"); + assertThat(threadList).hasSizeGreaterThan(1); + } + }