From a76229bfc2058028381a56ebe35bf73c507e1e2a Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:32:35 +0000 Subject: [PATCH] .Net: Google Gemini - Adding response schema (Structured Outputs support) (#10135) ### Motivation and Context - Resolves #9501 ### Description Allow schema definition for the LLM response. Similar to `Structured Output` concept from OpenAI. --- .editorconfig | 10 ++ .../Clients/GeminiChatGenerationTests.cs | 26 +++- .../Core/Gemini/GeminiRequestTests.cs | 115 +++++++++++++++++- .../GeminiPromptExecutionSettingsTests.cs | 7 +- .../Core/Gemini/Models/GeminiRequest.cs | 65 +++++++++- .../GeminiPromptExecutionSettings.cs | 31 ++++- 6 files changed, 248 insertions(+), 6 deletions(-) diff --git a/.editorconfig b/.editorconfig index ee1a588e7e30..14d07f4147f6 100644 --- a/.editorconfig +++ b/.editorconfig @@ -136,6 +136,7 @@ dotnet_diagnostic.IDE0005.severity = warning # Remove unnecessary using directiv dotnet_diagnostic.IDE0009.severity = warning # Add this or Me qualification dotnet_diagnostic.IDE0011.severity = warning # Add braces dotnet_diagnostic.IDE0018.severity = warning # Inline variable declaration + dotnet_diagnostic.IDE0032.severity = warning # Use auto-implemented property dotnet_diagnostic.IDE0034.severity = warning # Simplify 'default' expression dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code @@ -221,20 +222,29 @@ dotnet_diagnostic.RCS1241.severity = none # Implement IComparable when implement dotnet_diagnostic.IDE0001.severity = none # Simplify name dotnet_diagnostic.IDE0002.severity = none # Simplify member access dotnet_diagnostic.IDE0004.severity = none # Remove unnecessary cast +dotnet_diagnostic.IDE0010.severity = none # Populate switch +dotnet_diagnostic.IDE0021.severity = none # Use block body for constructors +dotnet_diagnostic.IDE0022.severity = none # Use block body for methods +dotnet_diagnostic.IDE0024.severity = none # Use block body for operator dotnet_diagnostic.IDE0035.severity = none # Remove unreachable code dotnet_diagnostic.IDE0051.severity = none # Remove unused private member dotnet_diagnostic.IDE0052.severity = none # Remove unread private member dotnet_diagnostic.IDE0058.severity = none # Remove unused expression value dotnet_diagnostic.IDE0059.severity = none # Unnecessary assignment of a value dotnet_diagnostic.IDE0060.severity = none # Remove unused parameter +dotnet_diagnostic.IDE0061.severity = none # Use block body for local function dotnet_diagnostic.IDE0079.severity = none # Remove unnecessary suppression. dotnet_diagnostic.IDE0080.severity = none # Remove unnecessary suppression operator. dotnet_diagnostic.IDE0100.severity = none # Remove unnecessary equality operator dotnet_diagnostic.IDE0110.severity = none # Remove unnecessary discards dotnet_diagnostic.IDE0130.severity = none # Namespace does not match folder structure +dotnet_diagnostic.IDE0290.severity = none # Use primary constructor dotnet_diagnostic.IDE0032.severity = none # Use auto property dotnet_diagnostic.IDE0160.severity = none # Use block-scoped namespace dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations +dotnet_diagnostic.IDE0046.severity = suggestion # If statement can be simplified +dotnet_diagnostic.IDE0056.severity = suggestion # Indexing can be simplified +dotnet_diagnostic.IDE0057.severity = suggestion # Substring can be simplified ############################### # Naming Conventions # diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs index 987e55f703bb..a27844d02f48 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Net.Http; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.SemanticKernel.ChatCompletion; @@ -419,13 +420,34 @@ public async Task ItCreatesPostRequestWithSemanticKernelVersionHeaderAsync() Assert.Equal(expectedVersion, header); } + [Fact] + public async Task ItCreatesPostRequestWithResponseSchemaPropertyAsync() + { + // Arrange + var client = this.CreateChatCompletionClient(); + var chatHistory = CreateSampleChatHistory(); + var settings = new GeminiPromptExecutionSettings { ResponseMimeType = "application/json", ResponseSchema = typeof(List) }; + + // Act + await client.GenerateChatMessageAsync(chatHistory, settings); + + // Assert + Assert.NotNull(this._messageHandlerStub.RequestHeaders); + + var responseBody = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + + Assert.Contains("responseSchema", responseBody, StringComparison.Ordinal); + Assert.Contains("\"responseSchema\":{\"type\":\"array\",\"items\":{\"type\":\"integer\"}}", responseBody, StringComparison.Ordinal); + Assert.Contains("\"responseMimeType\":\"application/json\"", responseBody, StringComparison.Ordinal); + } + [Fact] public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync() { // Arrange var bearerTokenGenerator = new BearerTokenGenerator() { - BearerKeys = new List { "key1", "key2", "key3" } + BearerKeys = ["key1", "key2", "key3"] }; var responseContent = File.ReadAllText(ChatTestDataFilePath); @@ -442,7 +464,7 @@ public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync() httpClient: httpClient, modelId: "fake-model", apiVersion: VertexAIVersion.V1, - bearerTokenProvider: () => bearerTokenGenerator.GetBearerToken(), + bearerTokenProvider: bearerTokenGenerator.GetBearerToken, location: "fake-location", projectId: "fake-project-id"); diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index c6701ee09b84..55283d191a84 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Text.Json.Nodes; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -25,7 +26,8 @@ public void FromPromptItReturnsWithConfiguration() MaxTokens = 10, TopP = 0.9, AudioTimestamp = true, - ResponseMimeType = "application/json" + ResponseMimeType = "application/json", + ResponseSchema = JsonSerializer.Deserialize(@"{""schema"":""schema""}") }; // Act @@ -37,9 +39,120 @@ public void FromPromptItReturnsWithConfiguration() Assert.Equal(executionSettings.MaxTokens, request.Configuration.MaxOutputTokens); Assert.Equal(executionSettings.AudioTimestamp, request.Configuration.AudioTimestamp); Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType); + Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema); Assert.Equal(executionSettings.TopP, request.Configuration.TopP); } + [Fact] + public void JsonElementResponseSchemaFromPromptReturnsAsExpected() + { + // Arrange + var prompt = "prompt-example"; + var executionSettings = new GeminiPromptExecutionSettings + { + ResponseMimeType = "application/json", + ResponseSchema = Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)) + }; + + // Act + var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); + + // Assert + Assert.NotNull(request.Configuration); + Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType); + Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema); + } + + [Fact] + public void KernelJsonSchemaFromPromptReturnsAsExpected() + { + // Arrange + var prompt = "prompt-example"; + var executionSettings = new GeminiPromptExecutionSettings + { + ResponseMimeType = "application/json", + ResponseSchema = KernelJsonSchemaBuilder.Build(typeof(int)) + }; + + // Act + var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); + + // Assert + Assert.NotNull(request.Configuration); + Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType); + Assert.Equal(((KernelJsonSchema)executionSettings.ResponseSchema).RootElement, request.Configuration.ResponseSchema); + } + + [Fact] + public void JsonNodeResponseSchemaFromPromptReturnsAsExpected() + { + // Arrange + var prompt = "prompt-example"; + var executionSettings = new GeminiPromptExecutionSettings + { + ResponseMimeType = "application/json", + ResponseSchema = JsonNode.Parse(Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)).GetRawText()) + }; + + // Act + var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); + + // Assert + Assert.NotNull(request.Configuration); + Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType); + Assert.NotNull(request.Configuration.ResponseSchema); + Assert.Equal(JsonSerializer.SerializeToElement(executionSettings.ResponseSchema).GetRawText(), request.Configuration.ResponseSchema.Value.GetRawText()); + } + + [Fact] + public void JsonDocumentResponseSchemaFromPromptReturnsAsExpected() + { + // Arrange + var prompt = "prompt-example"; + var executionSettings = new GeminiPromptExecutionSettings + { + ResponseMimeType = "application/json", + ResponseSchema = JsonDocument.Parse(Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)).GetRawText()) + }; + + // Act + var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); + + // Assert + Assert.NotNull(request.Configuration); + Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType); + Assert.NotNull(request.Configuration.ResponseSchema); + Assert.Equal(JsonSerializer.SerializeToElement(executionSettings.ResponseSchema).GetRawText(), request.Configuration.ResponseSchema.Value.GetRawText()); + } + + [Theory] + [InlineData(typeof(int), "integer")] + [InlineData(typeof(bool), "boolean")] + [InlineData(typeof(string), "string")] + [InlineData(typeof(double), "number")] + [InlineData(typeof(GeminiRequest), "object")] + [InlineData(typeof(List), "array")] + public void TypeResponseSchemaFromPromptReturnsAsExpected(Type type, string expectedSchemaType) + { + // Arrange + var prompt = "prompt-example"; + var executionSettings = new GeminiPromptExecutionSettings + { + ResponseMimeType = "application/json", + ResponseSchema = type + }; + + // Act + var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); + + // Assert + Assert.NotNull(request.Configuration); + var schemaType = request.Configuration.ResponseSchema?.GetProperty("type").GetString(); + + Assert.Equal(expectedSchemaType, schemaType); + Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType); + } + [Fact] public void FromPromptItReturnsWithSafetySettings() { diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs index b13a2e397ec7..0d2955f18d7f 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs @@ -28,6 +28,7 @@ public void ItCreatesGeminiExecutionSettingsWithCorrectDefaults() Assert.Null(executionSettings.SafetySettings); Assert.Null(executionSettings.AudioTimestamp); Assert.Null(executionSettings.ResponseMimeType); + Assert.Null(executionSettings.ResponseSchema); Assert.Equal(GeminiPromptExecutionSettings.DefaultTextMaxTokens, executionSettings.MaxTokens); } @@ -70,7 +71,8 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase() { "max_tokens", 1000 }, { "temperature", 0 }, { "audio_timestamp", true }, - { "response_mimetype", "application/json" } + { "response_mimetype", "application/json" }, + { "response_schema", JsonSerializer.Serialize(new { }) } } }; @@ -81,6 +83,9 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase() Assert.NotNull(executionSettings); Assert.Equal(1000, executionSettings.MaxTokens); Assert.Equal(0, executionSettings.Temperature); + Assert.Equal("application/json", executionSettings.ResponseMimeType); + Assert.NotNull(executionSettings.ResponseSchema); + Assert.Equal(typeof(JsonElement), executionSettings.ResponseSchema.GetType()); Assert.True(executionSettings.AudioTimestamp); } diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index 7787122756be..2ebda2c2a0de 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -4,13 +4,25 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Connectors.Google.Core; internal sealed class GeminiRequest { + private static JsonSerializerOptions? s_options; + private static readonly AIJsonSchemaCreateOptions s_schemaOptions = new() + { + IncludeSchemaKeyword = false, + IncludeTypeInEnumSchemas = true, + RequireAllProperties = false, + DisallowAdditionalProperties = false, + }; + [JsonPropertyName("contents")] public IList Contents { get; set; } = null!; @@ -249,10 +261,57 @@ private static void AddConfiguration(GeminiPromptExecutionSettings executionSett StopSequences = executionSettings.StopSequences, CandidateCount = executionSettings.CandidateCount, AudioTimestamp = executionSettings.AudioTimestamp, - ResponseMimeType = executionSettings.ResponseMimeType + ResponseMimeType = executionSettings.ResponseMimeType, + ResponseSchema = GetResponseSchemaConfig(executionSettings.ResponseSchema) }; } + private static JsonElement? GetResponseSchemaConfig(object? responseSchemaSettings) + { + if (responseSchemaSettings is null) + { + return null; + } + + var jsonElement = responseSchemaSettings switch + { + JsonElement element => element, + Type type => CreateSchema(type, GetDefaultOptions()), + KernelJsonSchema kernelJsonSchema => kernelJsonSchema.RootElement, + JsonNode jsonNode => JsonSerializer.SerializeToElement(jsonNode, GetDefaultOptions()), + JsonDocument jsonDocument => JsonSerializer.SerializeToElement(jsonDocument, GetDefaultOptions()), + _ => CreateSchema(responseSchemaSettings.GetType(), GetDefaultOptions()) + }; + + return jsonElement; + } + + private static JsonElement CreateSchema( + Type type, + JsonSerializerOptions options, + string? description = null, + AIJsonSchemaCreateOptions? configuration = null) + { + configuration ??= s_schemaOptions; + return AIJsonUtilities.CreateJsonSchema(type, description, serializerOptions: options, inferenceOptions: configuration); + } + + private static JsonSerializerOptions GetDefaultOptions() + { + if (s_options is null) + { + JsonSerializerOptions options = new() + { + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + Converters = { new JsonStringEnumConverter() }, + }; + options.MakeReadOnly(); + s_options = options; + } + + return s_options; + } + private static void AddSafetySettings(GeminiPromptExecutionSettings executionSettings, GeminiRequest request) { request.SafetySettings = executionSettings.SafetySettings?.Select(s @@ -292,5 +351,9 @@ internal sealed class ConfigurationElement [JsonPropertyName("responseMimeType")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public string? ResponseMimeType { get; set; } + + [JsonPropertyName("responseSchema")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public JsonElement? ResponseSchema { get; set; } } } diff --git a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs index cfb07941a393..fab00f01e11d 100644 --- a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs @@ -5,6 +5,7 @@ using System.Collections.ObjectModel; using System.Linq; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Text; @@ -25,6 +26,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings private IList? _stopSequences; private bool? _audioTimestamp; private string? _responseMimeType; + private object? _responseSchema; private IList? _safetySettings; private GeminiToolCallBehavior? _toolCallBehavior; @@ -206,6 +208,32 @@ public string? ResponseMimeType } } + /// + /// Optional. Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be objects, primitives or arrays. + /// If set, a compatible responseMimeType must also be set. Compatible MIME types: application/json: Schema for JSON response. + /// Refer to the https://ai.google.dev/gemini-api/docs/json-mode for more information. + /// + /// + /// Possible values are: + /// - which will be used to automatically generate a JSON schema. + /// - schema definition, which will be used as is. + /// - schema definition, which will be used as is. + /// - schema definition, which will be used as is. + /// - object, where none of the above matches which the type will be used to automatically generate a JSON schema. + /// + [JsonPropertyName("response_schema")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? ResponseSchema + { + get => this._responseSchema; + + set + { + this.ThrowIfFrozen(); + this._responseSchema = value; + } + } + /// public override void Freeze() { @@ -243,7 +271,8 @@ public override PromptExecutionSettings Clone() SafetySettings = this.SafetySettings?.Select(setting => new GeminiSafetySetting(setting)).ToList(), ToolCallBehavior = this.ToolCallBehavior?.Clone(), AudioTimestamp = this.AudioTimestamp, - ResponseMimeType = this.ResponseMimeType + ResponseMimeType = this.ResponseMimeType, + ResponseSchema = this.ResponseSchema, }; }