Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ public abstract ValueTask<string> GetUserAuthorizationTokenAsync(
AuthorizationTokenType tokenType,
ITrace trace);

public abstract ValueTask AddInferenceAuthorizationHeaderAsync(
INameValueCollection headersCollection,
Uri requestAddress,
string verb,
AuthorizationTokenType tokenType);

public abstract void TraceUnauthorized(
DocumentClientException dce,
string authorizationToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ private void Dispose(bool disposing)
this.authKeyHashFunction = null;
}

public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
{
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
}

// Use C# finalizer syntax for finalization code.
// This finalizer will run only if the Dispose method does not get called.
// It gives your base class the opportunity to finalize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ private void Dispose(bool disposing)
// Do nothing
}

public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
{
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
}

// Use C# finalizer syntax for finalization code.
// This finalizer will run only if the Dispose method does not get called.
// It gives your base class the opportunity to finalize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ namespace Microsoft.Azure.Cosmos

internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationTokenProvider
{
private const string InferenceTokenPrefix = "Bearer ";
internal readonly TokenCredentialCache tokenCredentialCache;
private bool isDisposed = false;

internal readonly TokenCredential tokenCredential;

public AuthorizationTokenProviderTokenCredential(
TokenCredential tokenCredential,
Uri accountEndpoint,
TimeSpan? backgroundTokenCredentialRefreshInterval)
{
this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential));
this.tokenCredentialCache = new TokenCredentialCache(
tokenCredential: tokenCredential,
accountEndpoint: accountEndpoint,
Expand Down Expand Up @@ -71,6 +75,21 @@ public override async ValueTask AddAuthorizationHeaderAsync(
}
}

public override async ValueTask AddInferenceAuthorizationHeaderAsync(
INameValueCollection headersCollection,
Uri requestAddress,
string verb,
AuthorizationTokenType tokenType)
{
using (Trace trace = Trace.GetRootTrace(nameof(GetUserAuthorizationTokenAsync), TraceComponent.Authorization, TraceLevel.Info))
{
string token = await this.tokenCredentialCache.GetTokenAsync(trace);

string inferenceToken = InferenceTokenPrefix + token;
headersCollection.Add(HttpConstants.HttpHeaders.Authorization, inferenceToken);
}
}

public override void TraceUnauthorized(
DocumentClientException dce,
string authorizationToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,10 @@ private void CheckAndRefreshTokenProvider()
}
}
}

public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
{
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
}
}
}
162 changes: 162 additions & 0 deletions Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using global::Azure.Core;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Collections;

internal class InferenceService : IDisposable
{
private const string basePath = "dbinference.azure.com/inference/semanticReranking";
private const string inferenceUserAgent = "cosmos-inference-dotnet";
private const string inferenceServiceDefaultScope = "https://dbinference.azure.com/.default";

private readonly Uri inferenceEndpoint;
private readonly HttpClient httpClient;
private readonly AuthorizationTokenProvider cosmosAuthorization;

private bool disposedValue;

public InferenceService(CosmosClient client, AccountProperties accountProperties)
{
//Create HttpClient
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
gatewayModeMaxConnectionLimit: client.DocumentClient.ConnectionPolicy.MaxConnectionLimit,
webProxy: null,
serverCertificateCustomValidationCallback: client.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback);

this.httpClient = new HttpClient(httpMessageHandler);

this.CreateClientHelper(this.httpClient);

//Set endpoints
this.inferenceEndpoint = new Uri($"https://{accountProperties.Id}.{basePath}");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, please add an environment variable where the inference endpoint can be set up.

AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT


//set authorization
if (client.DocumentClient.cosmosAuthorization.GetType() != typeof(AuthorizationTokenProviderTokenCredential))
{
throw new InvalidOperationException("InferenceService only supports AAD authentication.");
}

AuthorizationTokenProviderTokenCredential defaultOperationTokenProvider = client.DocumentClient.cosmosAuthorization as AuthorizationTokenProviderTokenCredential;
TokenCredential tokenCredential = defaultOperationTokenProvider.tokenCredential;

this.cosmosAuthorization = new AuthorizationTokenProviderTokenCredential(
tokenCredential: tokenCredential,
accountEndpoint: new Uri(inferenceServiceDefaultScope),
backgroundTokenCredentialRefreshInterval: client.ClientOptions?.TokenCredentialBackgroundRefreshInterval);
}

