Skip to content

Commit

Permalink
.Net: Google Gemini - Adding response schema (Structured Outputs supp…
Browse files Browse the repository at this point in the history
…ort) (#10135)

### Motivation and Context

- Resolves #9501

### Description

Allow schema definition for the LLM response. Similar to `Structured
Output` concept from OpenAI.
  • Loading branch information
RogerBarreto authored Jan 10, 2025
1 parent e8b31a2 commit a76229b
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 6 deletions.
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ dotnet_diagnostic.IDE0005.severity = warning # Remove unnecessary using directiv
dotnet_diagnostic.IDE0009.severity = warning # Add this or Me qualification
dotnet_diagnostic.IDE0011.severity = warning # Add braces
dotnet_diagnostic.IDE0018.severity = warning # Inline variable declaration

dotnet_diagnostic.IDE0032.severity = warning # Use auto-implemented property
dotnet_diagnostic.IDE0034.severity = warning # Simplify 'default' expression
dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code
Expand Down Expand Up @@ -221,20 +222,29 @@ dotnet_diagnostic.RCS1241.severity = none # Implement IComparable when implement
dotnet_diagnostic.IDE0001.severity = none # Simplify name
dotnet_diagnostic.IDE0002.severity = none # Simplify member access
dotnet_diagnostic.IDE0004.severity = none # Remove unnecessary cast
dotnet_diagnostic.IDE0010.severity = none # Populate switch
dotnet_diagnostic.IDE0021.severity = none # Use block body for constructors
dotnet_diagnostic.IDE0022.severity = none # Use block body for methods
dotnet_diagnostic.IDE0024.severity = none # Use block body for operator
dotnet_diagnostic.IDE0035.severity = none # Remove unreachable code
dotnet_diagnostic.IDE0051.severity = none # Remove unused private member
dotnet_diagnostic.IDE0052.severity = none # Remove unread private member
dotnet_diagnostic.IDE0058.severity = none # Remove unused expression value
dotnet_diagnostic.IDE0059.severity = none # Unnecessary assignment of a value
dotnet_diagnostic.IDE0060.severity = none # Remove unused parameter
dotnet_diagnostic.IDE0061.severity = none # Use block body for local function
dotnet_diagnostic.IDE0079.severity = none # Remove unnecessary suppression.
dotnet_diagnostic.IDE0080.severity = none # Remove unnecessary suppression operator.
dotnet_diagnostic.IDE0100.severity = none # Remove unnecessary equality operator
dotnet_diagnostic.IDE0110.severity = none # Remove unnecessary discards
dotnet_diagnostic.IDE0130.severity = none # Namespace does not match folder structure
dotnet_diagnostic.IDE0290.severity = none # Use primary constructor
dotnet_diagnostic.IDE0032.severity = none # Use auto property
dotnet_diagnostic.IDE0160.severity = none # Use block-scoped namespace
dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations
dotnet_diagnostic.IDE0046.severity = suggestion # If statement can be simplified
dotnet_diagnostic.IDE0056.severity = suggestion # Indexing can be simplified
dotnet_diagnostic.IDE0057.severity = suggestion # Substring can be simplified

###############################
# Naming Conventions #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -419,13 +420,34 @@ public async Task ItCreatesPostRequestWithSemanticKernelVersionHeaderAsync()
Assert.Equal(expectedVersion, header);
}

[Fact]
public async Task ItCreatesPostRequestWithResponseSchemaPropertyAsync()
{
// Arrange
var client = this.CreateChatCompletionClient();
var chatHistory = CreateSampleChatHistory();
var settings = new GeminiPromptExecutionSettings { ResponseMimeType = "application/json", ResponseSchema = typeof(List<int>) };

// Act
await client.GenerateChatMessageAsync(chatHistory, settings);

// Assert
Assert.NotNull(this._messageHandlerStub.RequestHeaders);

var responseBody = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);

Assert.Contains("responseSchema", responseBody, StringComparison.Ordinal);
Assert.Contains("\"responseSchema\":{\"type\":\"array\",\"items\":{\"type\":\"integer\"}}", responseBody, StringComparison.Ordinal);
Assert.Contains("\"responseMimeType\":\"application/json\"", responseBody, StringComparison.Ordinal);
}

[Fact]
public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
{
// Arrange
var bearerTokenGenerator = new BearerTokenGenerator()
{
BearerKeys = new List<string> { "key1", "key2", "key3" }
BearerKeys = ["key1", "key2", "key3"]
};

var responseContent = File.ReadAllText(ChatTestDataFilePath);
Expand All @@ -442,7 +464,7 @@ public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
httpClient: httpClient,
modelId: "fake-model",
apiVersion: VertexAIVersion.V1,
bearerTokenProvider: () => bearerTokenGenerator.GetBearerToken(),
bearerTokenProvider: bearerTokenGenerator.GetBearerToken,
location: "fake-location",
projectId: "fake-project-id");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
Expand All @@ -25,7 +26,8 @@ public void FromPromptItReturnsWithConfiguration()
MaxTokens = 10,
TopP = 0.9,
AudioTimestamp = true,
ResponseMimeType = "application/json"
ResponseMimeType = "application/json",
ResponseSchema = JsonSerializer.Deserialize<JsonElement>(@"{""schema"":""schema""}")
};

// Act
Expand All @@ -37,9 +39,120 @@ public void FromPromptItReturnsWithConfiguration()
Assert.Equal(executionSettings.MaxTokens, request.Configuration.MaxOutputTokens);
Assert.Equal(executionSettings.AudioTimestamp, request.Configuration.AudioTimestamp);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema);
Assert.Equal(executionSettings.TopP, request.Configuration.TopP);
}

[Fact]
public void JsonElementResponseSchemaFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int))
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema);
}

