Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ChatHistory interface #669

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatChineseGB2312.cs
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ public static async Task Run()
else
{
var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();

session = new ChatSession(executor, chatHistory);
}
@@ -105,7 +105,7 @@ in session.RegenerateAssistantMessageAsync(
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatSessionStripRoleName.cs
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ public static async Task Run()
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
@@ -46,7 +46,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatSessionWithHistory.cs
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ public static async Task Run()
else
{
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();

session = new ChatSession(executor, chatHistory);
}
@@ -92,7 +92,7 @@ in session.RegenerateAssistantMessageAsync(
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
14 changes: 7 additions & 7 deletions LLama.Examples/Examples/ChatSessionWithRestart.cs
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@ public static async Task Run()
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession prototypeSession =
IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();
ChatSession prototypeSession =
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:" },
@@ -50,10 +50,10 @@ public static async Task Run()
while (userInput != "exit")
{
// Load the session state from the reset state
if(userInput == "reset")
if (userInput == "reset")
{
session.LoadSession(resetState);
Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}");
Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.SessionChatHistory)}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session reset.");
}
@@ -75,10 +75,10 @@ public static async Task Run()

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Provide assistant input: ");

Console.ForegroundColor = ConsoleColor.Green;
string assistantInputOverride = Console.ReadLine() ?? "";

await session.AddAndProcessUserMessage(userInputOverride);
await session.AddAndProcessAssistantMessage(assistantInputOverride);

@@ -90,7 +90,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatSessionWithRoleName.cs
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ public static async Task Run()
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);

@@ -41,7 +41,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
new Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/LoadAndSaveSession.cs
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ public static async Task Run()
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, prompt),
new Message(AuthorRole.User, prompt),
new InferenceParams()
{
Temperature = 0.6f,
2 changes: 1 addition & 1 deletion LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
public class HistoryTransform : DefaultHistoryTransform
{
/// <inheritdoc/>
public override string HistoryToText(global::LLama.Common.ChatHistory history)
public string HistoryToText(global::LLama.Common.ChatHistory history)
{
return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
2 changes: 1 addition & 1 deletion LLama.WebAPI/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ public async Task<string> SendHistory([FromBody] HistoryInput input, [FromServic
{
var history = new ChatHistory();

var messages = input.Messages.Select(m => new ChatHistory.Message(Enum.Parse<AuthorRole>(m.Role), m.Content));
var messages = input.Messages.Select(m => new Message(Enum.Parse<AuthorRole>(m.Role), m.Content));

history.Messages.AddRange(messages);

6 changes: 3 additions & 3 deletions LLama.WebAPI/Services/StatefulChatService.cs
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ public StatefulChatService(IConfiguration configuration, ILogger<StatefulChatSer
_context = new LLamaContext(weights, @params);

_session = new ChatSession(new InteractiveExecutor(_context));
_session.History.AddMessage(Common.AuthorRole.System, SystemPrompt);
_session.SessionChatHistory.AddMessage(Common.AuthorRole.System, SystemPrompt);
}

public void Dispose()
@@ -46,7 +46,7 @@ public async Task<string> Send(SendMessageInput input)
}
_logger.LogInformation("Input: {text}", input.Text);
var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
@@ -74,7 +74,7 @@ public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
_logger.LogInformation(input.Text);

var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!)
new Common.Message(Common.AuthorRole.User, input.Text!)
, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
3 changes: 1 addition & 2 deletions LLama.WebAPI/Services/StatelessChatService.cs
Original file line number Diff line number Diff line change
@@ -48,10 +48,9 @@ public async Task<string> SendAsync(ChatHistory history)
}
public class HistoryTransform : DefaultHistoryTransform
{
public override string HistoryToText(ChatHistory history)
public override string HistoryToText(IChatHistory history)
{
return base.HistoryToText(history) + "\n Assistant:";
}

}
}
8 changes: 5 additions & 3 deletions LLama/Abstractions/IHistoryTransform.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Common;
using System;
using System.Text.Json.Serialization;

namespace LLama.Abstractions
@@ -14,15 +15,16 @@ public interface IHistoryTransform
/// </summary>
/// <param name="history">The ChatHistory instance</param>
/// <returns></returns>
string HistoryToText(ChatHistory history);
string HistoryToText(IChatHistory history);

/// <summary>
/// Converts plain text to a ChatHistory instance.
/// </summary>
/// <param name="role">The role for the author.</param>
/// <param name="text">The chat history as plain text.</param>
/// <param name="type">The type of the chat history.</param>
/// <returns>The updated history.</returns>
ChatHistory TextToHistory(AuthorRole role, string text);
IChatHistory TextToHistory(AuthorRole role, string text, Type type);

/// <summary>
/// Copy the transform.
107 changes: 53 additions & 54 deletions LLama/ChatSession.cs

Large diffs are not rendered by default.

95 changes: 62 additions & 33 deletions LLama/Common/ChatHistory.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -33,41 +34,60 @@ public enum AuthorRole

// copy from semantic-kernel
/// <summary>
/// The chat history class
/// The message class
/// </summary>
public class ChatHistory
public class Message
{
private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };
/// <summary>
/// Role of the message author, e.g. user/assistant/system
/// </summary>
[JsonConverter(typeof(JsonStringEnumConverter))]
[JsonPropertyName("author_role")]
public AuthorRole AuthorRole { get; set; }

/// <summary>
/// Chat message representation
/// Message content
/// </summary>
public class Message
[JsonPropertyName("content")]
public string Content { get; set; }

/// <summary>
/// Create a new instance
/// </summary>
/// <param name="authorRole">Role of message author</param>
/// <param name="content">Message content</param>
public Message(AuthorRole authorRole, string content)
{
/// <summary>
/// Role of the message author, e.g. user/assistant/system
/// </summary>
[JsonConverter(typeof(JsonStringEnumConverter))]
[JsonPropertyName("author_role")]
public AuthorRole AuthorRole { get; set; }

/// <summary>
/// Message content
/// </summary>
[JsonPropertyName("content")]
public string Content { get; set; }

/// <summary>
/// Create a new instance
/// </summary>
/// <param name="authorRole">Role of message author</param>
/// <param name="content">Message content</param>
public Message(AuthorRole authorRole, string content)
{
this.AuthorRole = authorRole;
this.Content = content;
}
this.AuthorRole = authorRole;
this.Content = content;
}
}

/// <summary>
/// Interface for chat history
/// </summary>
public interface IChatHistory
{
/// <summary>
/// List of messages in the chat
/// </summary>
List<Message> Messages { get; set; }

/// <summary>
/// Add a message to the chat history
/// </summary>
/// <param name="authorRole">Role of the message author</param>
/// <param name="content">Message content</param>
void AddMessage(AuthorRole authorRole, string content);
}

// copy from semantic-kernel
/// <summary>
/// The chat history class
/// </summary>
public class ChatHistory : IChatHistory
{
private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };

/// <summary>
/// List of messages in the chat
@@ -99,24 +119,33 @@ public void AddMessage(AuthorRole authorRole, string content)
{
this.Messages.Add(new Message(authorRole, content));
}
}