public async Task<IReadOnlyDictionary<string, dynamic>> SemanticRerankAsync(
string renrankContext,
IEnumerable<string> documents,
SemanticRerankRequestOptions options = null,
CancellationToken cancellationToken = default)
{
HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint);
INameValueCollection additionalHeaders = new RequestNameValueCollection();
await this.cosmosAuthorization.AddInferenceAuthorizationHeaderAsync(
headersCollection: additionalHeaders,
this.inferenceEndpoint,
HttpConstants.HttpMethods.Post,
AuthorizationTokenType.AadToken);
additionalHeaders.Add(HttpConstants.HttpHeaders.UserAgent, inferenceUserAgent);

foreach (string key in additionalHeaders.AllKeys())
{
message.Headers.Add(key, additionalHeaders[key]);
}

Dictionary<string, dynamic> body = this.AddSemanticRerankPayload(renrankContext, documents, options);

message.Content = new StringContent(
Newtonsoft.Json.JsonConvert.SerializeObject(body),
Encoding.UTF8,
RuntimeConstants.MediaTypes.Json);

HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken);
responseMessage.EnsureSuccessStatusCode();

// return the content of the responsemessage as a dictonary
string content = await responseMessage.Content.ReadAsStringAsync();
return Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, dynamic>>(content);
}

private void CreateClientHelper(HttpClient httpClient)
{
httpClient.Timeout = TimeSpan.FromSeconds(120);
httpClient.DefaultRequestHeaders.CacheControl = new CacheControlHeaderValue { NoCache = true };

// Set requested API version header that can be used for
// version enforcement.
httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Version,
HttpConstants.Versions.CurrentVersion);

httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Accept, RuntimeConstants.MediaTypes.Json);
}

private Dictionary<string, dynamic> AddSemanticRerankPayload(string rerankContext, IEnumerable<string> documents, SemanticRerankRequestOptions options)
{
Dictionary<string, dynamic> payload = new Dictionary<string, dynamic>
{
{ "query", rerankContext },
{ "documents", documents.ToArray() }
};

if (options == null)
{
return payload;
}

payload["return_documents"] = options.ReturnDocuments;
if (options.TopK > -1)
{
payload["top_k"] = options.TopK;
}
if (options.BatchSize > -1)
{
payload["batch_size"] = options.BatchSize;
}
payload["sort"] = options.Sort;
if (!string.IsNullOrEmpty(options.DocumentType))
{
payload["document_type"] = options.DocumentType;
}
if (!string.IsNullOrEmpty(options.TargetPaths))
{
payload["target_paths"] = options.TargetPaths;
}

return payload;
}

protected void Dispose(bool disposing)
{
if (!this.disposedValue)
{
if (disposing)
{
this.httpClient.Dispose();
}

this.disposedValue = true;
}
}

public void Dispose()
{
this.Dispose(true);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
/// <summary>
/// Request options for semantic rerank operations in Azure Cosmos DB.
/// </summary>
public class SemanticRerankRequestOptions : RequestOptions
{
/// <summary>
/// Gets or sets a value indicating whether to return the documents text in the response. Default is true.
/// </summary>
public bool ReturnDocuments { get; set; } = true;

/// <summary>
/// Gets or sets the number of top documents to return. Default all documents are returned.
/// </summary>
public int TopK { get; set; } = -1;

/// <summary>
/// Batch size for internal scoring operations
/// </summary>
public int BatchSize { get; set; } = -1;

/// <summary>
/// Whether to sort the results by relevance score in descending order.
/// </summary>
public bool Sort { get; set; } = true;

/// <summary>
/// Type of document being processed. Supported values are "string" and "json".
/// </summary>
public string DocumentType { get; set; }

/// <summary>
/// If document type is "json", the list of JSON paths to extract text from for reranking. Comma separated string.
/// </summary>
public string TargetPaths { get; set; }

}
}
29 changes: 29 additions & 0 deletions Microsoft.Azure.Cosmos/src/Resource/ClientContextCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Microsoft.Azure.Cosmos
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
Expand Down Expand Up @@ -34,6 +35,7 @@ internal class ClientContextCore : CosmosClientContext

