Skip to content

Commit de1656b

Browse files
committed
Add an IChatClient implementation to OnnxRuntimeGenAI
1 parent fe3604a commit de1656b

4 files changed

Lines changed: 393 additions & 1 deletion

File tree

src/csharp/ChatClient.cs

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Runtime.CompilerServices;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace Microsoft.ML.OnnxRuntimeGenAI;
10+
11+
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12+
public sealed partial class ChatClient : IChatClient
13+
{
14+
/// <summary>The options used to configure the instance.</summary>
15+
private readonly ChatClientConfiguration _config;
16+
/// <summary>The wrapped <see cref="Model"/>.</summary>
17+
private readonly Model _model;
18+
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19+
private readonly Tokenizer _tokenizer;
20+
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21+
private readonly bool _ownsModel;
22+
/// <summary>Metadata for the chat client.</summary>
23+
private readonly ChatClientMetadata _metadata;
24+
25+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
26+
/// <param name="configuration">Options used to configure the client instance.</param>
27+
/// <param name="modelPath">The file path to the model to load.</param>
28+
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
29+
public ChatClient(ChatClientConfiguration configuration, string modelPath)
30+
{
31+
if (configuration is null)
32+
{
33+
throw new ArgumentNullException(nameof(configuration));
34+
}
35+
36+
if (modelPath is null)
37+
{
38+
throw new ArgumentNullException(nameof(modelPath));
39+
}
40+
41+
_config = configuration;
42+
43+
_ownsModel = true;
44+
_model = new Model(modelPath);
45+
_tokenizer = new Tokenizer(_model);
46+
47+
_metadata = new("onnx", new Uri($"file://{modelPath}"), modelPath);
48+
}
49+
50+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
51+
/// <param name="configuration">Options used to configure the client instance.</param>
52+
/// <param name="model">The model to employ.</param>
53+
/// <param name="ownsModel">
54+
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
55+
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
56+
/// The default is <see langword="true"/>.
57+
/// </param>
58+
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
59+
public ChatClient(ChatClientConfiguration configuration, Model model, bool ownsModel = true)
60+
{
61+
if (configuration is null)
62+
{
63+
throw new ArgumentNullException(nameof(configuration));
64+
}
65+
66+
if (model is null)
67+
{
68+
throw new ArgumentNullException(nameof(model));
69+
}
70+
71+
_config = configuration;
72+
73+
_ownsModel = ownsModel;
74+
_model = model;
75+
_tokenizer = new Tokenizer(_model);
76+
77+
_metadata = new("onnx");
78+
}
79+
80+
/// <inheritdoc/>
81+
public void Dispose()
82+
{
83+
_tokenizer.Dispose();
84+
85+
if (_ownsModel)
86+
{
87+
_model.Dispose();
88+
}
89+
}
90+
91+
/// <inheritdoc/>
92+
public async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
93+
{
94+
if (chatMessages is null)
95+
{
96+
throw new ArgumentNullException(nameof(chatMessages));
97+
}
98+
99+
int inputTokens = 0, outputTokens = 0;
100+
StringBuilder text = new();
101+
await Task.Run(() =>
102+
{
103+
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
104+
using GeneratorParams generatorParams = new(_model);
105+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
106+
107+
inputTokens = tokens[0].Length;
108+
109+
using Generator generator = new(_model, generatorParams);
110+
generator.AppendTokenSequences(tokens);
111+
112+
using var tokenizerStream = _tokenizer.CreateStream();
113+
114+
while (!generator.IsDone())
115+
{
116+
cancellationToken.ThrowIfCancellationRequested();
117+
118+
generator.GenerateNextToken();
119+
120+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
121+
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
122+
123+
if (IsStop(next, options))
124+
{
125+
break;
126+
}
127+
128+
outputTokens++;
129+
text.Append(next);
130+
}
131+
}, cancellationToken);
132+
133+
return new ChatResponse(new ChatMessage(ChatRole.Assistant, text.ToString()))
134+
{
135+
ResponseId = Guid.NewGuid().ToString(),
136+
CreatedAt = DateTimeOffset.UtcNow,
137+
ModelId = _metadata.ModelId,
138+
Usage = new()
139+
{
140+
InputTokenCount = inputTokens,
141+
OutputTokenCount = outputTokens,
142+
TotalTokenCount = inputTokens + outputTokens,
143+
},
144+
};
145+
}
146+
147+
/// <inheritdoc/>
148+
public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
149+
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
150+
{
151+
if (chatMessages is null)
152+
{
153+
throw new ArgumentNullException(nameof(chatMessages));
154+
}
155+
156+
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
157+
using GeneratorParams generatorParams = new(_model);
158+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
159+
160+
using Generator generator = new(_model, generatorParams);
161+
generator.AppendTokenSequences(tokens);
162+
163+
using var tokenizerStream = _tokenizer.CreateStream();
164+
165+
int inputTokens = tokens[0].Length, outputTokens = 0;
166+
var completionId = Guid.NewGuid().ToString();
167+
while (!generator.IsDone())
168+
{
169+
string next = await Task.Run(() =>
170+
{
171+
generator.GenerateNextToken();
172+
173+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
174+
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
175+
}, cancellationToken);
176+
177+
if (IsStop(next, options))
178+
{
179+
break;
180+
}
181+
182+
outputTokens++;
183+
yield return new()
184+
{
185+
CreatedAt = DateTimeOffset.UtcNow,
186+
ResponseId = completionId,
187+
Role = ChatRole.Assistant,
188+
Text = next,
189+
};
190+
}
191+
192+
yield return new()
193+
{
194+
Contents = [new UsageContent(new()
195+
{
196+
InputTokenCount = inputTokens,
197+
OutputTokenCount = outputTokens,
198+
TotalTokenCount = inputTokens + outputTokens,
199+
})],
200+
CreatedAt = DateTimeOffset.UtcNow,
201+
ResponseId = completionId,
202+
Role = ChatRole.Assistant,
203+
};
204+
}
205+
206+
/// <inheritdoc/>
207+
object IChatClient.GetService(Type serviceType, object serviceKey = null) =>
208+
serviceKey is not null ? null :
209+
serviceType == typeof(ChatClientMetadata) ? _metadata :
210+
serviceType == typeof(Model) ? _model :
211+
serviceType == typeof(Tokenizer) ? _tokenizer :
212+
serviceType?.IsInstanceOfType(this) is true ? this :
213+
null;
214+
215+
/// <summary>Gets whether the specified token is a stop sequence.</summary>
216+
private bool IsStop(string token, ChatOptions options) =>
217+
options?.StopSequences?.Contains(token) is true ||
218+
Array.IndexOf(_config.StopSequences, token) >= 0;
219+
220+
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
221+
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
222+
{
223+
if (options is null)
224+
{
225+
return;
226+
}
227+
228+
if (options.MaxOutputTokens.HasValue)
229+
{
230+
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
231+
}
232+
233+
if (options.Temperature.HasValue)
234+
{
235+
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
236+
}
237+
238+
if (options.PresencePenalty.HasValue)
239+
{
240+
generatorParams.SetSearchOption("repetition_penalty", options.PresencePenalty.Value);
241+
}
242+
243+
if (options.TopP.HasValue || options.TopK.HasValue)
244+
{
245+
if (options.TopP.HasValue)
246+
{
247+
generatorParams.SetSearchOption("top_p", options.TopP.Value);
248+
}
249+
250+
if (options.TopK.HasValue)
251+
{
252+
generatorParams.SetSearchOption("top_k", options.TopK.Value);
253+
}
254+
}
255+
256+
if (options.Seed.HasValue)
257+
{
258+
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
259+
}
260+
261+
if (options.AdditionalProperties is { } props)
262+
{
263+
foreach (var entry in props)
264+
{
265+
if (entry.Value is bool b)
266+
{
267+
generatorParams.SetSearchOption(entry.Key, b);
268+
}
269+
else if (entry.Value is not null)
270+
{
271+
try
272+
{
273+
double d = Convert.ToDouble(entry.Value);
274+
generatorParams.SetSearchOption(entry.Key, d);
275+
}
276+
catch
277+
{
278+
// Ignore values we can't convert
279+
}
280+
}
281+
}
282+
}
283+
}
284+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
5+
namespace Microsoft.ML.OnnxRuntimeGenAI;
6+
7+
/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary>
8+
/// <remarks>
9+
/// Every model has different requirements for stop sequences and prompt formatting. For best results,
10+
/// the configuration should be tailored to the exact nature of the model being used. For example,
11+
/// when using a Phi3 model, a configuration like the following may be used:
12+
/// <code>
13+
/// static ChatClientConfiguration CreateForPhi3() =&gt;
14+
/// new(["&lt;|system|&gt;", "&lt;|user|&gt;", "&lt;|assistant|&gt;", "&lt;|end|&gt;"],
15+
/// (IEnumerable&lt;ChatMessage&gt; messages) =&gt;
16+
/// {
17+
/// StringBuilder prompt = new();
18+
///
19+
/// foreach (var message in messages)
20+
/// foreach (var content in message.Contents.OfType&lt;TextContent&gt;())
21+
/// prompt.Append("&lt;|").Append(message.Role.Value).Append("|&gt;\n").Append(tc.Text).Append("&lt;|end|&gt;\n");
22+
///
23+
/// return prompt.Append("&lt;|assistant|&gt;\n").ToString();
24+
/// });
25+
/// </code>
26+
/// </remarks>
27+
public sealed class ChatClientConfiguration
28+
{
29+
private string[] _stopSequences;
30+
private Func<IEnumerable<ChatMessage>, string> _promptFormatter;
31+
32+
/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary>
33+
/// <param name="stopSequences">The stop sequences used by the model.</param>
34+
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
35+
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception>
36+
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception>
37+
public ChatClientConfiguration(
38+
string[] stopSequences,
39+
Func<IEnumerable<ChatMessage>, string> promptFormatter)
40+
{
41+
if (stopSequences is null)
42+
{
43+
throw new ArgumentNullException(nameof(stopSequences));
44+
}
45+
46+
if (promptFormatter is null)
47+
{
48+
throw new ArgumentNullException(nameof(promptFormatter));
49+
}
50+
51+
StopSequences = stopSequences;
52+
PromptFormatter = promptFormatter;
53+
}
54+
55+
/// <summary>
56+
/// Gets or sets stop sequences to use during generation.
57+
/// </summary>
58+
/// <remarks>
59+
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
60+
/// </remarks>
61+
public string[] StopSequences
62+
{
63+
get => _stopSequences;
64+
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
65+
}
66+
67+
/// <summary>Gets the function that creates a prompt string from the chat history.</summary>
68+
public Func<IEnumerable<ChatMessage>, string> PromptFormatter
69+
{
70+
get => _promptFormatter;
71+
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
72+
}
73+
}

src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,8 @@
121121
<PackageReference Include="System.Memory" Version="4.5.5" />
122122
</ItemGroup>
123123

124+
<ItemGroup>
125+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.3.0-preview.1.25114.11" />
126+
</ItemGroup>
127+
124128
</Project>

0 commit comments

Comments
 (0)