diff --git a/LLama.Rag/IWebScraper.cs b/LLama.Rag/IWebScraper.cs new file mode 100644 index 000000000..2d7b616e1 --- /dev/null +++ b/LLama.Rag/IWebScraper.cs @@ -0,0 +1,15 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using HtmlAgilityPack; + +namespace LLama.Rag +{ + public interface IWebScraper + { + HashSet VisitedUrls { get; } + List Documents { get; } + + Task> ExtractVisibleTextAsync(int minWordLength, bool checkSentences, bool explodeParagraphs); + Task> ExtractParagraphsAsync(); + } +} \ No newline at end of file diff --git a/LLama.Rag/LLama.Rag.csproj b/LLama.Rag/LLama.Rag.csproj new file mode 100644 index 000000000..dea41a939 --- /dev/null +++ b/LLama.Rag/LLama.Rag.csproj @@ -0,0 +1,14 @@ + + + + net8.0 + enable + enable + Exe + + + + + + + diff --git a/LLama.Rag/Rag.cs b/LLama.Rag/Rag.cs new file mode 100644 index 000000000..9200af756 --- /dev/null +++ b/LLama.Rag/Rag.cs @@ -0,0 +1,53 @@ +using System; +using System.Threading.Tasks; + +namespace LLama.Rag +{ + public class Rag + { + public static async Task Main(string[] args) + { + try + { + Console.WriteLine("Initializing WebScraper..."); + + string startUrl = "https://en.wikipedia.org/wiki/Aluminium_alloy"; + int depth = 0; // Scrape only the provided webpage and no links. + int minWordLength = 4; // Minimum word count for a text block to be extracted. + bool checkSentences = false; + bool explodeParagraphs = true; + + WebScraper webScraper = await WebScraper.CreateAsync(startUrl, depth); + + Console.WriteLine("WebScraper initialized successfully."); + Console.WriteLine("Extracting visible text..."); + + var documentText = webScraper.ExtractVisibleTextAsync(minWordLength, checkSentences, explodeParagraphs); + + Console.WriteLine($"Extracted {documentText.Result.Count} blocks of text."); + + if (documentText.Result.Count == 0) + { + Console.WriteLine("Warning: No text was extracted. Try lowering minWordLength or changing extraction settings."); + } + + foreach (string text in documentText.Result) + { + Console.WriteLine("Extracted Block:"); + Console.WriteLine(text); + Console.WriteLine(""); // Space between blocks for readability + } + + Console.WriteLine("Scraping complete."); + } + catch (Exception ex) + { + Console.WriteLine($"An error occurred: {ex.Message}"); + Console.WriteLine($"StackTrace: {ex.StackTrace}"); + } + + Console.WriteLine("Press any key to exit..."); + Console.ReadKey(); + } + } +} diff --git a/LLama.Rag/WebScraper.cs b/LLama.Rag/WebScraper.cs new file mode 100644 index 000000000..f575f0102 --- /dev/null +++ b/LLama.Rag/WebScraper.cs @@ -0,0 +1,144 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using System.Text.RegularExpressions; +using HtmlAgilityPack; +using System.Web; + +namespace LLama.Rag +{ + class WebScraper : IWebScraper + { + private static readonly HttpClient httpClient = new HttpClient(); + public HashSet VisitedUrls { get; } = new HashSet(); + public List Documents { get; } = new List(); + + private WebScraper() { } + + public static async Task CreateAsync(string url, int queryDepth) + { + WebScraper instance = new WebScraper(); + await instance.FetchContentAsynch(url, queryDepth); + return instance; + } + + private async Task FetchContentAsynch(string url, int queryDepth) + { + if (queryDepth < 0 || VisitedUrls.Contains(url)) return; + + try + { + VisitedUrls.Add(url); + string pageContent = await httpClient.GetStringAsync(url); + HtmlDocument doc = new HtmlDocument(); + doc.LoadHtml(pageContent); + Documents.Add(doc); + + if (queryDepth > 0) + { + var links = ExtractLinks(doc, url); + var tasks = links.Select(link => FetchContentAsynch(link, queryDepth - 1)); + await Task.WhenAll(tasks); + } + } + catch (Exception ex) + { + Console.WriteLine($"Error scraping {url}: {ex.Message}"); + } + } + + private static List ExtractLinks(HtmlDocument doc, string baseUrl) + { + return doc.DocumentNode + .SelectNodes("//body//a[@href]")? + .Select(node => node.GetAttributeValue("href", "")) + .Where(href => !string.IsNullOrEmpty(href)) + .Select(href => NormalizeUrl(href, baseUrl)) + .Where(link => link != null) + .Distinct() + .ToList() ?? new List(); + } + + private static string NormalizeUrl(string href, string baseUrl) + { + if (href.StartsWith("http", StringComparison.OrdinalIgnoreCase)) + return href; + + if (href.StartsWith("/")) + return new Uri(new Uri(baseUrl), href).ToString(); + + return null; + } + + public async Task> ExtractVisibleTextAsync(int minWordLength, bool checkSentences, bool explodeParagraphs) + { + return await Task.Run(() => + { + List allDocumentText = new List(); + foreach (HtmlDocument doc in Documents) + { + var currentDocText = doc.DocumentNode + .SelectNodes("//body//*[not(ancestor::table) and not(self::script or self::style)] | //body//a[not(self::script or self::style)]")? + .Select(node => + { + string cleanedText = HtmlEntity.DeEntitize(node.InnerText.Trim()); + cleanedText = cleanedText.Replace("\t", " "); + cleanedText = Regex.Replace(cleanedText, @"\s+", " "); + return cleanedText; + }) + .Where(text => !string.IsNullOrWhiteSpace(text) && text.Split(' ').Length >= minWordLength) + .ToList() ?? new List(); + + allDocumentText.AddRange(currentDocText); + } + + if (explodeParagraphs) allDocumentText = ExplodeParagraphs(allDocumentText, minWordLength); + if (checkSentences) allDocumentText = RudimentarySentenceCheck(allDocumentText); + return allDocumentText; + }); + } + + public async Task> ExtractParagraphsAsync() + { + return await Task.Run(() => + { + List paragraphs = new List(); + foreach (HtmlDocument doc in Documents) + { + var currentDocParagraph = doc.DocumentNode + .SelectNodes("//p//text()")? + .Select(node => HtmlEntity.DeEntitize(node.InnerText.Trim())) + .Where(text => !string.IsNullOrWhiteSpace(text)) + .ToList() ?? new List(); + + paragraphs.AddRange(currentDocParagraph); + } + return paragraphs; + }); + } + + private static List RudimentarySentenceCheck(List sentences) + { + List sentenceRules = new List + { + new Regex(@"^[A-Za-z0-9]+[\w\s,;:'""-]*", RegexOptions.Compiled | RegexOptions.IgnoreCase), + new Regex(@"[^\W]{2,}", RegexOptions.Compiled), + new Regex(@"\b(\w*:?[/\w\d]+\.){2,}\d+\b", RegexOptions.Compiled) + }; + + return sentences.Where(sentence => sentenceRules.All(regex => regex.IsMatch(sentence))).ToList(); + } + + private static List ExplodeParagraphs(List paragraphs, int minWordLength) + { + return paragraphs + .SelectMany(paragraph => + Regex.Matches(paragraph, @"(?() + .Select(m => m.Value.Trim())) + .ToList(); + } + } +} diff --git a/LLama.Unittest/ChatSessionTests.cs b/LLama.Unittest/ChatSessionTests.cs new file mode 100644 index 000000000..d15655717 --- /dev/null +++ b/LLama.Unittest/ChatSessionTests.cs @@ -0,0 +1,31 @@ +using LLama.Common; +using Xunit.Abstractions; + +namespace LLama.Unittest +{ + public sealed class ChatSessionTests + : IDisposable + { + private readonly ITestOutputHelper _testOutputHelper; + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public ChatSessionTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.GenerativeModelPath2) + { + ContextSize = 128, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + + } +} \ No newline at end of file diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs new file mode 100644 index 000000000..0b701b3a4 --- /dev/null +++ b/LLama.Unittest/GrammarTest.cs @@ -0,0 +1,63 @@ +using Xunit; + +namespace LLama.Sampling.Tests +{ + public class GrammarTests + { + [Fact] + public void Constructor_SetsPropertiesCorrectly() + { + // Arrange + var gbnf = "test_gbnf"; + var root = "test_root"; + + // Act + var grammar = new Grammar(gbnf, root); + + // Assert + Assert.Equal(gbnf, grammar.Gbnf); + Assert.Equal(root, grammar.Root); + } + + [Fact] + public void ToString_ReturnsExpectedString() + { + // Arrange + var gbnf = "test_gbnf"; + var root = "test_root"; + var grammar = new Grammar(gbnf, root); + + // Act + var toString = grammar.ToString(); + + // Assert + Assert.Equal($"Grammar {{ Gbnf = {gbnf}, Root = {root} }}", toString); + } + + [Fact] + public void Equality_ChecksPropertiesCorrectly() + { + // Arrange + var gbnf = "test_gbnf"; + var root = "test_root"; + var grammar1 = new Grammar(gbnf, root); + var grammar2 = new Grammar(gbnf, root); + + // Act and Assert + Assert.Equal(grammar1, grammar2); + } + + [Fact] + public void Inequality_ChecksPropertiesCorrectly() + { + // Arrange + var gbnf = "test_gbnf"; + var root = "test_root"; + var grammar1 = new Grammar(gbnf, root); + var grammar2 = new Grammar("different_gbnf", root); + + // Act and Assert + Assert.NotEqual(grammar1, grammar2); + } + } +} diff --git a/LLama.Unittest/GreedysamplingTests.cs b/LLama.Unittest/GreedysamplingTests.cs new file mode 100644 index 000000000..4be07e3c7 --- /dev/null +++ b/LLama.Unittest/GreedysamplingTests.cs @@ -0,0 +1,92 @@ +using Xunit; +using LLama.Sampling; +using LLama.Native; + +namespace LLama.Sampling.Tests +{ + public class TestableSafeLLamaSamplerChainHandle + { + private SafeLLamaSamplerChainHandle chain; + public List Grammars { get; set; } + + public TestableSafeLLamaSamplerChainHandle(SafeLLamaSamplerChainHandle chain) + { + this.chain = chain; + this.Grammars = new List(); + } + + public SafeLLamaSamplerChainHandle GetChain() + { + return chain; + } + } + + public class TestableGreedySamplingPipeline : GreedySamplingPipeline + { + public SafeLLamaSamplerChainHandle CreateChainForTest(SafeLLamaContextHandle context) + { + return base.CreateChain(context); + } + } + + public class GreedySamplingPipelineTests + { + [Fact] + public void CreateChain_WithoutGrammar_DoesNotAddGrammarToChain() + { + // Arrange + var model = new SafeLlamaModelHandle(); + var lparams = LLamaContextParams.Default(); + var context = SafeLLamaContextHandle.Create(model, lparams); + var pipeline = new TestableGreedySamplingPipeline(); + + // Act + var chain = pipeline.CreateChainForTest(context); + var testableChain = new TestableSafeLLamaSamplerChainHandle(chain) + { + Grammars = new List() + }; + + // Assert + Assert.Empty(testableChain.Grammars); + } + + [Fact] + public void Get_Grammar_ReturnsExpectedValue() + { + // Arrange + var expectedGrammar = new Grammar("test_gbnf", "test_root"); + var pipeline = new TestableGreedySamplingPipeline { Grammar = expectedGrammar }; + + // Act + var actualGrammar = pipeline.Grammar; + + // Assert + Assert.Equal(expectedGrammar, actualGrammar); + } + + [Fact] + public void Get_Grammar_ReturnsNullByDefault() + { + // Arrange + var pipeline = new TestableGreedySamplingPipeline(); + + // Act + var actualGrammar = pipeline.Grammar; + + // Assert + Assert.Null(actualGrammar); + } + + [Fact] + public void Set_Grammar_SetsExpectedValue() + { + // Arrange + var expectedGrammar = new Grammar("test_gbnf", "test_root"); + var pipeline = new TestableGreedySamplingPipeline { Grammar = expectedGrammar }; + + // Act and Assert + Assert.Equal(expectedGrammar, pipeline.Grammar); + } + } +} diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 5d7277e7a..6d926fc3f 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -1,4 +1,4 @@ - + net8.0 @@ -15,6 +15,7 @@ + diff --git a/LLama.Unittest/SamplingPipelineTests.cs b/LLama.Unittest/SamplingPipelineTests.cs new file mode 100644 index 000000000..a07659938 --- /dev/null +++ b/LLama.Unittest/SamplingPipelineTests.cs @@ -0,0 +1,106 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using LLama.Native; +using LLama.Sampling; +using Moq; + +namespace LLama.Unittest +{ + public class DefaultSamplingPipelineTests + { + [Fact] + public void FrequencyPenalty_ThrowsException_WhenValueIsLessThanMinusTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + FrequencyPenalty = -2.1f + }); + } + + [Fact] + public void FrequencyPenalty_ThrowsException_WhenValueIsGreaterThanTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + FrequencyPenalty = 2.1f + }); + } + + [Fact] + public void PresencePenalty_ThrowsException_WhenValueIsLessThanMinusTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + PresencePenalty = -2.1f + }); + } + + [Fact] + public void PresencePenalty_ThrowsException_WhenValueIsGreaterThanTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + PresencePenalty = 2.1f + }); + } + + [Fact] + public void DefaultValues_AreSetCorrectly() + { + var pipeline = new DefaultSamplingPipeline(); + + Assert.Equal(1, pipeline.RepeatPenalty); + Assert.Equal(0.75f, pipeline.Temperature); + Assert.Equal(40, pipeline.TopK); + Assert.Equal(1, pipeline.TypicalP); + Assert.Equal(0.9f, pipeline.TopP); + Assert.Equal(0.1f, pipeline.MinP); + Assert.Equal(64, pipeline.PenaltyCount); + Assert.False(pipeline.PenalizeNewline); + Assert.False(pipeline.PreventEOS); + } + + [Fact] + public void Seed_IsInitializedWithRandomValue() + { + // Arrange + var pipeline = new DefaultSamplingPipeline(); + + // Act + uint seed = pipeline.Seed; + + // Assert + Assert.InRange(seed, 0u, uint.MaxValue); + } + + + + // Example test for CreateChain method + //[Fact] + //public void CreateChain_CreatesSamplerChainCorrectly() + //{ + // // Arrange + // var contextMock = new Mock(); + // contextMock.Setup(c => c.Vocab.Count).Returns(100); + + // var pipeline = new DefaultSamplingPipeline + // { + // LogitBias = new Dictionary + // { + // { new LLamaToken(), 1.0f } + // }, + // Grammar = new Grammar("testGbnf", "root") + // }; + + // // Act + // var chain = pipeline.CreateChain(contextMock.Object); + + // // Assert + // // Add assertions here based on the behavior of CreateChain method + // Assert.NotNull(chain); + //} + } +} diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs index 615a7c79e..b684ed85a 100644 --- a/LLama.Unittest/SamplingTests.cs +++ b/LLama.Unittest/SamplingTests.cs @@ -1,5 +1,6 @@ using LLama.Common; using LLama.Native; +using LLama.Sampling; using System.Numerics.Tensors; using System.Text; @@ -177,5 +178,192 @@ private static SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle co return chain; } + + [Fact] + public void SamplingWithMockTopK() + { + // Manually create a mock logits array with a fixed, well-distributed set of values + var logits = new float[] + { + 0.56f, -0.85f, 0.74f, -0.33f, 0.92f, -0.44f, 0.61f, -0.77f, 0.18f, -0.29f, + 0.87f, -0.52f, 0.31f, -0.66f, 0.28f, -0.91f, 0.75f, -0.58f, 0.42f, -0.62f, + 0.39f, -0.48f, 0.94f, -0.72f, 0.53f, -0.15f, 0.68f, -0.41f, 0.81f, -0.35f, + 0.76f, -0.27f, 0.63f, -0.69f, 0.21f, -0.11f, 0.59f, -0.79f, 0.33f, -0.87f, + 0.46f, -0.53f, 0.71f, -0.23f, 0.66f, -0.39f, 0.29f, -0.65f, 0.83f, -0.49f, + 0.35f, -0.71f, 0.61f, -0.13f, 0.57f, -0.43f, 0.93f, -0.37f, 0.82f, -0.54f, + 0.44f, -0.22f, 0.88f, -0.46f, 0.72f, -0.18f, 0.64f, -0.55f, 0.95f, -0.33f, + 0.41f, -0.63f, 0.79f, -0.28f, 0.31f, -0.67f, 0.74f, -0.44f, 0.85f, -0.32f, + 0.54f, -0.16f, 0.66f, -0.38f, 0.73f, -0.49f, 0.36f, -0.79f, 0.61f, -0.24f, + 0.77f, -0.55f, 0.52f, -0.41f, 0.81f, -0.36f, 0.69f, -0.26f, 0.45f, -0.17f + }; + + // Mock LLamaTokenDataArray and LLamaTokenDataArrayNative + var array = LLamaTokenDataArray.Create(logits); + + // First sampling (TopK=5) + using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p); + using var chain5 = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + + chain5.AddTopK(5); + chain5.Apply(ref cur_p); + var top5 = new List(); + for (int i = 0; i < 5; i++) + { + top5.Add(cur_p.Data[i].Logit); + } + + // Second sampling (TopK=50) + using var _2 = LLamaTokenDataArrayNative.Create(array, out var cur_p_broader); + using var chain50 = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + + chain50.AddTopK(50); + chain50.Apply(ref cur_p_broader); + var top50 = new List(); + for (int i = 0; i < 50; i++) + { + top50.Add(cur_p_broader.Data[i].Logit); + } + + // Assert that the top 5 logits are present in the top 50 logits + Assert.True(top5.All(logit => top50.Contains(logit))); + } + + + + /// + /// test frequency penalty out of range exception when less than -2 + /// + [Fact] + public void FrequencyPenalty_ThrowsException_WhenValueIsLessThanMinusTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + FrequencyPenalty = -2.1f + }); + } + + + /// + /// test frequency penalty out of range exception when greater than 2 + /// + [Fact] + public void FrequencyPenalty_ThrowsException_WhenValueIsGreaterThanTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + FrequencyPenalty = 2.1f + }); + } + + /// + /// Test Argument out of range exception when presence penalty less than -2 + /// + [Fact] + public void PresencePenalty_ThrowsException_WhenValueIsLessThanMinusTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + PresencePenalty = -2.1f + }); + } + + /// + /// Test argument out of range exception when presence penalty is greater than 2 + /// + [Fact] + public void PresencePenalty_ThrowsException_WhenValueIsGreaterThanTwo() + { + Assert.Throws(() => new DefaultSamplingPipeline + { + PresencePenalty = 2.1f + }); + } + + /// + /// Test the default sampling pipeline defaults + /// + [Fact] + public void DefaultValues_AreSetCorrectly() + { + var pipeline = new DefaultSamplingPipeline(); + + Assert.Equal(1, pipeline.RepeatPenalty); + Assert.Equal(0.75f, pipeline.Temperature); + Assert.Equal(40, pipeline.TopK); + Assert.Equal(1, pipeline.TypicalP); + Assert.Equal(0.9f, pipeline.TopP); + Assert.Equal(0.1f, pipeline.MinP); + Assert.Equal(64, pipeline.PenaltyCount); + Assert.False(pipeline.PenalizeNewline); + Assert.False(pipeline.PreventEOS); + } + + [Fact] + public void Seed_HasLowProbabilityOfCollision() + { + var seedSet = new HashSet(); + const int numberOfInitializations = 1000; // Run the test 1000 times + const int maxAllowedDuplicates = 2; + + int duplicateCount = 0; + + for (int i = 0; i < numberOfInitializations; i++) + { + var pipeline = new DefaultSamplingPipeline(); + uint seed = pipeline.Seed; + if (!seedSet.Add(seed)) + { + duplicateCount++; + } + } + + // Assert that the number of duplicates is within the acceptable threshold + Assert.True(duplicateCount <= maxAllowedDuplicates, $"Too many duplicate seeds: {duplicateCount}"); + } + + + /// + /// test the pipeline seed with a specific value + /// + [Fact] + public void Seed_IsInitializedWithSpecificValue() + { + // Arrange + var pipeline = new DefaultSamplingPipeline(); + + // Act + uint seed = 32; + + // Assert + Assert.Equal(32, (float)seed); + } + /// + /// test minkeep with a specific value + /// + [Fact] + public void SetMinKeep() + { + // Arrange + var pipeline = new DefaultSamplingPipeline(); + + //Act + pipeline.MinKeep = 5; + + //Assert + Assert.Equal(5, pipeline.MinKeep); + } + + /// + /// test the minkeep default + /// + [Fact] + public void GetMinKeepDefault() + { + // Arrange + var pipeline = new DefaultSamplingPipeline(); + + //Assert + Assert.Equal(1, pipeline.MinKeep); + } } } diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs index d75a8d4b4..264301f4c 100644 --- a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs @@ -88,6 +88,18 @@ public void ChatRequestSettings_FromAIRequestSettings() Assert.Equal(originalRequestSettings.ModelId, requestSettings.ModelId); } + [Fact] + public void ChatRequestSettings_Null_ReturnsDefaultObject() + { + // Act + var requestSettings = LLamaSharpPromptExecutionSettings.FromRequestSettings(null); + + // Assert + Assert.NotNull(requestSettings); // Ensure it is NOT null + Assert.Null(requestSettings.MaxTokens); // Default behavior + } + + [Fact] public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInSnakeCase() { diff --git a/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs b/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs index 41e842737..33ab42897 100644 --- a/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs +++ b/LLama.Unittest/SemanticKernel/ExtensionMethodsTests.cs @@ -20,6 +20,20 @@ public void ToLLamaSharpChatHistory_StateUnderTest_ExpectedBehavior() Assert.NotNull(result); } + [Fact] + public void ToLLamaSharpChatHistory_NullChatHistory_ThrowsArgumentNullException() + { + // Arrange + Microsoft.SemanticKernel.ChatCompletion.ChatHistory chatHistory = null; + bool ignoreCase = true; + + // Act & Assert + var exception = Assert.Throws(() => + ExtensionMethods.ToLLamaSharpChatHistory(chatHistory, ignoreCase)); + + Assert.Equal("chatHistory", exception.ParamName); + } + [Fact] public void ToLLamaSharpInferenceParams_StateUnderTest_ExpectedBehavior() { @@ -33,5 +47,20 @@ public void ToLLamaSharpInferenceParams_StateUnderTest_ExpectedBehavior() // Assert Assert.NotNull(result); } + + [Fact] + public void ToLLamaSharpInferenceParams_NullRequestSettings_ThrowsArgumentNullException() + { + // Arrange + LLamaSharpPromptExecutionSettings requestSettings = null; + + // Act & Assert + var exception = Assert.Throws(() => + ExtensionMethods.ToLLamaSharpInferenceParams(requestSettings)); + + // Ensure the exception is thrown for the correct parameter + Assert.Equal("requestSettings", exception.ParamName); + } + } } diff --git a/LLama.Unittest/SemanticKernel/LLamaSharpChatCompletionTests.cs b/LLama.Unittest/SemanticKernel/LLamaSharpChatCompletionTests.cs index 0873a713b..41d08cc5b 100644 --- a/LLama.Unittest/SemanticKernel/LLamaSharpChatCompletionTests.cs +++ b/LLama.Unittest/SemanticKernel/LLamaSharpChatCompletionTests.cs @@ -3,6 +3,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Moq; +using Xunit; namespace LLama.Unittest.SemanticKernel { @@ -24,12 +25,63 @@ private LLamaSharpChatCompletion CreateLLamaSharpChatCompletion() null); } + [Fact] + public void CreateNewChat_NoInstructions_ReturnsEmptyChatHistory() + { + // Arrange + var unitUnderTest = this.CreateLLamaSharpChatCompletion(); + + // Act + var result = unitUnderTest.CreateNewChat(); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); // No system message should be added + } + + [Fact] + public void CreateNewChat_WithInstructions_AddsSystemMessage() + { + // Arrange + var unitUnderTest = this.CreateLLamaSharpChatCompletion(); + string instructions = "This is a system instruction"; + + // Act + var result = unitUnderTest.CreateNewChat(instructions); + + // Assert + Assert.NotNull(result); + Assert.Single(result); // One message should be present + Assert.Equal(instructions, result[0].Content); // System message should match the instructions + } + + [Fact] + public void CreateNewChat_NullInstructions_ReturnsEmptyChatHistory() + { + // Arrange + var unitUnderTest = this.CreateLLamaSharpChatCompletion(); + + // Act + var result = unitUnderTest.CreateNewChat(null); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); // Should not add a system message + } + [Fact] public async Task GetChatMessageContentsAsync_StateUnderTest_ExpectedBehavior() { // Arrange var unitUnderTest = this.CreateLLamaSharpChatCompletion(); ChatHistory chatHistory = new ChatHistory(); + + // Add to Chat History + chatHistory.AddMessage(new AuthorRole("User"), "Hello"); + chatHistory.AddMessage(new AuthorRole("User"), "World"); + chatHistory.AddMessage(new AuthorRole("User"), "Goodbye"); + chatHistory.AddMessage(new AuthorRole("InvalidRole"), "This should trigger Unknown role"); + PromptExecutionSettings? executionSettings = null; Kernel? kernel = null; CancellationToken cancellationToken = default; @@ -42,7 +94,7 @@ public async Task GetChatMessageContentsAsync_StateUnderTest_ExpectedBehavior() executionSettings, kernel, cancellationToken); - + // Assert Assert.True(result.Count > 0); } diff --git a/LLama.Unittest/SessionStateTests.cs b/LLama.Unittest/SessionStateTests.cs new file mode 100644 index 000000000..6f8384f6d --- /dev/null +++ b/LLama.Unittest/SessionStateTests.cs @@ -0,0 +1,31 @@ +using LLama.Common; +using Xunit.Abstractions; + +namespace LLama.Unittest +{ + public sealed class SessionStateTests + : IDisposable + { + private readonly ITestOutputHelper _testOutputHelper; + private readonly ModelParams _params; + private readonly LLamaWeights _model; + + public SessionStateTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.GenerativeModelPath2) + { + ContextSize = 128, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + + } +} \ No newline at end of file diff --git a/llama.cpp b/llama.cpp index 5783575c9..11b84eb45 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 5783575c9d99c4d9370495800663aa5397ceb0be +Subproject commit 11b84eb4578864827afcf956db5b571003f18180