1+ using Microsoft . Extensions . AI ;
2+ using System ;
3+ using System . Collections . Generic ;
4+ using System . Runtime . CompilerServices ;
5+ using System . Text ;
6+ using System . Threading ;
7+ using System . Threading . Tasks ;
8+
9+ namespace Microsoft . ML . OnnxRuntimeGenAI ;
10+
11+ /// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12+ public sealed partial class ChatClient : IChatClient
13+ {
14+ /// <summary>The options used to configure the instance.</summary>
15+ private readonly ChatClientConfiguration _config ;
16+ /// <summary>The wrapped <see cref="Model"/>.</summary>
17+ private readonly Model _model ;
18+ /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19+ private readonly Tokenizer _tokenizer ;
20+ /// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21+ private readonly bool _ownsModel ;
22+ /// <summary>Metadata for the chat client.</summary>
23+ private readonly ChatClientMetadata _metadata ;
24+
25+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
26+ /// <param name="configuration">Options used to configure the client instance.</param>
27+ /// <param name="modelPath">The file path to the model to load.</param>
28+ /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
29+ public ChatClient ( ChatClientConfiguration configuration , string modelPath )
30+ {
31+ if ( configuration is null )
32+ {
33+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
34+ }
35+
36+ if ( modelPath is null )
37+ {
38+ throw new ArgumentNullException ( nameof ( modelPath ) ) ;
39+ }
40+
41+ _config = configuration ;
42+
43+ _ownsModel = true ;
44+ _model = new Model ( modelPath ) ;
45+ _tokenizer = new Tokenizer ( _model ) ;
46+
47+ _metadata = new ( "onnx" , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
48+ }
49+
50+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
51+ /// <param name="configuration">Options used to configure the client instance.</param>
52+ /// <param name="model">The model to employ.</param>
53+ /// <param name="ownsModel">
54+ /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
55+ /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
56+ /// The default is <see langword="true"/>.
57+ /// </param>
58+ /// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
59+ public ChatClient ( ChatClientConfiguration configuration , Model model , bool ownsModel = true )
60+ {
61+ if ( configuration is null )
62+ {
63+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
64+ }
65+
66+ if ( model is null )
67+ {
68+ throw new ArgumentNullException ( nameof ( model ) ) ;
69+ }
70+
71+ _config = configuration ;
72+
73+ _ownsModel = ownsModel ;
74+ _model = model ;
75+ _tokenizer = new Tokenizer ( _model ) ;
76+
77+ _metadata = new ( "onnx" ) ;
78+ }
79+
80+ /// <inheritdoc/>
81+ public void Dispose ( )
82+ {
83+ _tokenizer . Dispose ( ) ;
84+
85+ if ( _ownsModel )
86+ {
87+ _model . Dispose ( ) ;
88+ }
89+ }
90+
91+ /// <inheritdoc/>
92+ public async Task < ChatResponse > GetResponseAsync ( IList < ChatMessage > chatMessages , ChatOptions options = null , CancellationToken cancellationToken = default )
93+ {
94+ if ( chatMessages is null )
95+ {
96+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
97+ }
98+
99+ int inputTokens = 0 , outputTokens = 0 ;
100+ StringBuilder text = new ( ) ;
101+ await Task . Run ( ( ) =>
102+ {
103+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
104+ using GeneratorParams generatorParams = new ( _model ) ;
105+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
106+
107+ inputTokens = tokens [ 0 ] . Length ;
108+
109+ using Generator generator = new ( _model , generatorParams ) ;
110+ generator . AppendTokenSequences ( tokens ) ;
111+
112+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
113+
114+ while ( ! generator . IsDone ( ) )
115+ {
116+ cancellationToken . ThrowIfCancellationRequested ( ) ;
117+
118+ generator . GenerateNextToken ( ) ;
119+
120+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
121+ string next = tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
122+
123+ if ( IsStop ( next , options ) )
124+ {
125+ break ;
126+ }
127+
128+ outputTokens ++ ;
129+ text . Append ( next ) ;
130+ }
131+ } , cancellationToken ) ;
132+
133+ return new ChatResponse ( new ChatMessage ( ChatRole . Assistant , text . ToString ( ) ) )
134+ {
135+ ResponseId = Guid . NewGuid ( ) . ToString ( ) ,
136+ CreatedAt = DateTimeOffset . UtcNow ,
137+ ModelId = _metadata . ModelId ,
138+ Usage = new ( )
139+ {
140+ InputTokenCount = inputTokens ,
141+ OutputTokenCount = outputTokens ,
142+ TotalTokenCount = inputTokens + outputTokens ,
143+ } ,
144+ } ;
145+ }
146+
147+ /// <inheritdoc/>
148+ public async IAsyncEnumerable < ChatResponseUpdate > GetStreamingResponseAsync (
149+ IList < ChatMessage > chatMessages , ChatOptions options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
150+ {
151+ if ( chatMessages is null )
152+ {
153+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
154+ }
155+
156+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
157+ using GeneratorParams generatorParams = new ( _model ) ;
158+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
159+
160+ using Generator generator = new ( _model , generatorParams ) ;
161+ generator . AppendTokenSequences ( tokens ) ;
162+
163+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
164+
165+ int inputTokens = tokens [ 0 ] . Length , outputTokens = 0 ;
166+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
167+ while ( ! generator . IsDone ( ) )
168+ {
169+ string next = await Task . Run ( ( ) =>
170+ {
171+ generator . GenerateNextToken ( ) ;
172+
173+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
174+ return tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
175+ } , cancellationToken ) ;
176+
177+ if ( IsStop ( next , options ) )
178+ {
179+ break ;
180+ }
181+
182+ outputTokens ++ ;
183+ yield return new ( )
184+ {
185+ CreatedAt = DateTimeOffset . UtcNow ,
186+ ResponseId = completionId ,
187+ Role = ChatRole . Assistant ,
188+ Text = next ,
189+ } ;
190+ }
191+
192+ yield return new ( )
193+ {
194+ Contents = [ new UsageContent ( new ( )
195+ {
196+ InputTokenCount = inputTokens ,
197+ OutputTokenCount = outputTokens ,
198+ TotalTokenCount = inputTokens + outputTokens ,
199+ } ) ] ,
200+ CreatedAt = DateTimeOffset . UtcNow ,
201+ ResponseId = completionId ,
202+ Role = ChatRole . Assistant ,
203+ } ;
204+ }
205+
206+ /// <inheritdoc/>
207+ object IChatClient . GetService ( Type serviceType , object serviceKey = null ) =>
208+ serviceKey is not null ? null :
209+ serviceType == typeof ( ChatClientMetadata ) ? _metadata :
210+ serviceType == typeof ( Model ) ? _model :
211+ serviceType == typeof ( Tokenizer ) ? _tokenizer :
212+ serviceType ? . IsInstanceOfType ( this ) is true ? this :
213+ null ;
214+
215+ /// <summary>Gets whether the specified token is a stop sequence.</summary>
216+ private bool IsStop ( string token , ChatOptions options ) =>
217+ options ? . StopSequences ? . Contains ( token ) is true ||
218+ Array . IndexOf ( _config . StopSequences , token ) >= 0 ;
219+
220+ /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
221+ private static void UpdateGeneratorParamsFromOptions ( int numInputTokens , GeneratorParams generatorParams , ChatOptions options )
222+ {
223+ if ( options is null )
224+ {
225+ return ;
226+ }
227+
228+ if ( options . MaxOutputTokens . HasValue )
229+ {
230+ generatorParams . SetSearchOption ( "max_length" , numInputTokens + options . MaxOutputTokens . Value ) ;
231+ }
232+
233+ if ( options . Temperature . HasValue )
234+ {
235+ generatorParams . SetSearchOption ( "temperature" , options . Temperature . Value ) ;
236+ }
237+
238+ if ( options . PresencePenalty . HasValue )
239+ {
240+ generatorParams . SetSearchOption ( "repetition_penalty" , options . PresencePenalty . Value ) ;
241+ }
242+
243+ if ( options . TopP . HasValue || options . TopK . HasValue )
244+ {
245+ if ( options . TopP . HasValue )
246+ {
247+ generatorParams . SetSearchOption ( "top_p" , options . TopP . Value ) ;
248+ }
249+
250+ if ( options . TopK . HasValue )
251+ {
252+ generatorParams . SetSearchOption ( "top_k" , options . TopK . Value ) ;
253+ }
254+ }
255+
256+ if ( options . Seed . HasValue )
257+ {
258+ generatorParams . SetSearchOption ( "random_seed" , options . Seed . Value ) ;
259+ }
260+
261+ if ( options . AdditionalProperties is { } props )
262+ {
263+ foreach ( var entry in props )
264+ {
265+ if ( entry . Value is bool b )
266+ {
267+ generatorParams . SetSearchOption ( entry . Key , b ) ;
268+ }
269+ else if ( entry . Value is not null )
270+ {
271+ try
272+ {
273+ double d = Convert . ToDouble ( entry . Value ) ;
274+ generatorParams . SetSearchOption ( entry . Key , d ) ;
275+ }
276+ catch
277+ {
278+ // Ignore values we can't convert
279+ }
280+ }
281+ }
282+ }
283+ }
284+ }
0 commit comments