diff --git a/docs/release_notes.md b/docs/release_notes.md index 20a0ef6f8..cf1d8f869 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -45,6 +45,7 @@ ### ✨ New Functionality +- [Orchestration] For streaming, add convenience configuration for output-filter-overlap, chunk-size, and delimiters via `OrchestrationModuleConfig#withStreamConfig`. - [Orchestration] Added embedding generation support with new `OrchestrationClient#embed()` methods. - Added `OrchestrationEmbeddingModel` with `TEXT_EMBEDDING_3_SMALL`, `TEXT_EMBEDDING_3_LARGE`, `AMAZON_TITAN_EMBED_TEXT` and `NVIDIA_LLAMA_32_NV_EMBEDQA_1B` embedding models. - Introduced `OrchestrationEmbeddingRequest` for building requests fluently and `OrchestrationEmbeddingResponse#getEmbeddingVectors()` to retrieve embeddings. diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java index dcd8e96dc..36cbdb9e9 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformer.java @@ -38,8 +38,11 @@ static CompletionRequestConfiguration toCompletionPostRequest( val moduleConfigs = toModuleConfigs(configCopy); + val reqConfig = + OrchestrationConfig.create().modules(moduleConfigs).stream(config.getGlobalStreamOptions()); + return CompletionRequestConfiguration.create() - .config(OrchestrationConfig.create().modules(moduleConfigs)) + .config(reqConfig) .placeholderValues(prompt.getTemplateParameters()) .messagesHistory(messageHistory); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index b7ba4568a..6ec9de4cf 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -220,7 +220,13 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig( @Nonnull public Stream streamChatCompletionDeltas( @Nonnull final CompletionRequestConfiguration request) throws OrchestrationClientException { - request.getConfig().setStream(GlobalStreamOptions.create().enabled(true).delimiters(null)); + val config = request.getConfig(); + val stream = config.getStream(); + if (stream == null) { + config.setStream(GlobalStreamOptions.create().enabled(true).delimiters(null)); + } else { + stream.enabled(true); + } return executor.stream(COMPLETION_ENDPOINT, request, customHeaders); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java index cf87316ca..bdd8da4d5 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java @@ -2,6 +2,8 @@ import com.google.common.annotations.Beta; import com.sap.ai.sdk.orchestration.model.FilteringModuleConfig; +import com.sap.ai.sdk.orchestration.model.FilteringStreamOptions; +import com.sap.ai.sdk.orchestration.model.GlobalStreamOptions; import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; import com.sap.ai.sdk.orchestration.model.InputFilteringConfig; import com.sap.ai.sdk.orchestration.model.LLMModelDetails; @@ -18,6 +20,7 @@ import javax.annotation.Nullable; import lombok.AccessLevel; import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Value; import lombok.With; @@ -101,6 +104,18 @@ public class OrchestrationModuleConfig { @Nullable SAPDocumentTranslationOutput outputTranslationConfig; + /** Configuration of optional streaming options for output filtering. */ + @With(AccessLevel.NONE) // may be exposed to public in the future + @Getter(AccessLevel.PACKAGE) + @Nullable + FilteringStreamOptions outputFilteringStreamOptions; + + /** Configuration of optional global streaming options, e.g. chunk-size. */ + @With(AccessLevel.PRIVATE) // may be exposed to public in the future + @Getter(AccessLevel.PACKAGE) + @Nullable + GlobalStreamOptions globalStreamOptions; + /** * Creates a new configuration with the given LLM configuration. * @@ -116,6 +131,19 @@ public OrchestrationModuleConfig withLlmConfig(@Nonnull final OrchestrationAiMod return withLlmConfig(aiModel.createConfig()); } + /** + * Creates a new configuration with the given stream configuration. + * + * @param config The stream configuration to use. + * @return A new configuration with the given stream configuration. + */ + @Nonnull + public OrchestrationModuleConfig withStreamConfig( + @Nonnull final OrchestrationStreamConfig config) { + return this.withOutputFilteringStreamOptions(config.createFilteringStreamOptions()) + .withGlobalStreamOptions(config.createGlobalStreamOptions()); + } + /** * Creates a new configuration with the given Data Masking configuration. * @@ -204,7 +232,10 @@ public OrchestrationModuleConfig withOutputFiltering( .map(ContentFilter::createOutputFilterConfig) .toList(); - final var outputFilter = OutputFilteringConfig.create().filters(filterConfigs); + final var outputFilter = + OutputFilteringConfig.create() + .filters(filterConfigs) + .streamOptions(outputFilteringStreamOptions); final var newFilteringConfig = FilteringModuleConfig.create() @@ -214,6 +245,33 @@ public OrchestrationModuleConfig withOutputFiltering( return this.withFilteringConfig(newFilteringConfig); } + /** + * Creates a new configuration with the given output filtering stream options. + * + * @see Orchestration + * documentation on streaming. + * @param outputFilteringStreamOptions The output filtering stream options to use. + * @return A new configuration with the given output filtering stream options. + */ + @Nonnull + OrchestrationModuleConfig withOutputFilteringStreamOptions( + @Nullable final FilteringStreamOptions outputFilteringStreamOptions) { + if (filteringConfig != null && filteringConfig.getOutput() != null) { + filteringConfig.getOutput().setStreamOptions(outputFilteringStreamOptions); + } + return new OrchestrationModuleConfig( + this.llmConfig, + this.templateConfig, + this.maskingConfig, + this.filteringConfig, + this.groundingConfig, + this.inputTranslationConfig, + this.outputTranslationConfig, + outputFilteringStreamOptions, + this.globalStreamOptions); + } + /** * Creates a new configuration with the given grounding configuration. * diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamConfig.java new file mode 100644 index 000000000..bb7d40ccc --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationStreamConfig.java @@ -0,0 +1,55 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.model.FilteringStreamOptions; +import com.sap.ai.sdk.orchestration.model.GlobalStreamOptions; +import java.util.List; +import java.util.Optional; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Value; +import lombok.With; +import lombok.val; + +/** + * Configuration for orchestration streaming options. + * + * @since 1.12.0 + */ +@Value +@With +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class OrchestrationStreamConfig { + /** + * Number of characters that should be additionally sent to content filtering services from + * previous chunks as additional context. + */ + @Nullable Integer filterOverlap; + + /** Size of the chunks the response will be split into when streaming. */ + @Nullable Integer chunkSize; + + /** List of delimiters to use for chunking the response when streaming. */ + @Nullable List delimiters; + + /** Default constructor for OrchestrationStreamConfig. */ + public OrchestrationStreamConfig() { + this(null, null, null); + } + + @Nullable + FilteringStreamOptions createFilteringStreamOptions() { + return filterOverlap == null ? null : FilteringStreamOptions.create().overlap(filterOverlap); + } + + @Nullable + GlobalStreamOptions createGlobalStreamOptions() { + if (chunkSize == null && delimiters == null) { + return null; + } + val opts = GlobalStreamOptions.create(); + Optional.ofNullable(chunkSize).ifPresent(opts::setChunkSize); + opts.setDelimiters(delimiters == null ? null : List.copyOf(delimiters)); + return opts; + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 092403158..ba389c4ba 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -46,6 +46,7 @@ import com.sap.ai.sdk.orchestration.model.DataRepositoryType; import com.sap.ai.sdk.orchestration.model.DocumentGroundingFilter; import com.sap.ai.sdk.orchestration.model.ErrorResponse; +import com.sap.ai.sdk.orchestration.model.FilteringStreamOptions; import com.sap.ai.sdk.orchestration.model.GenericModuleResult; import com.sap.ai.sdk.orchestration.model.GroundingFilterSearchConfiguration; import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; @@ -63,6 +64,7 @@ import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import io.vavr.control.Try; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; @@ -103,6 +105,8 @@ class OrchestrationUnitTest { private final Function fileLoader = filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); + private final Function fileLoaderStr = + filename -> new String(Try.of(() -> fileLoader.apply(filename).readAllBytes()).get()); private static OrchestrationClient client; private static OrchestrationModuleConfig config; @@ -259,11 +263,9 @@ void testGrounding() throws IOException { "masked_grounding_input", // maskGroundingInput: true will make this field present "[\"What does Joule do?\"]")); - try (var requestInputStream = fileLoader.apply("groundingRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify( - postRequestedFor(urlPathEqualTo("/v2/completion")).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("groundingRequest.json"); + verify( + postRequestedFor(urlPathEqualTo("/v2/completion")).withRequestBody(equalToJson(request))); } @Test @@ -293,12 +295,10 @@ void testGroundingWithHelpSapCom() throws IOException { "A fuzzy search is a search technique that is designed to be fast and tolerant of errors"); assertThat(response.getContent()).startsWith("A fuzzy search is a search technique"); - try (var requestInputStream = fileLoader.apply("groundingHelpSapComRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify( - postRequestedFor(urlPathEqualTo("/v2/completion")) - .withRequestBody(equalToJson(request, true, true))); - } + final String request = fileLoaderStr.apply("groundingHelpSapComRequest.json"); + verify( + postRequestedFor(urlPathEqualTo("/v2/completion")) + .withRequestBody(equalToJson(request, true, true))); } @Test @@ -364,10 +364,8 @@ void testTemplating() throws IOException { assertThat(usage.getTotalTokens()).isEqualTo(26); // verify that null fields are absent from the sent request - try (var requestInputStream = fileLoader.apply("templatingRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("templatingRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } @Test @@ -432,10 +430,61 @@ void filteringLoose() throws IOException { // the result is asserted in the verify step below // verify that null fields are absent from the sent request - try (var requestInputStream = fileLoader.apply("filteringLooseRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request, true, true))); - } + final String request = fileLoaderStr.apply("filteringLooseRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request, true, true))); + } + + @Test + void filteringLooseStream() throws IOException { + final var res = new String(fileLoader.apply("streamChatCompletion.txt").readAllBytes()); + stubFor( + post(anyUrl()) + .willReturn(aResponse().withBody(res).withHeader("Content-Type", "application/json"))); + + final var azureFilter = + new AzureContentFilter() + .hate(ALLOW_SAFE_LOW_MEDIUM) + .selfHarm(ALLOW_SAFE_LOW_MEDIUM) + .sexual(ALLOW_SAFE_LOW_MEDIUM) + .violence(ALLOW_SAFE_LOW_MEDIUM); + + final var llamaFilter = new LlamaGuardFilter().config(LlamaGuard38b.create().selfHarm(true)); + + OrchestrationModuleConfig myConfig = + config + .withInputFiltering(azureFilter, llamaFilter) + .withOutputFiltering(azureFilter) + .withOutputFilteringStreamOptions(FilteringStreamOptions.create().overlap(1_000)); + + Stream result = client.streamChatCompletion(prompt, myConfig); + assertThat(result).containsExactly("", "Sure", "!"); + // the result is asserted in the verify step below + + // verify that null fields are absent from the sent request + final String request = fileLoaderStr.apply("filteringLooseRequestStream.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request, true, false))); + } + + @Test + void convenienceConfig() { + final var azureFilter = new AzureContentFilter().hate(ALLOW_SAFE_LOW_MEDIUM); + + OrchestrationModuleConfig myConfig = + config + .withOutputFiltering(azureFilter) + .withOutputFilteringStreamOptions(FilteringStreamOptions.create().overlap(1_000)); + OrchestrationModuleConfig myConfig2 = + config + .withOutputFiltering(azureFilter) + .withStreamConfig(new OrchestrationStreamConfig().withFilterOverlap(1_000)); + assertThat(myConfig).isEqualTo(myConfig2); + + OrchestrationModuleConfig myConfig3 = + config + .withOutputFiltering(azureFilter) + .withStreamConfig( + new OrchestrationStreamConfig().withFilterOverlap(1_000).withChunkSize(10)); + assertThat(myConfig2).isNotEqualTo(myConfig3); } @Test @@ -572,10 +621,8 @@ void messagesHistory() throws IOException { .isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); // verify that the history is sent correctly - try (var requestInputStream = fileLoader.apply("messagesHistoryRequest.json")) { - final String requestBody = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(requestBody))); - } + final String requestBody = fileLoaderStr.apply("messagesHistoryRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(requestBody))); } @Test @@ -600,10 +647,8 @@ void maskingPseudonymization() throws IOException { assertThat(result.getContent()).contains("Hi Mallory"); // verify that the request is sent correctly - try (var requestInputStream = fileLoader.apply("maskingRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request, true, true))); - } + final String request = fileLoaderStr.apply("maskingRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request, true, true))); } private static Runnable[] errorHandlingCalls() { @@ -1032,12 +1077,10 @@ void testMultiMessage() throws IOException { assertThat(orchestrationResult.getChoices().get(0).getFinishReason()).isEqualTo("stop"); assertThat(orchestrationResult.getChoices().get(0).getMessage().getRole()).isEqualTo(ASSISTANT); - try (var requestInputStream = fileLoader.apply("multiMessageRequest.json")) { - final String requestBody = new String(requestInputStream.readAllBytes()); - verify( - postRequestedFor(urlPathEqualTo("/v2/completion")) - .withRequestBody(equalToJson(requestBody))); - } + final String requestBody = fileLoaderStr.apply("multiMessageRequest.json"); + verify( + postRequestedFor(urlPathEqualTo("/v2/completion")) + .withRequestBody(equalToJson(requestBody))); } // Example class @@ -1095,10 +1138,8 @@ class TranslationNotStaticNoConstructor { assertThat(translation.language).isEqualTo("German"); assertThat(translation.translation).isEqualTo("Apfel"); - try (var requestInputStream = fileLoader.apply("jsonSchemaRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("jsonSchemaRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } @Test @@ -1160,10 +1201,8 @@ void testResponseFormatJsonObject() throws IOException { final var message = client.chatCompletion(prompt, configWithJsonResponse).getContent(); assertThat(message).isEqualTo("{\"language\": \"German\", \"translation\": \"Apfel\"}"); - try (var requestInputStream = fileLoader.apply("jsonObjectRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("jsonObjectRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } @Test @@ -1193,10 +1232,8 @@ void testResponseFormatText() throws IOException { .isEqualTo( "```json\n{\n \"word\": \"apple\",\n \"translation\": \"Apfel\",\n \"language\": \"German\"\n}\n```"); - try (var requestInputStream = fileLoader.apply("responseFormatTextRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("responseFormatTextRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } @Test @@ -1220,10 +1257,8 @@ void testTemplateFromPromptRegistryById() throws IOException { assertThat(response.getOriginalResponse().getIntermediateResults().getTemplating()) .hasSize(2); - try (var requestInputStream = fileLoader.apply("templateReferenceByIdRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("templateReferenceByIdRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } } @@ -1246,10 +1281,8 @@ void testTemplateFromPromptRegistryByScenario() throws IOException { assertThat(response.getContent()).startsWith("I sistemi ERP (Enterprise Resource Planning)"); assertThat(response.getOriginalResponse().getIntermediateResults().getTemplating()).hasSize(2); - try (var requestInputStream = fileLoader.apply("templateReferenceByScenarioRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("templateReferenceByScenarioRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } @Test @@ -1272,10 +1305,8 @@ void testTemplateFromInput() throws IOException { final var response = client.chatCompletion(prompt, configWithTemplate); - try (var requestInputStream = fileLoader.apply("localTemplateRequest.json")) { - final String request = new String(requestInputStream.readAllBytes()); - verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); - } + final String request = fileLoaderStr.apply("localTemplateRequest.json"); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); } @Test diff --git a/orchestration/src/test/resources/filteringLooseRequestStream.json b/orchestration/src/test/resources/filteringLooseRequestStream.json new file mode 100644 index 000000000..749125d03 --- /dev/null +++ b/orchestration/src/test/resources/filteringLooseRequestStream.json @@ -0,0 +1,75 @@ +{ + "config": { + "modules": { + "prompt_templating": { + "model": { + "name": "gpt-4o", + "params": { + "temperature": 0.1, + "max_tokens": 50, + "frequency_penalty": 0, + "presence_penalty": 0, + "top_p": 1, + "n": 1 + }, + "version": "latest", + "timeout" : 600, + "max_retries" : 2 + }, + "prompt": { + "template": [ + { + "role": "user", + "content": "Hello World! Why is this phrase so famous?" + } + ], + "defaults": {}, + "tools": [] + } + }, + "filtering": { + "input": { + "filters": [ + { + "type": "azure_content_safety", + "config": { + "hate": 4, + "self_harm": 4, + "sexual": 4, + "violence": 4 + } + }, + { + "type": "llama_guard_3_8b", + "config": { + "self_harm": true + } + } + ] + }, + "output": { + "filters": [ + { + "type": "azure_content_safety", + "config": { + "hate": 4, + "self_harm": 4, + "sexual": 4, + "violence": 4 + } + } + ], + "stream_options" : { + "overlap" : 1000 + } + } + } + }, + "stream" : { + "enabled" : true, + "chunk_size" : 100 + } + }, + "placeholder_values": {}, + "messages_history": [] +} \ No newline at end of file