[Fact]
public void KernelJsonSchemaFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = KernelJsonSchemaBuilder.Build(typeof(int))
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.Equal(((KernelJsonSchema)executionSettings.ResponseSchema).RootElement, request.Configuration.ResponseSchema);
}

[Fact]
public void JsonNodeResponseSchemaFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = JsonNode.Parse(Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)).GetRawText())
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.NotNull(request.Configuration.ResponseSchema);
Assert.Equal(JsonSerializer.SerializeToElement(executionSettings.ResponseSchema).GetRawText(), request.Configuration.ResponseSchema.Value.GetRawText());
}

[Fact]
public void JsonDocumentResponseSchemaFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = JsonDocument.Parse(Microsoft.Extensions.AI.AIJsonUtilities.CreateJsonSchema(typeof(int)).GetRawText())
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.NotNull(request.Configuration.ResponseSchema);
Assert.Equal(JsonSerializer.SerializeToElement(executionSettings.ResponseSchema).GetRawText(), request.Configuration.ResponseSchema.Value.GetRawText());
}

[Theory]
[InlineData(typeof(int), "integer")]
[InlineData(typeof(bool), "boolean")]
[InlineData(typeof(string), "string")]
[InlineData(typeof(double), "number")]
[InlineData(typeof(GeminiRequest), "object")]
[InlineData(typeof(List<int>), "array")]
public void TypeResponseSchemaFromPromptReturnsAsExpected(Type type, string expectedSchemaType)
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = type
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
var schemaType = request.Configuration.ResponseSchema?.GetProperty("type").GetString();

Assert.Equal(expectedSchemaType, schemaType);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
}

[Fact]
public void FromPromptItReturnsWithSafetySettings()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public void ItCreatesGeminiExecutionSettingsWithCorrectDefaults()
Assert.Null(executionSettings.SafetySettings);
Assert.Null(executionSettings.AudioTimestamp);
Assert.Null(executionSettings.ResponseMimeType);
Assert.Null(executionSettings.ResponseSchema);
Assert.Equal(GeminiPromptExecutionSettings.DefaultTextMaxTokens, executionSettings.MaxTokens);
}

Expand Down Expand Up @@ -70,7 +71,8 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase()
{ "max_tokens", 1000 },
{ "temperature", 0 },
{ "audio_timestamp", true },
{ "response_mimetype", "application/json" }
{ "response_mimetype", "application/json" },
{ "response_schema", JsonSerializer.Serialize(new { }) }
}
};

Expand All @@ -81,6 +83,9 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase()
Assert.NotNull(executionSettings);
Assert.Equal(1000, executionSettings.MaxTokens);
Assert.Equal(0, executionSettings.Temperature);
Assert.Equal("application/json", executionSettings.ResponseMimeType);
Assert.NotNull(executionSettings.ResponseSchema);
Assert.Equal(typeof(JsonElement), executionSettings.ResponseSchema.GetType());
Assert.True(executionSettings.AudioTimestamp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,25 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.Google.Core;

internal sealed class GeminiRequest
{
private static JsonSerializerOptions? s_options;
private static readonly AIJsonSchemaCreateOptions s_schemaOptions = new()
{
IncludeSchemaKeyword = false,
IncludeTypeInEnumSchemas = true,
RequireAllProperties = false,
DisallowAdditionalProperties = false,
};

[JsonPropertyName("contents")]
public IList<GeminiContent> Contents { get; set; } = null!;

Expand Down Expand Up @@ -249,10 +261,57 @@ private static void AddConfiguration(GeminiPromptExecutionSettings executionSett
StopSequences = executionSettings.StopSequences,
CandidateCount = executionSettings.CandidateCount,
AudioTimestamp = executionSettings.AudioTimestamp,
ResponseMimeType = executionSettings.ResponseMimeType
ResponseMimeType = executionSettings.ResponseMimeType,
ResponseSchema = GetResponseSchemaConfig(executionSettings.ResponseSchema)
};
}

private static JsonElement? GetResponseSchemaConfig(object? responseSchemaSettings)
{
if (responseSchemaSettings is null)
{
return null;
}

var jsonElement = responseSchemaSettings switch
{
JsonElement element => element,
Type type => CreateSchema(type, GetDefaultOptions()),
KernelJsonSchema kernelJsonSchema => kernelJsonSchema.RootElement,
JsonNode jsonNode => JsonSerializer.SerializeToElement(jsonNode, GetDefaultOptions()),
JsonDocument jsonDocument => JsonSerializer.SerializeToElement(jsonDocument, GetDefaultOptions()),
_ => CreateSchema(responseSchemaSettings.GetType(), GetDefaultOptions())
};

return jsonElement;
}

private static JsonElement CreateSchema(
Type type,
JsonSerializerOptions options,
string? description = null,
AIJsonSchemaCreateOptions? configuration = null)
{
configuration ??= s_schemaOptions;
return AIJsonUtilities.CreateJsonSchema(type, description, serializerOptions: options, inferenceOptions: configuration);
}

private static JsonSerializerOptions GetDefaultOptions()
{
if (s_options is null)
{
JsonSerializerOptions options = new()
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
};
options.MakeReadOnly();
s_options = options;
}

return s_options;
}

private static void AddSafetySettings(GeminiPromptExecutionSettings executionSettings, GeminiRequest request)
{
request.SafetySettings = executionSettings.SafetySettings?.Select(s
Expand Down Expand Up @@ -292,5 +351,9 @@ internal sealed class ConfigurationElement
[JsonPropertyName("responseMimeType")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ResponseMimeType { get; set; }

[JsonPropertyName("responseSchema")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public JsonElement? ResponseSchema { get; set; }
}
}
Loading

0 comments on commit a76229b

Please sign in to comment.