Skip to content

Commit afb8682

Browse files
committed
Fix converse streaming issues:
- Correct finish reason when stop reason is not tool_use - Populate finish reason for non-tool_use cases - Ensure multiple tool calls are output in ChatResponse Closes gh-4374, gh-4126, gh-3251 Signed-off-by: Jared Rufer <[email protected]>
1 parent 3e17e16 commit afb8682

File tree

6 files changed

+324
-474
lines changed

6 files changed

+324
-474
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
import org.slf4j.Logger;
3535
import org.slf4j.LoggerFactory;
3636
import reactor.core.publisher.Flux;
37-
import reactor.core.publisher.Sinks;
38-
import reactor.core.publisher.Sinks.EmitFailureHandler;
3937
import reactor.core.scheduler.Schedulers;
4038
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
4139
import software.amazon.awssdk.core.SdkBytes;
@@ -51,9 +49,7 @@
5149
import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
5250
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
5351
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
54-
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
5552
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
56-
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
5753
import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock;
5854
import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource;
5955
import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
@@ -76,6 +72,7 @@
7672

7773
import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
7874
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
75+
import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream;
7976
import org.springframework.ai.bedrock.converse.api.URLValidator;
8077
import org.springframework.ai.chat.messages.AssistantMessage;
8178
import org.springframework.ai.chat.messages.MessageType;
@@ -84,6 +81,7 @@
8481
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
8582
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
8683
import org.springframework.ai.chat.metadata.DefaultUsage;
84+
import org.springframework.ai.chat.metadata.Usage;
8785
import org.springframework.ai.chat.model.ChatModel;
8886
import org.springframework.ai.chat.model.ChatResponse;
8987
import org.springframework.ai.chat.model.Generation;
@@ -680,9 +678,14 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
680678
.toolConfig(converseRequest.toolConfig())
681679
.build();
682680

683-
Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest);
681+
Usage accumulatedUsage = null;
682+
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) {
683+
accumulatedUsage = perviousChatResponse.getMetadata().getUsage();
684+
}
684685

685-
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse);
686+
Flux<ChatResponse> chatResponses = new ConverseChatResponseStream(this.bedrockRuntimeAsyncClient,
687+
converseStreamRequest, accumulatedUsage)
688+
.stream();
686689

687690
Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
688691

@@ -729,48 +732,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
729732
});
730733
}
731734

732-
public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler
733-
.busyLooping(Duration.ofSeconds(10));
734-
735-
/**
736-
* Invoke the model and return the response stream.
737-
*
738-
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
739-
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
740-
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
741-
* @param converseStreamRequest Model invocation request.
742-
* @return The model invocation response stream.
743-
*/
744-
public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseStreamRequest) {
745-
Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null");
746-
747-
Sinks.Many<ConverseStreamOutput> eventSink = Sinks.many().multicast().onBackpressureBuffer();
748-
749-
ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
750-
.onDefault(output -> {
751-
logger.debug("Received converse stream output:{}", output);
752-
eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER);
753-
})
754-
.build();
755-
756-
ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
757-
.onEventStream(stream -> stream.subscribe(e -> e.accept(visitor)))
758-
.onComplete(() -> {
759-
eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER);
760-
logger.info("Completed streaming response.");
761-
})
762-
.onError(error -> {
763-
logger.error("Error handling Bedrock converse stream response", error);
764-
eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER);
765-
})
766-
.build();
767-
768-
this.bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler);
769-
770-
return eventSink.asFlux();
771-
772-
}
773-
774735
/**
775736
* Use the provided convention for reporting observation data
776737
* @param observationConvention The provided convention

0 commit comments

Comments
 (0)