/// <summary>
/// Serializer for chat history
/// </summary>
public class ChatHistorySerializer
{
private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };

/// <summary>
/// Serialize the chat history to JSON
/// </summary>
/// <returns></returns>
public string ToJson()
public static string ToJson(IChatHistory chatHistory)
{
return JsonSerializer.Serialize(this, _jsonOptions);
return JsonSerializer.Serialize(chatHistory, _jsonOptions);
}

/// <summary>
/// Deserialize a chat history from JSON
/// </summary>
/// <param name="json"></param>
/// <param name="type"></param>
/// <returns></returns>
public static ChatHistory? FromJson(string json)
public static IChatHistory? FromJson(string json, Type type)
{
return JsonSerializer.Deserialize<ChatHistory>(json);
return JsonSerializer.Deserialize(json, type) as IChatHistory;
}
}
}
9 changes: 5 additions & 4 deletions LLama/LLamaTransforms.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -44,7 +45,7 @@ public class DefaultHistoryTransform : IHistoryTransform
/// <param name="systemName"></param>
/// <param name="unknownName"></param>
/// <param name="isInstructMode"></param>
public DefaultHistoryTransform(string? userName = null, string? assistantName = null,
public DefaultHistoryTransform(string? userName = null, string? assistantName = null,
string? systemName = null, string? unknownName = null, bool isInstructMode = false)
{
_userName = userName ?? defaultUserName;
@@ -61,7 +62,7 @@ public IHistoryTransform Clone()
}

/// <inheritdoc />
public virtual string HistoryToText(ChatHistory history)
public virtual string HistoryToText(IChatHistory history)
{
StringBuilder sb = new();
foreach (var message in history.Messages)
@@ -87,9 +88,9 @@ public virtual string HistoryToText(ChatHistory history)
}

/// <inheritdoc />
public virtual ChatHistory TextToHistory(AuthorRole role, string text)
public virtual IChatHistory TextToHistory(AuthorRole role, string text, Type type)
{
ChatHistory history = new ChatHistory();
IChatHistory history = (IChatHistory)(Activator.CreateInstance(type) ?? new ChatHistory());
history.AddMessage(role, TrimNamesFromText(text, role));
return history;
}

Unchanged files with check annotations Beta

using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 104 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 104 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 184 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 184 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 188 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 188 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;
if (_pastTokensCount > 0 && args.WaitForInput)
// Images
foreach( var image in _imageEmbedHandles )
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);

Check warning on line 235 in LLama/LLamaInteractExecutor.cs

GitHub Actions / Test (linux-release)

Dereference of a possibly null reference.
// Post-image Tokens
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}
}
/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 215 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}
LLamaToken id;
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public LLamaToken[] InputPrefixTokens { get; set; }

Check warning on line 275 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (linux-release)

Non-nullable property 'InputPrefixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public LLamaToken[] InputSuffixTokens { get; set; }

Check warning on line 280 in LLama/LLamaInstructExecutor.cs

GitHub Actions / Test (linux-release)

Non-nullable property 'InputSuffixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
}
}
}
public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
writer.WriteStartObject();
writer.WriteString("Name", value.GetType().Name);

Check warning on line 51 in LLama/Common/PolymorphicJSONConverter.cs

GitHub Actions / Test (windows-release)

Dereference of a possibly null reference.

Check warning on line 51 in LLama/Common/PolymorphicJSONConverter.cs

GitHub Actions / Test (windows-release)

Dereference of a possibly null reference.
writer.WritePropertyName("Data");
JsonSerializer.Serialize(writer, value, value.GetType(), options);
writer.WriteEndObject();
public string? SessionFilePath { get; set; }
[JsonPropertyName("embd")]
public LLamaToken[] Embeds { get; set; }

Check warning on line 422 in LLama/LLamaExecutorBase.cs

GitHub Actions / Test (windows-release)

Non-nullable property 'Embeds' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
[JsonPropertyName("embd_inps")]
public LLamaToken[] EmbedInps { get; set; }

Check warning on line 425 in LLama/LLamaExecutorBase.cs

GitHub Actions / Test (windows-release)

Non-nullable property 'EmbedInps' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
[JsonPropertyName("session_tokens")]
public LLamaToken[] SessionTokens { get; set; }