-
Notifications
You must be signed in to change notification settings - Fork 525
Semantic Rerank: Adds Semantic Rerank API #5445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
7f732e5
11c81aa
bad23ca
f396716
697d9eb
02fc277
e16e988
3ad012b
f2e0e5b
db15b93
e42274e
8923888
5437698
4a5afaf
c05501e
76f9ce1
8fba1a1
e8445e7
57e65a7
0bd9579
257847e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| //------------------------------------------------------------ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| //------------------------------------------------------------ | ||
|
|
||
| namespace Microsoft.Azure.Cosmos | ||
| { | ||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Linq; | ||
| using System.Net.Http; | ||
| using System.Net.Http.Headers; | ||
| using System.Text; | ||
| using System.Threading; | ||
| using System.Threading.Tasks; | ||
| using global::Azure.Core; | ||
| using Microsoft.Azure.Documents; | ||
| using Microsoft.Azure.Documents.Collections; | ||
|
|
||
| /// <summary> | ||
| /// Provides functionality to interact with the Cosmos DB Inference Service for semantic reranking. | ||
| /// </summary> | ||
| internal class InferenceService : IDisposable | ||
| { | ||
| // Base path for the inference service endpoint. | ||
| private const string basePath = "/inference/semanticReranking"; | ||
| // User agent string for inference requests. | ||
| private const string inferenceUserAgent = "cosmos-inference-dotnet"; | ||
| // Default scope for AAD authentication. | ||
| private const string inferenceServiceDefaultScope = "https://dbinference.azure.com/.default"; | ||
|
|
||
| private readonly string inferenceServiceBaseUrl; | ||
| private readonly Uri inferenceEndpoint; | ||
| private readonly HttpClient httpClient; | ||
| private readonly AuthorizationTokenProvider cosmosAuthorization; | ||
|
|
||
| private bool disposedValue; | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the <see cref="InferenceService"/> class. | ||
| /// </summary> | ||
| /// <param name="client">The CosmosClient instance.</param> | ||
| /// <exception cref="InvalidOperationException">Thrown if AAD authentication is not used.</exception> | ||
| public InferenceService(CosmosClient client) | ||
| { | ||
| this.inferenceServiceBaseUrl = ConfigurationManager.GetEnvironmentVariable<string>("AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT", null); | ||
|
|
||
| if (string.IsNullOrEmpty(this.inferenceServiceBaseUrl)) | ||
| { | ||
| throw new ArgumentNullException("Set environment variable AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT to use inference service"); | ||
| } | ||
|
|
||
| // Create and configure HttpClient for inference requests. | ||
| HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler( | ||
| gatewayModeMaxConnectionLimit: client.DocumentClient.ConnectionPolicy.MaxConnectionLimit, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's please isolate Inference settings/configurations. |
||
| webProxy: null, | ||
| serverCertificateCustomValidationCallback: client.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback); | ||
|
|
||
| this.httpClient = new HttpClient(httpMessageHandler); | ||
|
|
||
| this.CreateClientHelper(this.httpClient); | ||
|
|
||
| // Construct the inference service endpoint URI. | ||
| this.inferenceEndpoint = new Uri($"{this.inferenceServiceBaseUrl}/{basePath}"); | ||
|
|
||
| // Ensure AAD authentication is used. | ||
| if (client.DocumentClient.cosmosAuthorization.GetType() != typeof(AuthorizationTokenProviderTokenCredential)) | ||
| { | ||
| throw new InvalidOperationException("InferenceService only supports AAD authentication."); | ||
| } | ||
|
|
||
| // Set up token credential for authorization. | ||
| AuthorizationTokenProviderTokenCredential defaultOperationTokenProvider = client.DocumentClient.cosmosAuthorization as AuthorizationTokenProviderTokenCredential; | ||
| TokenCredential tokenCredential = defaultOperationTokenProvider.tokenCredential; | ||
|
|
||
| this.cosmosAuthorization = new AuthorizationTokenProviderTokenCredential( | ||
| tokenCredential: tokenCredential, | ||
| accountEndpoint: new Uri(inferenceServiceDefaultScope), | ||
| backgroundTokenCredentialRefreshInterval: client.ClientOptions?.TokenCredentialBackgroundRefreshInterval); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Sends a semantic rerank request to the inference service. | ||
| /// </summary> | ||
| /// <param name="rerankContext">The context/query for reranking.</param> | ||
| /// <param name="documents">The documents to be reranked.</param> | ||
| /// <param name="options">Optional additional options for the request.</param> | ||
| /// <param name="cancellationToken">Cancellation token.</param> | ||
| /// <returns>A dictionary containing the reranked results.</returns> | ||
| public async Task<SemanticRerankResult> SemanticRerankAsync( | ||
| string rerankContext, | ||
| IEnumerable<string> documents, | ||
| IDictionary<string, object> options = null, | ||
| CancellationToken cancellationToken = default) | ||
| { | ||
| // Prepare HTTP request for semantic reranking. | ||
| HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint); | ||
| INameValueCollection additionalHeaders = new RequestNameValueCollection(); | ||
| await this.cosmosAuthorization.AddInferenceAuthorizationHeaderAsync( | ||
| headersCollection: additionalHeaders, | ||
| this.inferenceEndpoint, | ||
| HttpConstants.HttpMethods.Post, | ||
| AuthorizationTokenType.AadToken); | ||
| additionalHeaders.Add(HttpConstants.HttpHeaders.UserAgent, inferenceUserAgent); | ||
|
|
||
| // Add all headers to the HTTP request. | ||
| foreach (string key in additionalHeaders.AllKeys()) | ||
| { | ||
| message.Headers.Add(key, additionalHeaders[key]); | ||
| } | ||
|
|
||
| // Build the request payload. | ||
| Dictionary<string, object> body = this.AddSemanticRerankPayload(rerankContext, documents, options); | ||
|
|
||
| message.Content = new StringContent( | ||
| Newtonsoft.Json.JsonConvert.SerializeObject(body), | ||
| Encoding.UTF8, | ||
| RuntimeConstants.MediaTypes.Json); | ||
|
|
||
| // Send the request and ensure success. | ||
| HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reliability story? (ex: retries etc...)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retries will come at a later date |
||
| responseMessage.EnsureSuccessStatusCode(); | ||
|
|
||
| // Deserialize and return the response content as a dictionary. | ||
| return await SemanticRerankResult.DeserializeSemanticRerankResultAsync(responseMessage); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Configures the provided HttpClient with default headers and settings for inference requests. | ||
| /// </summary> | ||
| /// <param name="httpClient">The HttpClient to configure.</param> | ||
| private void CreateClientHelper(HttpClient httpClient) | ||
| { | ||
| httpClient.Timeout = TimeSpan.FromSeconds(120); | ||
| httpClient.DefaultRequestHeaders.CacheControl = new CacheControlHeaderValue { NoCache = true }; | ||
|
|
||
| // Set requested API version header for version enforcement. | ||
| httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Version, | ||
| HttpConstants.Versions.CurrentVersion); | ||
|
|
||
| httpClient.DefaultRequestHeaders.Add(HttpConstants.HttpHeaders.Accept, RuntimeConstants.MediaTypes.Json); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Constructs the payload for the semantic rerank request. | ||
| /// </summary> | ||
| /// <param name="rerankContext">The context/query for reranking.</param> | ||
| /// <param name="documents">The documents to be reranked.</param> | ||
| /// <param name="options">Optional additional options.</param> | ||
| /// <returns>A dictionary representing the request payload.</returns> | ||
| private Dictionary<string, object> AddSemanticRerankPayload(string rerankContext, IEnumerable<string> documents, IDictionary<string, object> options) | ||
| { | ||
| Dictionary<string, object> payload = new Dictionary<string, object> | ||
| { | ||
| { "query", rerankContext }, | ||
| { "documents", documents.ToArray() } | ||
| }; | ||
|
|
||
| if (options == null) | ||
| { | ||
| return payload; | ||
| } | ||
|
|
||
| // Add any additional options to the payload. | ||
| foreach (string option in options.Keys) | ||
| { | ||
| payload.Add(option, options[option]); | ||
| } | ||
|
|
||
| return payload; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Disposes managed resources used by the service. | ||
| /// </summary> | ||
| /// <param name="disposing">Indicates if called from Dispose.</param> | ||
| protected void Dispose(bool disposing) | ||
| { | ||
| if (!this.disposedValue) | ||
| { | ||
| if (disposing) | ||
| { | ||
| this.httpClient.Dispose(); | ||
| this.cosmosAuthorization.Dispose(); | ||
| } | ||
|
|
||
| this.disposedValue = true; | ||
| } | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Disposes the service and its resources. | ||
| /// </summary> | ||
| public void Dispose() | ||
| { | ||
| this.Dispose(true); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| //------------------------------------------------------------ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| //------------------------------------------------------------ | ||
|
|
||
| namespace Microsoft.Azure.Cosmos | ||
| { | ||
| /// <summary> | ||
| /// Represents the score assigned to a document after a reranking operation. | ||
| /// </summary> | ||
| #if PREVIEW | ||
| public | ||
| #else | ||
| internal | ||
| #endif | ||
|
|
||
| class RerankScore | ||
| { | ||
| /// <summary> | ||
| /// Gets the document content or identifier that was reranked. | ||
| /// </summary> | ||
| public object Document { get; } | ||
|
|
||
| /// <summary> | ||
| /// Gets the score assigned to the document after reranking. | ||
| /// </summary> | ||
| public double Score { get; } | ||
|
|
||
| /// <summary> | ||
| /// Gets the original index or position of the document before reranking. | ||
| /// </summary> | ||
| public int Index { get; } | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of the <see cref="RerankScore"/> class. | ||
| /// </summary> | ||
| /// <param name="document">The document content or identifier.</param> | ||
| /// <param name="score">The reranked score for the document.</param> | ||
| /// <param name="index">The original index of the document.</param> | ||
| public RerankScore(object document, double score, int index) | ||
| { | ||
| this.Document = document; | ||
| this.Score = score; | ||
| this.Index = index; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO let's not overload the core types with inference specific?