Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ public abstract ValueTask<string> GetUserAuthorizationTokenAsync(
AuthorizationTokenType tokenType,
ITrace trace);

public abstract ValueTask AddInferenceAuthorizationHeaderAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO let's not overload the core types with inference specific?

INameValueCollection headersCollection,
Uri requestAddress,
string verb,
AuthorizationTokenType tokenType);

public abstract void TraceUnauthorized(
DocumentClientException dce,
string authorizationToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ private void Dispose(bool disposing)
this.authKeyHashFunction = null;
}

public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
{
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
}

// Use C# finalizer syntax for finalization code.
// This finalizer will run only if the Dispose method does not get called.
// It gives your base class the opportunity to finalize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ private void Dispose(bool disposing)
// Do nothing
}

public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
{
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
}

// Use C# finalizer syntax for finalization code.
// This finalizer will run only if the Dispose method does not get called.
// It gives your base class the opportunity to finalize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ namespace Microsoft.Azure.Cosmos

internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationTokenProvider
{
private const string InferenceTokenPrefix = "Bearer ";
internal readonly TokenCredentialCache tokenCredentialCache;
private bool isDisposed = false;

internal readonly TokenCredential tokenCredential;

public AuthorizationTokenProviderTokenCredential(
TokenCredential tokenCredential,
Uri accountEndpoint,
TimeSpan? backgroundTokenCredentialRefreshInterval)
{
this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential));
this.tokenCredentialCache = new TokenCredentialCache(
tokenCredential: tokenCredential,
accountEndpoint: accountEndpoint,
Expand Down Expand Up @@ -71,6 +75,21 @@ public override async ValueTask AddAuthorizationHeaderAsync(
}
}

public override async ValueTask AddInferenceAuthorizationHeaderAsync(
INameValueCollection headersCollection,
Uri requestAddress,
string verb,
AuthorizationTokenType tokenType)
{
using (Trace trace = Trace.GetRootTrace(nameof(GetUserAuthorizationTokenAsync), TraceComponent.Authorization, TraceLevel.Info))
{
string token = await this.tokenCredentialCache.GetTokenAsync(trace);

string inferenceToken = InferenceTokenPrefix + token;
headersCollection.Add(HttpConstants.HttpHeaders.Authorization, inferenceToken);
}
}

public override void TraceUnauthorized(
DocumentClientException dce,
string authorizationToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,10 @@ private void CheckAndRefreshTokenProvider()
}
}
}

public override ValueTask AddInferenceAuthorizationHeaderAsync(INameValueCollection headersCollection, Uri requestAddress, string verb, AuthorizationTokenType tokenType)
{
throw new NotImplementedException("AddInferenceAuthorizationHeaderAsync is only valid for AAD");
}
}
}
198 changes: 198 additions & 0 deletions Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using global::Azure.Core;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Collections;

/// <summary>
/// Provides functionality to interact with the Cosmos DB Inference Service for semantic reranking.
/// </summary>
internal class InferenceService : IDisposable
{
// Base path for the inference service endpoint.
private const string basePath = "/inference/semanticReranking";
// User agent string for inference requests.
private const string inferenceUserAgent = "cosmos-inference-dotnet";
// Default scope for AAD authentication.
private const string inferenceServiceDefaultScope = "https://dbinference.azure.com/.default";

private readonly string inferenceServiceBaseUrl;
private readonly Uri inferenceEndpoint;
private readonly HttpClient httpClient;
private readonly AuthorizationTokenProvider cosmosAuthorization;

private bool disposedValue;

/// <summary>
/// Initializes a new instance of the <see cref="InferenceService"/> class.
/// </summary>
/// <param name="client">The CosmosClient instance.</param>
/// <exception cref="InvalidOperationException">Thrown if AAD authentication is not used.</exception>
public InferenceService(CosmosClient client)
{
this.inferenceServiceBaseUrl = ConfigurationManager.GetEnvironmentVariable<string>("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", null);

if (string.IsNullOrEmpty(this.inferenceServiceBaseUrl))
{
throw new ArgumentNullException("Set environment variable AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT to use inference service");
}

// Create and configure HttpClient for inference requests.
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
gatewayModeMaxConnectionLimit: client.DocumentClient.ConnectionPolicy.MaxConnectionLimit,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please isolate Inference settings/configurations.

webProxy: null,
serverCertificateCustomValidationCallback: client.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback);

this.httpClient = new HttpClient(httpMessageHandler);

this.CreateClientHelper(this.httpClient);

// Construct the inference service endpoint URI.
this.inferenceEndpoint = new Uri($"{this.inferenceServiceBaseUrl}/{basePath}");

// Ensure AAD authentication is used.
if (client.DocumentClient.cosmosAuthorization.GetType() != typeof(AuthorizationTokenProviderTokenCredential))
{
throw new InvalidOperationException("InferenceService only supports AAD authentication.");
}

// Set up token credential for authorization.
AuthorizationTokenProviderTokenCredential defaultOperationTokenProvider = client.DocumentClient.cosmosAuthorization as AuthorizationTokenProviderTokenCredential;
TokenCredential tokenCredential = defaultOperationTokenProvider.tokenCredential;

this.cosmosAuthorization = new AuthorizationTokenProviderTokenCredential(
tokenCredential: tokenCredential,
accountEndpoint: new Uri(inferenceServiceDefaultScope),
backgroundTokenCredentialRefreshInterval: client.ClientOptions?.TokenCredentialBackgroundRefreshInterval);
}

