|
34 | 34 | import org.slf4j.Logger;
|
35 | 35 | import org.slf4j.LoggerFactory;
|
36 | 36 | import reactor.core.publisher.Flux;
|
37 |
| -import reactor.core.publisher.Sinks; |
38 |
| -import reactor.core.publisher.Sinks.EmitFailureHandler; |
39 | 37 | import reactor.core.scheduler.Schedulers;
|
40 | 38 | import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
41 | 39 | import software.amazon.awssdk.core.SdkBytes;
|
|
51 | 49 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
|
52 | 50 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
|
53 | 51 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
|
54 |
| -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; |
55 | 52 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
|
56 |
| -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; |
57 | 53 | import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock;
|
58 | 54 | import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource;
|
59 | 55 | import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
|
|
76 | 72 |
|
77 | 73 | import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
|
78 | 74 | import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
|
| 75 | +import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream; |
79 | 76 | import org.springframework.ai.bedrock.converse.api.URLValidator;
|
80 | 77 | import org.springframework.ai.chat.messages.AssistantMessage;
|
81 | 78 | import org.springframework.ai.chat.messages.MessageType;
|
|
84 | 81 | import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
85 | 82 | import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
86 | 83 | import org.springframework.ai.chat.metadata.DefaultUsage;
|
| 84 | +import org.springframework.ai.chat.metadata.Usage; |
87 | 85 | import org.springframework.ai.chat.model.ChatModel;
|
88 | 86 | import org.springframework.ai.chat.model.ChatResponse;
|
89 | 87 | import org.springframework.ai.chat.model.Generation;
|
@@ -680,9 +678,14 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
|
680 | 678 | .toolConfig(converseRequest.toolConfig())
|
681 | 679 | .build();
|
682 | 680 |
|
683 |
| - Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest); |
| 681 | + Usage accumulatedUsage = null; |
| 682 | + if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) { |
| 683 | + accumulatedUsage = perviousChatResponse.getMetadata().getUsage(); |
| 684 | + } |
684 | 685 |
|
685 |
| - Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse); |
| 686 | + Flux<ChatResponse> chatResponses = new ConverseChatResponseStream(this.bedrockRuntimeAsyncClient, |
| 687 | + converseStreamRequest, accumulatedUsage) |
| 688 | + .stream(); |
686 | 689 |
|
687 | 690 | Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
|
688 | 691 |
|
@@ -729,48 +732,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
|
729 | 732 | });
|
730 | 733 | }
|
731 | 734 |
|
732 |
| - public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler |
733 |
| - .busyLooping(Duration.ofSeconds(10)); |
734 |
| - |
735 |
| - /** |
736 |
| - * Invoke the model and return the response stream. |
737 |
| - * |
738 |
| - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html |
739 |
| - * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html |
740 |
| - * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream |
741 |
| - * @param converseStreamRequest Model invocation request. |
742 |
| - * @return The model invocation response stream. |
743 |
| - */ |
744 |
| - public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseStreamRequest) { |
745 |
| - Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null"); |
746 |
| - |
747 |
| - Sinks.Many<ConverseStreamOutput> eventSink = Sinks.many().multicast().onBackpressureBuffer(); |
748 |
| - |
749 |
| - ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder() |
750 |
| - .onDefault(output -> { |
751 |
| - logger.debug("Received converse stream output:{}", output); |
752 |
| - eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER); |
753 |
| - }) |
754 |
| - .build(); |
755 |
| - |
756 |
| - ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder() |
757 |
| - .onEventStream(stream -> stream.subscribe(e -> e.accept(visitor))) |
758 |
| - .onComplete(() -> { |
759 |
| - eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER); |
760 |
| - logger.info("Completed streaming response."); |
761 |
| - }) |
762 |
| - .onError(error -> { |
763 |
| - logger.error("Error handling Bedrock converse stream response", error); |
764 |
| - eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER); |
765 |
| - }) |
766 |
| - .build(); |
767 |
| - |
768 |
| - this.bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler); |
769 |
| - |
770 |
| - return eventSink.asFlux(); |
771 |
| - |
772 |
| - } |
773 |
| - |
774 | 735 | /**
|
775 | 736 | * Use the provided convention for reporting observation data
|
776 | 737 | * @param observationConvention The provided convention
|
|
0 commit comments