diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs
index dace956c..863bb0c3 100644
--- a/LLama.Examples/Examples/QuantizeModel.cs
+++ b/LLama.Examples/Examples/QuantizeModel.cs
@@ -2,7 +2,7 @@ namespace LLama.Examples.Examples
{
public class QuantizeModel
{
- public static async Task Run()
+ public static Task Run()
{
string inputPath = UserSettings.GetModelPath();
@@ -21,7 +21,7 @@ public static async Task Run()
Console.WriteLine("Quantization failed!");
}
- await Task.CompletedTask;
+ return Task.CompletedTask;
}
}
}
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 90119d4f..bda7472d 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -76,9 +76,10 @@ public class ChatSession
/// The executor for this session
/// History for this session
/// History Transform for this session
+ /// A token that cancels the operation
/// A new chat session.
public static async Task InitializeSessionFromHistoryAsync(
- ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null)
+ ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
@@ -90,7 +91,7 @@ public static async Task InitializeSessionFromHistoryAsync(
session = session.WithHistoryTransform(transform);
}
- await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
+ await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
return session;
}
@@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage()
/// Compute KV cache for the message and add it to the chat history.
///
///
+ ///
///
- public async Task AddAndProcessMessage(ChatHistory.Message message)
+ public async Task AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
}
+
AddMessage(message);
var content = message.Content;
if (message.AuthorRole != AuthorRole.Assistant)
@@ -328,27 +331,27 @@ public async Task AddAndProcessMessage(ChatHistory.Message message)
}
}
- await statefulExecutor.PrefillPromptAsync(content);
+ await statefulExecutor.PrefillPromptAsync(content, cancellationToken);
return this;
}
///
/// Compute KV cache for the system message and add it to the chat history.
///
- public Task AddAndProcessSystemMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
+ public Task AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default)
+ => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken);
///
/// Compute KV cache for the user message and add it to the chat history.
///
- public Task AddAndProcessUserMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
+ public Task AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default)
+ => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken);
///
/// Compute KV cache for the assistant message and add it to the chat history.
///
- public Task AddAndProcessAssistantMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
+ public Task AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default)
+ => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken);
///
/// Replace a user message with a new message and remove all messages after the new message.
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 42d76c51..4188f9e5 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -1,14 +1,14 @@
-using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics;
-using System.Text;
using System.IO;
using System.IO.MemoryMappedFiles;
+using System.Text;
+using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
+using LLama.Native;
using Microsoft.Extensions.Logging;
-using System.Threading;
namespace LLama
{
@@ -73,7 +73,7 @@ public int BatchThreads
/// Get the special tokens for the model associated with this context
///
public SafeLlamaModelHandle.Vocabulary Vocab { get; }
-
+
///
/// Create a new LLamaContext for the given LLamaWeights
///
@@ -396,7 +396,7 @@ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancel
{
return Task.Run(() => Decode(batch), cancellationToken);
}
-
+
///
///
///
@@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch)
return 0;
if (batch.EmbeddingsCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
-
+
return (DecodeResult)NativeHandle.Decode(batch);
}
-
+
///
///
///
@@ -425,15 +425,16 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
///
///
///
+ ///
/// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded.
- public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
+ public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default)
{
return Task.Run(() =>
{
var past = n_past;
var res = NativeHandle.Decode(tokens, id, batch, ref past);
return (res.Item1, res.Item2, past);
- });
+ }, cancellationToken);
}
#endregion
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index eee5ea49..d0829dec 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -246,36 +246,41 @@ protected virtual void TryReuseMatchingPrefix()
/// Decide whether to continue the loop.
///
///
+ ///
///
- protected abstract Task GetLoopCondition(InferStateArgs args);
+ protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default);
///
/// Preprocess the inputs before the inference.
///
///
///
- protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
+ ///
+ protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default);
///
/// Do some post processing after the inference.
///
///
///
+ ///
///
- protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
+ protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
///
/// The core inference logic.
///
///
///
- protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
+ ///
+ protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
///
/// Save the current state to a file.
///
///
- public abstract Task SaveState(string filename);
+ ///
+ public abstract Task SaveState(string filename, CancellationToken cancellationToken = default);
///
/// Get the current state data.
@@ -287,13 +292,15 @@ protected virtual void TryReuseMatchingPrefix()
/// Load the state from data.
///
///
- public abstract Task LoadState(ExecutorBaseState data);
+ ///
+ public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default);
///
/// Load the state from a file.
///
///
- public abstract Task LoadState(string filename);
+ ///
+ public abstract Task LoadState(string filename, CancellationToken cancellationToken = default);
///
@@ -318,17 +325,17 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc
};
AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? []);
+ await PreprocessInputs(text, args, cancellationToken);
- await PreprocessInputs(text, args);
-
- while (await GetLoopCondition(args))
+ while (await GetLoopCondition(args, cancellationToken))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
+
args.LastOutput = string.Empty;
- await InferInternal(inferenceParams, args);
+ await InferInternal(inferenceParams, args, cancellationToken);
if (args.ReturnValue)
{
@@ -338,7 +345,7 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc
yield return decoded;
}
- var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
+ var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
@@ -358,8 +365,9 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc
/// It could reduce the latency of the first time response if the first input from the user is not immediate.
///
/// Prompt to process
+ ///
///
- public virtual async Task PrefillPromptAsync(string prompt)
+ public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default)
{
var inferenceParams = new InferenceParams
{
@@ -374,11 +382,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
NeedToSaveSession = false
};
- await PreprocessInputs(prompt, args);
+ await PreprocessInputs(prompt, args, cancellationToken);
// First run adds the prompt to the _embeds
- await InferInternal(inferenceParams, args);
+ await InferInternal(inferenceParams, args, cancellationToken);
// Second run puts it through decode
- await InferInternal(inferenceParams, args);
+ await InferInternal(inferenceParams, args, cancellationToken);
}
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index a2898c09..517a4e7d 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -1,14 +1,15 @@
-using LLama.Abstractions;
-using LLama.Common;
-using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Threading;
using System.Threading.Tasks;
+using LLama.Abstractions;
+using LLama.Common;
using LLama.Exceptions;
+using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData()
return state;
}
///
- public override Task LoadState(ExecutorBaseState data)
+ public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
{
- if(data is InstructExecutorState state)
+ if (data is InstructExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
_embed_inps = state.EmbedInps!.ToList();
@@ -91,35 +92,35 @@ public override Task LoadState(ExecutorBaseState data)
}
///
- public override async Task SaveState(string filename)
+ public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
{
var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
- await JsonSerializer.SerializeAsync(fs, state);
+ await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
}
}
///
- public override async Task LoadState(string filename)
+ public override async Task LoadState(string filename, CancellationToken cancellationToken = default)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync(fs);
- await LoadState(state!);
+ await LoadState(state!, cancellationToken);
}
}
///
- protected override Task GetLoopCondition(InferStateArgs args)
+ protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
{
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
}
///
- protected override Task PreprocessInputs(string? text, InferStateArgs args)
+ protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
{
- args.Antiprompts ??= [ ];
+ args.Antiprompts ??= [];
if (!args.Antiprompts.Contains(_instructionPrefix))
args.Antiprompts.Add(_instructionPrefix);
@@ -155,7 +156,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
///
- protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
@@ -167,7 +168,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
if (_pastTokensCount > 0 && args.WaitForInput)
{
- return Task.FromResult<(bool, IReadOnlyList)>((true, [ "\n> " ]));
+ return Task.FromResult<(bool, IReadOnlyList)>((true, ["\n> "]));
}
}
@@ -185,7 +186,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
///
- protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
var batch = new LLamaBatch();
@@ -253,7 +254,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
return;
}
-
+
///
/// The descriptor of the state of the instruct executor.
///
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index c76a1121..e7cac4c4 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -1,14 +1,15 @@
-using LLama.Common;
-using LLama.Native;
-using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Threading;
using System.Threading.Tasks;
+using LLama.Abstractions;
+using LLama.Common;
using LLama.Exceptions;
+using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@@ -69,7 +70,7 @@ public override ExecutorBaseState GetStateData()
}
///
- public override Task LoadState(ExecutorBaseState data)
+ public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
{
if (data is InteractiveExecutorState state)
{
@@ -91,34 +92,35 @@ public override Task LoadState(ExecutorBaseState data)
}
///
- public override async Task SaveState(string filename)
+ public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
{
var state = (InteractiveExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
- await JsonSerializer.SerializeAsync(fs, state);
+ await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
}
}
///
- public override async Task LoadState(string filename)
+ public override async Task LoadState(string filename, CancellationToken cancellationToken = default)
{
using var fs = new FileStream(filename, FileMode.Open, FileAccess.Read);
+
var state = await JsonSerializer.DeserializeAsync(fs);
- await LoadState(state!);
+ await LoadState(state!, cancellationToken);
}
///
/// Define whether to continue the loop to generate responses.
///
///
- protected override Task GetLoopCondition(InferStateArgs args)
+ protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
{
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
}
///
- protected override Task PreprocessInputs(string? text, InferStateArgs args)
+ protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
{
if (_is_prompt_run)
{
@@ -164,7 +166,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
///
- private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true)
+ private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true)
{
// If the prompt contains the tag extract this.
_imageInPrompt = text.Contains("");
@@ -199,7 +201,6 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
args.RemainedTokens -= line_inp.Length;
}
}
- return Task.CompletedTask;
}
///
@@ -207,8 +208,9 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
///
///
///
+ ///
///
- protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
@@ -238,7 +240,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
}
///
- protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
var batch = new LLamaBatch();