/// <summary>
/// Sends a semantic rerank request to the inference service.
/// </summary>
/// <param name="rerankContext">The context/query for reranking.</param>
/// <param name="documents">The documents to be reranked.</param>
/// <param name="options">Optional additional options for the request.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>A dictionary containing the reranked results.</returns>
public async Task<SemanticRerankResult> SemanticRerankAsync(
string rerankContext,
IEnumerable<string> documents,
IDictionary<string, object> options = null,
CancellationToken cancellationToken = default)
{
// Prepare HTTP request for semantic reranking.
HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint);
INameValueCollection additionalHeaders = new RequestNameValueCollection();
await this.cosmosAuthorization.AddInferenceAuthorizationHeaderAsync(
headersCollection: additionalHeaders,
this.inferenceEndpoint,
HttpConstants.HttpMethods.Post,
AuthorizationTokenType.AadToken);
additionalHeaders.Add(HttpConstants.HttpHeaders.UserAgent, inferenceUserAgent);

// Add all headers to the HTTP request.
foreach (string key in additionalHeaders.AllKeys())
{
message.Headers.Add(key, additionalHeaders[key]);
}

// Build the request payload.
Dictionary<string, object> body = this.AddSemanticRerankPayload(rerankContext, documents, options);

message.Content = new StringContent(
Newtonsoft.Json.JsonConvert.SerializeObject(body),
Encoding.UTF8,
RuntimeConstants.MediaTypes.Json);

// Send the request and ensure success.
HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reliability story? (ex: retries etc...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retries will come at a later date

responseMessage.EnsureSuccessStatusCode();

// Deserialize and return the response content as a dictionary.
return await SemanticRerankResult.DeserializeSemanticRerankResultAsync(responseMessage);
}

/// <summary>
/// Configures the provided HttpClient with default headers and settings for inference requests.
/// </summary>
/// <param name="httpClient">The HttpClient to configure.</param>
private void CreateClientHelper(HttpClient httpClient)
{
httpClient.Timeout = TimeSpan.FromSeconds(120);
httpClient.DefaultRequestHeaders.CacheControl = new CacheControlHeaderValue { NoCache = true };

// Set requested API version header for version enforcement.
httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Version,
HttpConstants.Versions.CurrentVersion);

httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Accept, RuntimeConstants.MediaTypes.Json);
}

/// <summary>
/// Constructs the payload for the semantic rerank request.
/// </summary>
/// <param name="rerankContext">The context/query for reranking.</param>
/// <param name="documents">The documents to be reranked.</param>
/// <param name="options">Optional additional options.</param>
/// <returns>A dictionary representing the request payload.</returns>
private Dictionary<string, object> AddSemanticRerankPayload(string rerankContext, IEnumerable<string> documents, IDictionary<string, object> options)
{
Dictionary<string, object> payload = new Dictionary<string, object>
{
{ "query", rerankContext },
{ "documents", documents.ToArray() }
};

if (options == null)
{
return payload;
}

// Add any additional options to the payload.
foreach (string option in options.Keys)
{
payload.Add(option, options[option]);
}

return payload;
}

/// <summary>
/// Disposes managed resources used by the service.
/// </summary>
/// <param name="disposing">Indicates if called from Dispose.</param>
protected void Dispose(bool disposing)
{
if (!this.disposedValue)
{
if (disposing)
{
this.httpClient.Dispose();
this.cosmosAuthorization.Dispose();
}

this.disposedValue = true;
}
}

/// <summary>
/// Disposes the service and its resources.
/// </summary>
public void Dispose()
{
this.Dispose(true);
}
}
}
46 changes: 46 additions & 0 deletions Microsoft.Azure.Cosmos/src/Inference/RerankScore.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos
{
/// <summary>
/// Represents the score assigned to a document after a reranking operation.
/// </summary>
#if PREVIEW
public
#else
internal
#endif

class RerankScore
{
/// <summary>
/// Gets the document content or identifier that was reranked.
/// </summary>
public object Document { get; }

/// <summary>
/// Gets the score assigned to the document after reranking.
/// </summary>
public double Score { get; }

/// <summary>
/// Gets the original index or position of the document before reranking.
/// </summary>
public int Index { get; }

/// <summary>
/// Initializes a new instance of the <see cref="RerankScore"/> class.
/// </summary>
/// <param name="document">The document content or identifier.</param>
/// <param name="score">The reranked score for the document.</param>
/// <param name="index">The original index of the document.</param>
public RerankScore(object document, double score, int index)
{
this.Document = document;
this.Score = score;
this.Index = index;
}
}
}
Loading
Loading