private readonly string userAgent;
private bool isDisposed = false;
private InferenceService inferenceService = null;

private ClientContextCore(
CosmosClient client,
Expand Down Expand Up @@ -467,6 +469,32 @@ await this.DocumentClient.OpenConnectionsToAllReplicasAsync(
cancellationToken);
}

internal override async Task<IReadOnlyDictionary<string, dynamic>> SemanticRerankAsync(
string renrankContext,
IEnumerable<string> documents,
SemanticRerankRequestOptions options = null,
CancellationToken cancellationToken = default)
{
InferenceService inferenceService = await this.GetOrCreateInferenceServiceAsync();
return await inferenceService.SemanticRerankAsync(renrankContext, documents, options, cancellationToken);
}

/// <inheritdoc/>
internal override async Task<InferenceService> GetOrCreateInferenceServiceAsync()
{
AccountProperties accountProperties = await this.client.DocumentClient.GlobalEndpointManager.GetDatabaseAccountAsync() ?? throw new InvalidOperationException("Failed to retrieve AccountProperties. The response was null.");
if (this.inferenceService == null)
{
// Double check locking to avoid unnecessary locks
lock (this)
{
this.inferenceService ??= new InferenceService(this.client, accountProperties);
}
}

return this.inferenceService;
}

public override void Dispose()
{
this.Dispose(true);
Expand All @@ -484,6 +512,7 @@ protected virtual void Dispose(bool disposing)
{
this.batchExecutorCache.Dispose();
this.DocumentClient.Dispose();
this.inferenceService?.Dispose();
}

this.isDisposed = true;
Expand Down
16 changes: 16 additions & 0 deletions Microsoft.Azure.Cosmos/src/Resource/Container/Container.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,22 @@ public abstract ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilder(
public abstract ChangeFeedProcessorBuilder GetChangeFeedProcessorBuilderWithManualCheckpoint(
string processorName,
ChangeFeedStreamHandlerWithManualCheckpoint onChangesDelegate);

/// <summary>
/// Rerank a list of documents using semantic reranking.
/// This method uses a semantic reranker to score and reorder the provided documents
/// based on their relevance to the given reranking context.
/// </summary>
/// <param name="renrankContext"> The context or query string to use for reranking the documents.</param>
/// <param name="documents"> A list of documents to be reranked</param>
/// <param name="options"> (Optional) The options for the semantic reranking request.</param>
/// <param name="cancellationToken">(Optional) <see cref="CancellationToken"/> representing request cancellation.</param>
/// <returns> The reranking results, typically including the reranked documents and their scores. </returns>
public abstract Task<IReadOnlyDictionary<string, dynamic>> SemanticRerankAsync(
string renrankContext,
IEnumerable<string> documents,
SemanticRerankRequestOptions options = null,
CancellationToken cancellationToken = default);

/// <summary>
/// Deletes all items in the Container with the specified <see cref="PartitionKey"/> value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,5 +697,14 @@ public override Task<bool> IsFeedRangePartOfAsync(
y,
cancellationToken: cancellationToken));
}

public override Task<IReadOnlyDictionary<string, dynamic>> SemanticRerankAsync(
string renrankContext,
IEnumerable<string> documents,
SemanticRerankRequestOptions options = null,
CancellationToken cancellationToken = default)
{
return this.ClientContext.SemanticRerankAsync(renrankContext, documents, options, cancellationToken);
}
}
}
Loading
Loading