diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs index dace956c..863bb0c3 100644 --- a/LLama.Examples/Examples/QuantizeModel.cs +++ b/LLama.Examples/Examples/QuantizeModel.cs @@ -2,7 +2,7 @@ namespace LLama.Examples.Examples { public class QuantizeModel { - public static async Task Run() + public static Task Run() { string inputPath = UserSettings.GetModelPath(); @@ -21,7 +21,7 @@ public static async Task Run() Console.WriteLine("Quantization failed!"); } - await Task.CompletedTask; + return Task.CompletedTask; } } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 90119d4f..bda7472d 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -76,9 +76,10 @@ public class ChatSession /// The executor for this session /// History for this session /// History Transform for this session + /// A token that cancels the operation /// A new chat session. public static async Task InitializeSessionFromHistoryAsync( - ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null) + ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default) { if (executor is not StatefulExecutorBase statefulExecutor) { @@ -90,7 +91,7 @@ public static async Task InitializeSessionFromHistoryAsync( session = session.WithHistoryTransform(transform); } - await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); return session; } @@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage() /// Compute KV cache for the message and add it to the chat history. /// /// + /// /// - public async Task AddAndProcessMessage(ChatHistory.Message message) + public async Task AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default) { if (Executor is not StatefulExecutorBase statefulExecutor) { throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); } + AddMessage(message); var content = message.Content; if (message.AuthorRole != AuthorRole.Assistant) @@ -328,27 +331,27 @@ public async Task AddAndProcessMessage(ChatHistory.Message message) } } - await statefulExecutor.PrefillPromptAsync(content); + await statefulExecutor.PrefillPromptAsync(content, cancellationToken); return this; } /// /// Compute KV cache for the system message and add it to the chat history. /// - public Task AddAndProcessSystemMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + public Task AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken); /// /// Compute KV cache for the user message and add it to the chat history. /// - public Task AddAndProcessUserMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + public Task AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken); /// /// Compute KV cache for the assistant message and add it to the chat history. /// - public Task AddAndProcessAssistantMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + public Task AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken); /// /// Replace a user message with a new message and remove all messages after the new message. diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 42d76c51..4188f9e5 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,14 +1,14 @@ -using LLama.Native; using System; using System.Collections.Generic; using System.Diagnostics; -using System.Text; using System.IO; using System.IO.MemoryMappedFiles; +using System.Text; +using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; +using LLama.Native; using Microsoft.Extensions.Logging; -using System.Threading; namespace LLama { @@ -73,7 +73,7 @@ public int BatchThreads /// Get the special tokens for the model associated with this context /// public SafeLlamaModelHandle.Vocabulary Vocab { get; } - + /// /// Create a new LLamaContext for the given LLamaWeights /// @@ -396,7 +396,7 @@ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancel { return Task.Run(() => Decode(batch), cancellationToken); } - + /// /// /// @@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch) return 0; if (batch.EmbeddingsCount > BatchSize) throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); - + return (DecodeResult)NativeHandle.Decode(batch); } - + /// /// /// @@ -425,15 +425,16 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo /// /// /// + /// /// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded. - public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past) + public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default) { return Task.Run(() => { var past = n_past; var res = NativeHandle.Decode(tokens, id, batch, ref past); return (res.Item1, res.Item2, past); - }); + }, cancellationToken); } #endregion diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index eee5ea49..d0829dec 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -246,36 +246,41 @@ protected virtual void TryReuseMatchingPrefix() /// Decide whether to continue the loop. /// /// + /// /// - protected abstract Task GetLoopCondition(InferStateArgs args); + protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default); /// /// Preprocess the inputs before the inference. /// /// /// - protected abstract Task PreprocessInputs(string? text, InferStateArgs args); + /// + protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default); /// /// Do some post processing after the inference. /// /// /// + /// /// - protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); + protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// /// The core inference logic. /// /// /// - protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); + /// + protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// /// Save the current state to a file. /// /// - public abstract Task SaveState(string filename); + /// + public abstract Task SaveState(string filename, CancellationToken cancellationToken = default); /// /// Get the current state data. @@ -287,13 +292,15 @@ protected virtual void TryReuseMatchingPrefix() /// Load the state from data. /// /// - public abstract Task LoadState(ExecutorBaseState data); + /// + public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default); /// /// Load the state from a file. /// /// - public abstract Task LoadState(string filename); + /// + public abstract Task LoadState(string filename, CancellationToken cancellationToken = default); /// @@ -318,17 +325,17 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc }; AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? []); + await PreprocessInputs(text, args, cancellationToken); - await PreprocessInputs(text, args); - - while (await GetLoopCondition(args)) + while (await GetLoopCondition(args, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { break; } + args.LastOutput = string.Empty; - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); if (args.ReturnValue) { @@ -338,7 +345,7 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc yield return decoded; } - var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); + var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken); if (extraOutputs is { Count: > 0 }) { foreach (var item in extraOutputs) @@ -358,8 +365,9 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc /// It could reduce the latency of the first time response if the first input from the user is not immediate. /// /// Prompt to process + /// /// - public virtual async Task PrefillPromptAsync(string prompt) + public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default) { var inferenceParams = new InferenceParams { @@ -374,11 +382,11 @@ public virtual async Task PrefillPromptAsync(string prompt) NeedToSaveSession = false }; - await PreprocessInputs(prompt, args); + await PreprocessInputs(prompt, args, cancellationToken); // First run adds the prompt to the _embeds - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); // Second run puts it through decode - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index a2898c09..517a4e7d 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,14 +1,15 @@ -using LLama.Abstractions; -using LLama.Common; -using LLama.Native; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; +using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData() return state; } /// - public override Task LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - if(data is InstructExecutorState state) + if (data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; _embed_inps = state.EmbedInps!.ToList(); @@ -91,35 +92,35 @@ public override Task LoadState(ExecutorBaseState data) } /// - public override async Task SaveState(string filename) + public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { var state = (InstructExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - await JsonSerializer.SerializeAsync(fs, state); + await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken); } } /// - public override async Task LoadState(string filename) + public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); - await LoadState(state!); + await LoadState(state!, cancellationToken); } } /// - protected override Task GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); } /// - protected override Task PreprocessInputs(string? text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { - args.Antiprompts ??= [ ]; + args.Antiprompts ??= []; if (!args.Antiprompts.Contains(_instructionPrefix)) args.Antiprompts.Add(_instructionPrefix); @@ -155,7 +156,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } /// - protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) { @@ -167,7 +168,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) if (_pastTokensCount > 0 && args.WaitForInput) { - return Task.FromResult<(bool, IReadOnlyList)>((true, [ "\n> " ])); + return Task.FromResult<(bool, IReadOnlyList)>((true, ["\n> "])); } } @@ -185,7 +186,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { var batch = new LLamaBatch(); @@ -253,7 +254,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In return; } - + /// /// The descriptor of the state of the instruct executor. /// diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index c76a1121..e7cac4c4 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,14 +1,15 @@ -using LLama.Common; -using LLama.Native; -using LLama.Abstractions; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; +using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -69,7 +70,7 @@ public override ExecutorBaseState GetStateData() } /// - public override Task LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { if (data is InteractiveExecutorState state) { @@ -91,34 +92,35 @@ public override Task LoadState(ExecutorBaseState data) } /// - public override async Task SaveState(string filename) + public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { var state = (InteractiveExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - await JsonSerializer.SerializeAsync(fs, state); + await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken); } } /// - public override async Task LoadState(string filename) + public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { using var fs = new FileStream(filename, FileMode.Open, FileAccess.Read); + var state = await JsonSerializer.DeserializeAsync(fs); - await LoadState(state!); + await LoadState(state!, cancellationToken); } /// /// Define whether to continue the loop to generate responses. /// /// - protected override Task GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } /// - protected override Task PreprocessInputs(string? text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { if (_is_prompt_run) { @@ -164,7 +166,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } /// - private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true) + private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true) { // If the prompt contains the tag extract this. _imageInPrompt = text.Contains(""); @@ -199,7 +201,6 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru args.RemainedTokens -= line_inp.Length; } } - return Task.CompletedTask; } /// @@ -207,8 +208,9 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru /// /// /// + /// /// - protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) { @@ -238,7 +240,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { var batch = new LLamaBatch();