diff --git a/CHANGELOG.md b/CHANGELOG.md index 35a2154fd..2d5d7388f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ ## Next +### Added + +- Document node is now always created when running SimpleKGPipeline, even if `from_pdf=False`. +- Document metadata is exposed in SimpleKGPipeline run method. + +### Fixed + +- LangChain Chat models compatibility is now working again. + + ## 1.10.0 ### Added @@ -15,12 +25,6 @@ - Fixed an edge case where the LLM can output a property with type 'map', which was causing errors during import as it is not a valid property type in Neo4j. -### Added - -- Document node is now always created when running SimpleKGPipeline, even if `from_pdf=False`. -- Document metadata is exposed in SimpleKGPipeline run method. - - ## 1.9.1 ### Fixed diff --git a/examples/README.md b/examples/README.md index 774739b32..6cd0e758b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -69,7 +69,7 @@ are listed in [the last section of this file](#customize). - [OpenAI (GPT)](./customize/llms/openai_llm.py) - [Azure OpenAI]() - [VertexAI (Gemini)](./customize/llms/vertexai_llm.py) -- [MistralAI](./customize/llms/mistalai_llm.py) +- [MistralAI](customize/llms/mistralai_llm.py) - [Cohere](./customize/llms/cohere_llm.py) - [Anthropic (Claude)](./customize/llms/anthropic_llm.py) - [Ollama](./customize/llms/ollama_llm.py) @@ -142,7 +142,7 @@ are listed in [the last section of this file](#customize). ### Answer: GraphRAG -- [LangChain compatibility](./customize/answer/langchain_compatiblity.py) +- [LangChain compatibility](customize/answer/langchain_compatibility.py) - [Use a custom prompt](./customize/answer/custom_prompt.py) diff --git a/examples/customize/answer/langchain_compatiblity.py b/examples/customize/answer/langchain_compatibility.py similarity index 100% rename from examples/customize/answer/langchain_compatiblity.py rename to examples/customize/answer/langchain_compatibility.py diff --git a/examples/customize/llms/anthropic_llm.py b/examples/customize/llms/anthropic_llm.py index 85c4ad03a..dbd3f56fd 100644 --- a/examples/customize/llms/anthropic_llm.py +++ b/examples/customize/llms/anthropic_llm.py @@ -1,12 +1,28 @@ from neo4j_graphrag.llm import AnthropicLLM, LLMResponse +from neo4j_graphrag.types import LLMMessage # set api key here on in the ANTHROPIC_API_KEY env var api_key = None +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + llm = AnthropicLLM( model_name="claude-3-opus-20240229", model_params={"max_tokens": 1000}, # max_tokens must be specified api_key=api_key, ) -res: LLMResponse = llm.invoke("say something") +res: LLMResponse = llm.invoke( + # "say something", + messages, +) print(res.content) diff --git a/examples/customize/llms/cohere_llm.py b/examples/customize/llms/cohere_llm.py index d631d3e41..daa3926ef 100644 --- a/examples/customize/llms/cohere_llm.py +++ b/examples/customize/llms/cohere_llm.py @@ -1,11 +1,23 @@ from neo4j_graphrag.llm import CohereLLM, LLMResponse +from neo4j_graphrag.types import LLMMessage # set api key here on in the CO_API_KEY env var api_key = None +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + llm = CohereLLM( model_name="command-r", api_key=api_key, ) -res: LLMResponse = llm.invoke("say something") +res: LLMResponse = llm.invoke(input=messages) print(res.content) diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 0eecfd878..953ba8467 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -1,6 +1,6 @@ import random import string -from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union +from typing import Any, Awaitable, Callable, Optional, TypeVar from neo4j_graphrag.llm import LLMInterface, LLMResponse from neo4j_graphrag.llm.rate_limit import ( @@ -8,7 +8,6 @@ # rate_limit_handler, # async_rate_limit_handler, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage @@ -18,38 +17,26 @@ def __init__( ): super().__init__(model_name, **kwargs) - # Optional: Apply rate limit handling to synchronous invoke method - # @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: content: str = ( self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30)) ) return LLMResponse(content=content) - # Optional: Apply rate limit handling to asynchronous ainvoke method - # @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: raise NotImplementedError() -llm = CustomLLM( - "" -) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff) +llm = CustomLLM("") res: LLMResponse = llm.invoke("text") print(res.content) -# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler -# Type variables for function signatures used in rate limit handlers F = TypeVar("F", bound=Callable[..., Any]) AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) diff --git a/examples/customize/llms/mistalai_llm.py b/examples/customize/llms/mistalai_llm.py deleted file mode 100644 index b829baad4..000000000 --- a/examples/customize/llms/mistalai_llm.py +++ /dev/null @@ -1,10 +0,0 @@ -from neo4j_graphrag.llm import MistralAILLM - -# set api key here on in the MISTRAL_API_KEY env var -api_key = None - -llm = MistralAILLM( - model_name="mistral-small-latest", - api_key=api_key, -) -llm.invoke("say something") diff --git a/examples/customize/llms/mistralai_llm.py b/examples/customize/llms/mistralai_llm.py new file mode 100644 index 000000000..66db280b1 --- /dev/null +++ b/examples/customize/llms/mistralai_llm.py @@ -0,0 +1,32 @@ +from neo4j_graphrag.llm import MistralAILLM, LLMResponse +from neo4j_graphrag.message_history import InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage + +# set api key here on in the MISTRAL_API_KEY env var +api_key = None + + +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + +llm = MistralAILLM( + model_name="mistral-small-latest", + api_key=api_key, +) +res: LLMResponse = llm.invoke( + # "say something", + # messages, + InMemoryMessageHistory( + messages=messages, + ) +) +print(res.content) diff --git a/examples/customize/llms/ollama_llm.py b/examples/customize/llms/ollama_llm.py index dc42f7466..37dd1dbec 100644 --- a/examples/customize/llms/ollama_llm.py +++ b/examples/customize/llms/ollama_llm.py @@ -3,11 +3,26 @@ """ from neo4j_graphrag.llm import LLMResponse, OllamaLLM +from neo4j_graphrag.types import LLMMessage + +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + llm = OllamaLLM( - model_name="", + model_name="orca-mini:latest", # model_params={"options": {"temperature": 0}, "format": "json"}, # host="...", # if using a remote server ) -res: LLMResponse = llm.invoke("What is the additive color model?") +res: LLMResponse = llm.invoke( + messages, +) print(res.content) diff --git a/examples/customize/llms/openai_llm.py b/examples/customize/llms/openai_llm.py index d4b38244e..501ccdb53 100644 --- a/examples/customize/llms/openai_llm.py +++ b/examples/customize/llms/openai_llm.py @@ -1,8 +1,28 @@ from neo4j_graphrag.llm import LLMResponse, OpenAILLM +from neo4j_graphrag.message_history import InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage # set api key here on in the OPENAI_API_KEY env var api_key = None +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + llm = OpenAILLM(model_name="gpt-4o", api_key=api_key) -res: LLMResponse = llm.invoke("say something") +res: LLMResponse = llm.invoke( + # "say something", + # messages, + InMemoryMessageHistory( + messages=messages, + ) +) print(res.content) diff --git a/examples/customize/llms/vertexai_llm.py b/examples/customize/llms/vertexai_llm.py index f43864935..34fc179ae 100644 --- a/examples/customize/llms/vertexai_llm.py +++ b/examples/customize/llms/vertexai_llm.py @@ -1,6 +1,20 @@ from neo4j_graphrag.llm import LLMResponse, VertexAILLM from vertexai.generative_models import GenerationConfig +from neo4j_graphrag.types import LLMMessage + +messages: list[LLMMessage] = [ + { + "role": "system", + "content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.", + }, + { + "role": "user", + "content": "say something", + }, +] + + generation_config = GenerationConfig(temperature=1.0) llm = VertexAILLM( model_name="gemini-2.0-flash-001", @@ -9,7 +23,6 @@ # vertexai.generative_models.GenerativeModel client ) res: LLMResponse = llm.invoke( - "say something", - system_instruction="You are living in 3000 where AI rules the world", + input=messages, ) print(res.content) diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 08f08a368..e79622dc3 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -27,6 +27,7 @@ from neo4j_graphrag.generation.prompts import RagTemplate from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.llm.utils import legacy_inputs_to_messages from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import LLMMessage, RetrieverResult @@ -145,12 +146,17 @@ def search( prompt = self.prompt_template.format( query_text=query_text, context=context, examples=validated_data.examples ) + + messages = legacy_inputs_to_messages( + prompt, + message_history=message_history, + system_instruction=self.prompt_template.system_instructions, + ) + logger.debug(f"RAG: retriever_result={prettify(retriever_result)}") logger.debug(f"RAG: prompt={prompt}") llm_response = self.llm.invoke( - prompt, - message_history, - system_instruction=self.prompt_template.system_instructions, + messages, ) answer = llm_response.content result: dict[str, Any] = {"answer": answer} @@ -168,9 +174,12 @@ def _build_query( summarization_prompt = self._chat_summary_prompt( message_history=message_history ) - summary = self.llm.invoke( - input=summarization_prompt, + messages = legacy_inputs_to_messages( + summarization_prompt, system_instruction=summary_system_message, + ) + summary = self.llm.invoke( + messages, ).content return self.conversation_prompt(summary=summary, current_query=query_text) return query_text diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index 3c4f65d9a..79360e57b 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -22,8 +22,6 @@ RateLimitHandler, NoOpRateLimitHandler, RetryRateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from .types import LLMResponse from .vertexai_llm import VertexAILLM @@ -42,6 +40,4 @@ "RateLimitHandler", "NoOpRateLimitHandler", "RetryRateLimitHandler", - "rate_limit_handler", - "async_rate_limit_handler", ] diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 6bafef85b..bfb0c9d5f 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,28 +13,21 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast - -from pydantic import ValidationError +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, - UserMessage, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage if TYPE_CHECKING: from anthropic.types.message_param import MessageParam + from anthropic import NotGiven class AnthropicLLM(LLMInterface): @@ -84,46 +77,39 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - ) -> Iterable[MessageParam]: - messages: list[dict[str, str]] = [] - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore - - @rate_limit_handler - def invoke( + input: list[LLMMessage], + ) -> tuple[Union[str, NotGiven], Iterable[MessageParam]]: + messages: list[MessageParam] = [] + system_instruction: Union[str, NotGiven] = self.anthropic.NOT_GIVEN + for i in input: + if i["role"] == "system": + system_instruction = i["content"] + else: + messages.append( + self.anthropic.types.MessageParam( + role=i["role"], + content=i["content"], + ) + ) + return system_instruction, messages + + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history) + system_instruction, messages = self.get_messages(input) response = self.client.messages.create( model=self.model_name, - system=system_instruction or self.anthropic.NOT_GIVEN, + system=system_instruction, messages=messages, **self.model_params, ) @@ -136,31 +122,23 @@ def invoke( except self.anthropic.APIError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history) + system_instruction, messages = self.get_messages(input) response = await self.async_client.messages.create( model=self.model_name, - system=system_instruction or self.anthropic.NOT_GIVEN, + system=system_instruction, messages=messages, **self.model_params, ) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index cca710bc9..8e77b9dc3 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,17 +17,23 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional, Sequence, Union +from pydantic import ValidationError + from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage +from .rate_limit import rate_limit_handler from .types import LLMResponse, ToolCallResponse from .rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, + async_rate_limit_handler, ) from neo4j_graphrag.tool import Tool from .rate_limit import RateLimitHandler +from .utils import legacy_inputs_to_messages +from ..exceptions import LLMGenerationError class LLMInterface(ABC): @@ -55,20 +61,30 @@ def __init__( else: self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER - @abstractmethod + @rate_limit_handler def invoke( self, - input: str, + input: Union[str, List[LLMMessage], MessageHistory], message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, + ) -> LLMResponse: + try: + messages = legacy_inputs_to_messages( + input, message_history, system_instruction + ) + except ValidationError as e: + raise LLMGenerationError("Input validation failed") from e + return self._invoke(messages) + + @abstractmethod + def _invoke( + self, + input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. Args: - input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. + input (MessageHistory): Text sent to the LLM. Returns: LLMResponse: The response from the LLM. @@ -77,20 +93,25 @@ def invoke( LLMGenerationError: If anything goes wrong. """ - @abstractmethod + @async_rate_limit_handler async def ainvoke( self, - input: str, + input: Union[str, List[LLMMessage], MessageHistory], message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, + ) -> LLMResponse: + messages = legacy_inputs_to_messages(input, message_history, system_instruction) + return await self._ainvoke(messages) + + @abstractmethod + async def _ainvoke( + self, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. @@ -99,6 +120,7 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ + @rate_limit_handler def invoke_with_tools( self, input: str, @@ -124,6 +146,17 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ + try: + messages = legacy_inputs_to_messages( + input, message_history, system_instruction + ) + except ValidationError as e: + raise LLMGenerationError("Input validation failed") from e + return self._invoke_with_tools(messages, tools) + + def _invoke_with_tools( + self, inputs: list[LLMMessage], tools: Sequence[Tool] + ) -> ToolCallResponse: raise NotImplementedError("This LLM provider does not support tool calling.") async def ainvoke_with_tools( @@ -151,4 +184,10 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ + messages = legacy_inputs_to_messages(input, message_history, system_instruction) + return await self._ainvoke_with_tools(messages, tools) + + async def _ainvoke_with_tools( + self, inputs: list[LLMMessage], tools: Sequence[Tool] + ) -> ToolCallResponse: raise NotImplementedError("This LLM provider does not support tool calling.") diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 7c3905500..c8d51c93a 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,25 +14,16 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast - -from pydantic import ValidationError +from typing import TYPE_CHECKING, Any, Optional from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, - SystemMessage, - UserMessage, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage if TYPE_CHECKING: @@ -84,46 +75,34 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> ChatMessages: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore - - @rate_limit_handler - def invoke( + messages: ChatMessages = [] + for i in input: + if i["role"] == "system": + messages.append(self.cohere.SystemChatMessageV2(content=i["content"])) + if i["role"] == "user": + messages.append(self.cohere.UserChatMessageV2(content=i["content"])) + if i["role"] == "assistant": + messages.append( + self.cohere.AssistantChatMessageV2(content=i["content"]) + ) + return messages + + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) res = self.client.chat( messages=messages, model=self.model_name, @@ -134,28 +113,20 @@ def invoke( content=res.message.content[0].text if res.message.content else "", ) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) res = await self.async_client.chat( messages=messages, model=self.model_name, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index ae2a6312f..93f54146d 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -15,33 +15,31 @@ from __future__ import annotations import os -from typing import Any, Iterable, List, Optional, Union, cast - -from pydantic import ValidationError +from typing import Any, Optional from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, - SystemMessage, - UserMessage, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage try: - from mistralai import Messages, Mistral + from mistralai import ( + Messages, + UserMessage, + AssistantMessage, + SystemMessage, + Mistral, + ) from mistralai.models.sdkerror import SDKError except ImportError: Mistral = None # type: ignore SDKError = None # type: ignore + Messages = None # type: ignore class MistralAILLM(LLMInterface): @@ -75,38 +73,30 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> list[Messages]: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return cast(list[Messages], messages) - - @rate_limit_handler - def invoke( + messages: list[Messages] = [] + for m in input: + if m["role"] == "system": + messages.append(SystemMessage(content=m["content"])) + continue + if m["role"] == "user": + messages.append(UserMessage(content=m["content"])) + continue + if m["role"] == "assistant": + messages.append(AssistantMessage(content=m["content"])) + continue + return messages + + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the Mistral chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -115,9 +105,7 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) response = self.client.chat.complete( model=self.model_name, messages=messages, @@ -132,21 +120,15 @@ def invoke( except SDKError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from MistralAI. @@ -155,9 +137,7 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - messages = self.get_messages(input, message_history, system_instruction) + messages = self.get_messages(input) response = await self.client.chat.complete_async( model=self.model_name, messages=messages, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 214640625..ab7d77874 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -15,22 +15,15 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast - -from pydantic import ValidationError +from typing import TYPE_CHECKING, Any, Optional, Sequence from neo4j_graphrag.exceptions import LLMGenerationError -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from .base import LLMInterface -from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler +from .rate_limit import RateLimitHandler from .types import ( - BaseMessage, LLMResponse, - MessageList, - SystemMessage, - UserMessage, ) if TYPE_CHECKING: @@ -76,48 +69,26 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> Sequence[Message]: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore + return [self.ollama.Message(**i) for i in input] - @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = self.client.chat( model=self.model_name, - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), **self.model_params, ) content = response.message.content or "" @@ -125,21 +96,15 @@ def invoke( except self.ollama.ResponseError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -148,11 +113,9 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = await self.async_client.chat( model=self.model_name, - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), options=self.model_params, ) content = response.message.content or "" diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index d74c83dcf..936f8a355 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -24,26 +24,18 @@ Optional, Iterable, Sequence, - Union, cast, + Type, ) -from pydantic import ValidationError - -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from ..exceptions import LLMGenerationError from .base import LLMInterface -from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler from .types import ( - BaseMessage, LLMResponse, - MessageList, ToolCall, ToolCallResponse, - SystemMessage, - UserMessage, ) from neo4j_graphrag.tool import Tool @@ -54,11 +46,13 @@ ChatCompletionToolParam, ) from openai import OpenAI, AsyncOpenAI + from .rate_limit import RateLimitHandler else: ChatCompletionMessageParam = Any ChatCompletionToolParam = Any OpenAI = Any AsyncOpenAI = Any + RateLimitHandler = Any class BaseOpenAILLM(LLMInterface, abc.ABC): @@ -93,23 +87,28 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + messages: list[LLMMessage], ) -> Iterable[ChatCompletionMessageParam]: - messages = [] - if system_instruction: - messages.append(SystemMessage(content=system_instruction).model_dump()) - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - messages.extend(cast(Iterable[dict[str, Any]], message_history)) - messages.append(UserMessage(content=input).model_dump()) - return messages # type: ignore + chat_messages = [] + for m in messages: + message_type: Type[ChatCompletionMessageParam] + if m["role"] == "system": + message_type = self.openai.types.chat.ChatCompletionSystemMessageParam + elif m["role"] == "user": + message_type = self.openai.types.chat.ChatCompletionUserMessageParam + elif m["role"] == "assistant": + message_type = ( + self.openai.types.chat.ChatCompletionAssistantMessageParam + ) + else: + raise ValueError(f"Unknown role: {m['role']}") + chat_messages.append( + message_type( + role=m["role"], # type: ignore + content=m["content"], + ) + ) + return chat_messages def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: """Convert a Tool object to OpenAI's expected format. @@ -132,21 +131,15 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: except AttributeError: raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") - @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -155,10 +148,8 @@ def invoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = self.client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, **self.model_params, ) @@ -167,13 +158,10 @@ def invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - @rate_limit_handler - def invoke_with_tools( + def _invoke_with_tools( self, - input: str, - tools: Sequence[Tool], # Tools definition as a sequence of Tool objects - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], + tools: Sequence[Tool], ) -> ToolCallResponse: """Sends a text input to the OpenAI chat completion model with tool definitions and retrieves a tool call response. @@ -181,9 +169,6 @@ def invoke_with_tools( Args: input (str): Text sent to the LLM. tools (List[Tool]): List of Tools for the LLM to choose from. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: ToolCallResponse: The response from the LLM containing a tool call. @@ -192,9 +177,6 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - params = self.model_params.copy() if self.model_params else {} if "temperature" not in params: params["temperature"] = 0.0 @@ -206,7 +188,7 @@ def invoke_with_tools( openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) response = self.client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, tools=openai_tools, tool_choice="auto", @@ -242,21 +224,15 @@ def invoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat completion model and returns the response's content. Args: input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from OpenAI. @@ -265,10 +241,8 @@ async def ainvoke( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages response = await self.async_client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, **self.model_params, ) @@ -277,13 +251,10 @@ async def ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - @async_rate_limit_handler - async def ainvoke_with_tools( + async def _ainvoke_with_tools( self, - input: str, + input: list[LLMMessage], tools: Sequence[Tool], # Tools definition as a sequence of Tool objects - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, ) -> ToolCallResponse: """Asynchronously sends a text input to the OpenAI chat completion model with tool definitions and retrieves a tool call response. @@ -291,9 +262,6 @@ async def ainvoke_with_tools( Args: input (str): Text sent to the LLM. tools (List[Tool]): List of Tools for the LLM to choose from. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: ToolCallResponse: The response from the LLM containing a tool call. @@ -302,9 +270,6 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - params = self.model_params.copy() if "temperature" not in params: params["temperature"] = 0.0 @@ -316,7 +281,7 @@ async def ainvoke_with_tools( openai_tools.append(cast(ChatCompletionToolParam, openai_format_tool)) response = await self.async_client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), + messages=self.get_messages(input), model=self.model_name, tools=openai_tools, tool_choice="auto", diff --git a/src/neo4j_graphrag/llm/utils.py b/src/neo4j_graphrag/llm/utils.py new file mode 100644 index 000000000..b61a880f4 --- /dev/null +++ b/src/neo4j_graphrag/llm/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Union, Optional + +from pydantic import TypeAdapter + +from neo4j_graphrag.message_history import MessageHistory +from neo4j_graphrag.types import LLMMessage + + +def system_instruction_from_messages(messages: list[LLMMessage]) -> str | None: + for message in messages: + if message["role"] == "system": + return message["content"] + return None + + +llm_messages_adapter = TypeAdapter(list[LLMMessage]) + + +def legacy_inputs_to_messages( + input: Union[str, list[LLMMessage], MessageHistory], + message_history: Optional[Union[list[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, +) -> list[LLMMessage]: + if message_history: + if isinstance(message_history, MessageHistory): + messages = message_history.messages + else: # list[LLMMessage] + messages = llm_messages_adapter.validate_python(message_history) + else: + messages = [] + if system_instruction is not None: + if system_instruction_from_messages(messages) is not None: + warnings.warn( + "system_instruction provided but ignored as the message history already contains a system message", + UserWarning, + ) + else: + messages.insert( + 0, + LLMMessage( + role="system", + content=system_instruction, + ), + ) + + if isinstance(input, str): + messages.append(LLMMessage(role="user", content=input)) + return messages + if isinstance(input, list): + messages.extend(input) + return messages + # input is a MessageHistory instance + messages.extend(input.messages) + return messages diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 0b4926978..d5ca1c3c8 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,25 +13,19 @@ # limitations under the License. from __future__ import annotations -from typing import Any, List, Optional, Union, cast, Sequence +from typing import Any, Optional, Sequence -from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface from neo4j_graphrag.llm.rate_limit import ( RateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, ) from neo4j_graphrag.llm.types import ( - BaseMessage, LLMResponse, - MessageList, ToolCall, ToolCallResponse, ) -from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage @@ -98,92 +92,73 @@ def __init__( def get_messages( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - ) -> list[Content]: + input: list[LLMMessage], + ) -> tuple[str | None, list[Content]]: messages = [] - if message_history: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - try: - MessageList(messages=cast(list[BaseMessage], message_history)) - except ValidationError as e: - raise LLMGenerationError(e.errors()) from e - - for message in message_history: - if message.get("role") == "user": - messages.append( - Content( - role="user", - parts=[Part.from_text(message.get("content", ""))], - ) + system_instruction = self.system_instruction + for message in input: + if message.get("role") == "system": + system_instruction = message.get("content") + continue + if message.get("role") == "user": + messages.append( + Content( + role="user", + parts=[Part.from_text(message.get("content", ""))], ) - elif message.get("role") == "assistant": - messages.append( - Content( - role="model", - parts=[Part.from_text(message.get("content", ""))], - ) + ) + continue + if message.get("role") == "assistant": + messages.append( + Content( + role="model", + parts=[Part.from_text(message.get("content", ""))], ) + ) + continue + return system_instruction, messages - messages.append(Content(role="user", parts=[Part.from_text(input)])) - return messages - - @rate_limit_handler - def invoke( + def _invoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ + system_instruction, messages = self.get_messages(input) model = self._get_model( system_instruction=system_instruction, ) try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - options = self._get_call_params(input, message_history, tools=None) + options = self._get_call_params(messages, tools=None) response = model.generate_content(**options) return self._parse_content_response(response) except ResponseValidationError as e: raise LLMGenerationError("Error calling VertexAILLM") from e - @async_rate_limit_handler - async def ainvoke( + async def _ainvoke( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: input (str): The text to send to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. Returns: LLMResponse: The response from the LLM. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages + system_instruction, messages = self.get_messages(input) model = self._get_model( system_instruction=system_instruction, ) - options = self._get_call_params(input, message_history, tools=None) + options = self._get_call_params(messages, tools=None) response = await model.generate_content_async(**options) return self._parse_content_response(response) except ResponseValidationError as e: @@ -213,7 +188,6 @@ def _get_model( self, system_instruction: Optional[str] = None, ) -> GenerativeModel: - # system_message = [system_instruction] if system_instruction is not None else [] model = GenerativeModel( model_name=self.model_name, system_instruction=system_instruction, @@ -222,8 +196,7 @@ def _get_model( def _get_call_params( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]], + contents: list[Content], tools: Optional[Sequence[Tool]], ) -> dict[str, Any]: options = dict(self.options) @@ -240,32 +213,28 @@ def _get_call_params( else: # no tools, remove tool_config if defined options.pop("tool_config", None) - - messages = self.get_messages(input, message_history) - options["contents"] = messages + options["contents"] = contents return options async def _acall_llm( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction) - options = self._get_call_params(input, message_history, tools) + system_instruction, contents = self.get_messages(input) + model = self._get_model(system_instruction) + options = self._get_call_params(contents, tools) response = await model.generate_content_async(**options) return response # type: ignore[no-any-return] def _call_llm( self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, + input: list[LLMMessage], tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction) - options = self._get_call_params(input, message_history, tools) + system_instruction, contents = self.get_messages(input) + model = self._get_model(system_instruction) + options = self._get_call_params(contents, tools) response = model.generate_content(**options) return response # type: ignore[no-any-return] @@ -287,32 +256,24 @@ def _parse_content_response(self, response: GenerationResponse) -> LLMResponse: content=response.text, ) - async def ainvoke_with_tools( + async def _ainvoke_with_tools( self, - input: str, + input: list[LLMMessage], tools: Sequence[Tool], - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, ) -> ToolCallResponse: response = await self._acall_llm( input, - message_history=message_history, - system_instruction=system_instruction, tools=tools, ) return self._parse_tool_response(response) - def invoke_with_tools( + def _invoke_with_tools( self, - input: str, + input: list[LLMMessage], tools: Sequence[Tool], - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, ) -> ToolCallResponse: response = self._call_llm( input, - message_history=message_history, - system_instruction=system_instruction, tools=tools, ) return self._parse_tool_response(response) diff --git a/src/neo4j_graphrag/message_history.py b/src/neo4j_graphrag/message_history.py index 59ba033d9..f4df4576f 100644 --- a/src/neo4j_graphrag/message_history.py +++ b/src/neo4j_graphrag/message_history.py @@ -74,6 +74,9 @@ class MessageHistory(ABC): @abstractmethod def messages(self) -> List[LLMMessage]: ... + def is_empty(self) -> bool: + return len(self.messages) == 0 + @abstractmethod def add_message(self, message: LLMMessage) -> None: ... diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index 029d75778..326f027d1 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -19,9 +19,11 @@ import anthropic import pytest -from neo4j_graphrag.exceptions import LLMGenerationError +from anthropic import NOT_GIVEN, NotGiven + +from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.anthropic_llm import AnthropicLLM -from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.types import LLMMessage @pytest.fixture @@ -40,132 +42,65 @@ def test_anthropic_llm_missing_dependency(mock_import: Mock) -> None: AnthropicLLM(model_name="claude-3-opus-20240229") -def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params) - input_text = "may thy knife chip and shatter" - response = llm.invoke(input_text) - assert response.content == "generated text" - llm.client.messages.create.assert_called_once_with( # type: ignore - messages=[{"role": "user", "content": input_text}], - model="claude-3-opus-20240229", - system=anthropic.NOT_GIVEN, - **model_params, - ) - - -def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - ) +def test_anthropic_llm_get_messages_with_system_instructions() -> None: + llm = AnthropicLLM(api_key="my key", model_name="claude") message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, + LLMMessage(**{"role": "system", "content": "do something"}), + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), ] - question = "What about next season?" - - response = llm.invoke(question, message_history) # type: ignore - assert response.content == "generated text" - message_history.append({"role": "user", "content": question}) - llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] - messages=message_history, - model="claude-3-opus-20240229", - system=anthropic.NOT_GIVEN, - **model_params, - ) + system_instruction, messages = llm.get_messages(message_history) + assert isinstance(system_instruction, str) + assert system_instruction == "do something" + assert isinstance(messages, list) + assert len(messages) == 2 # exclude system instruction + for actual, expected in zip(messages, message_history[1:]): + assert isinstance(actual, dict) + assert actual["role"] == expected["role"] + assert actual["content"] == expected["content"] -def test_anthropic_invoke_with_system_instruction( - mock_anthropic: Mock, -) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - ) - question = "When does it come up in the winter?" - response = llm.invoke(question, system_instruction=system_instruction) - assert isinstance(response, LLMResponse) - assert response.content == "generated text" - messages = [{"role": "user", "content": question}] - llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] - model="claude-3-opus-20240229", - system=system_instruction, - messages=messages, - **model_params, - ) +def test_anthropic_llm_get_messages_without_system_instructions() -> None: + llm = AnthropicLLM(api_key="my key", model_name="claude") + message_history = [ + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + ] - assert llm.client.messages.create.call_count == 1 # type: ignore + system_instruction, messages = llm.get_messages(message_history) + assert isinstance(system_instruction, NotGiven) + assert system_instruction == NOT_GIVEN + assert isinstance(messages, list) + assert len(messages) == 2 + for actual, expected in zip(messages, message_history): + assert isinstance(actual, dict) + assert actual["role"] == expected["role"] + assert actual["content"] == expected["content"] -def test_anthropic_invoke_with_message_history_and_system_instruction( - mock_anthropic: Mock, -) -> None: +def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( content=[MagicMock(text="generated text")] ) + mock_anthropic.types.MessageParam.return_value = {"role": "user", "content": "hi"} model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - - question = "When does it come up in the winter?" - response = llm.invoke(question, message_history, system_instruction) # type: ignore + llm = AnthropicLLM("claude-3-opus-20240229", model_params=model_params) + input_text = "may thy knife chip and shatter" + response = llm.invoke(input_text) assert isinstance(response, LLMResponse) assert response.content == "generated text" - message_history.append({"role": "user", "content": question}) - llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] + llm.client.messages.create.assert_called_once_with( # type: ignore + messages=[{"role": "user", "content": "hi"}], model="claude-3-opus-20240229", - system=system_instruction, - messages=message_history, + system=anthropic.NOT_GIVEN, **model_params, ) - assert llm.client.messages.create.call_count == 1 # type: ignore - - -def test_anthropic_invoke_with_message_history_validation_error( - mock_anthropic: Mock, -) -> None: - mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( - content=[MagicMock(text="generated text")] - ) - model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." - llm = AnthropicLLM( - "claude-3-opus-20240229", - model_params=model_params, - system_instruction=system_instruction, - ) - message_history = [ - {"role": "human", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) - @pytest.mark.asyncio async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: @@ -173,14 +108,16 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: mock_response.content = [MagicMock(text="Return text")] mock_model = mock_anthropic.AsyncAnthropic.return_value mock_model.messages.create = AsyncMock(return_value=mock_response) + mock_anthropic.types.MessageParam.return_value = {"role": "user", "content": "hi"} model_params = {"temperature": 0.3} llm = AnthropicLLM("claude-3-opus-20240229", model_params) input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) + assert isinstance(response, LLMResponse) assert response.content == "Return text" llm.async_client.messages.create.assert_awaited_once_with( # type: ignore model="claude-3-opus-20240229", system=anthropic.NOT_GIVEN, - messages=[{"role": "user", "content": input_text}], + messages=[{"role": "user", "content": "hi"}], **model_params, ) diff --git a/tests/unit/llm/test_base.py b/tests/unit/llm/test_base.py new file mode 100644 index 000000000..9a927b193 --- /dev/null +++ b/tests/unit/llm/test_base.py @@ -0,0 +1,136 @@ +from typing import Type, Generator +from unittest.mock import patch, Mock + +import pytest +from joblib.testing import fixture +from pydantic import ValidationError + +from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.types import LLMMessage + + +@fixture(scope="module") # type: ignore[misc] +def llm_interface() -> Generator[Type[LLMInterface], None, None]: + real_abstract_methods = LLMInterface.__abstractmethods__ + LLMInterface.__abstractmethods__ = frozenset() + + class CustomLLMInterface(LLMInterface): + pass + + yield CustomLLMInterface + + LLMInterface.__abstractmethods__ = real_abstract_methods + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_input_as_str( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.return_value = [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] + llm = llm_interface(model_name="test") + message_history = [ + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + ] + question = "What about next season?" + system_instruction = "You are a genius." + + with patch.object(llm, "_invoke") as mock_invoke: + llm.invoke(question, message_history, system_instruction) + mock_invoke.assert_called_once_with( + [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] + ) + mock_inputs.assert_called_once_with( + question, + message_history, + system_instruction, + ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_invalid_inputs( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.side_effect = [ + ValidationError.from_exception_data("Invalid data", line_errors=[]) + ] + llm = llm_interface(model_name="test") + question = "What about next season?" + + with pytest.raises(LLMGenerationError, match="Input validation failed"): + llm.invoke(question) + mock_inputs.assert_called_once_with( + question, + None, + None, + ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_tools_with_input_as_str( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.return_value = [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ] + llm = llm_interface(model_name="test") + message_history = [ + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), + ] + question = "What about next season?" + system_instruction = "You are a genius." + + with patch.object(llm, "_invoke_with_tools") as mock_invoke: + llm.invoke_with_tools(question, [], message_history, system_instruction) + mock_invoke.assert_called_once_with( + [ + LLMMessage( + role="user", + content="return value of the legacy_inputs_to_messages function", + ) + ], + [], # tools + ) + mock_inputs.assert_called_once_with( + question, + message_history, + system_instruction, + ) + + +@patch("neo4j_graphrag.llm.base.legacy_inputs_to_messages") +def test_base_llm_interface_invoke_with_tools_with_invalid_inputs( + mock_inputs: Mock, llm_interface: Type[LLMInterface] +) -> None: + mock_inputs.side_effect = [ + ValidationError.from_exception_data("Invalid data", line_errors=[]) + ] + llm = llm_interface(model_name="test") + question = "What about next season?" + + with pytest.raises(LLMGenerationError, match="Input validation failed"): + llm.invoke_with_tools(question, []) + mock_inputs.assert_called_once_with( + question, + None, + None, + ) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index 10a02ec86..c3b43dbc4 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -41,86 +41,17 @@ def test_cohere_llm_happy_path(mock_cohere: Mock) -> None: chat_response_mock = MagicMock() chat_response_mock.message.content = [MagicMock(text="cohere response text")] mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock + mock_cohere.UserChatMessageV2.return_value = {"role": "user", "content": "test"} llm = CohereLLM(model_name="something") res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "cohere response text" - - -def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> None: - chat_response_mock = MagicMock() - chat_response_mock.message.content = [MagicMock(text="cohere response text")] - mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat - mock_cohere_client_chat.return_value = chat_response_mock - - system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="something") - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "cohere response text" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - mock_cohere_client_chat.assert_called_once_with( - messages=messages, + mock_cohere.ClientV2.return_value.chat.assert_called_once_with( + messages=[{"role": "user", "content": "test"}], model="something", ) -def test_cohere_llm_invoke_with_message_history_and_system_instruction( - mock_cohere: Mock, -) -> None: - chat_response_mock = MagicMock() - chat_response_mock.message.content = [MagicMock(text="cohere response text")] - mock_cohere_client_chat = mock_cohere.ClientV2.return_value.chat - mock_cohere_client_chat.return_value = chat_response_mock - - system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="gpt") - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "cohere response text" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - mock_cohere_client_chat.assert_called_once_with( - messages=messages, - model="gpt", - ) - - -def test_cohere_llm_invoke_with_message_history_validation_error( - mock_cohere: Mock, -) -> None: - chat_response_mock = MagicMock() - chat_response_mock.message.content = [MagicMock(text="cohere response text")] - mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock - - system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="something", system_instruction=system_instruction) - message_history = [ - {"role": "robot", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) - - @pytest.mark.asyncio async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: chat_response_mock = MagicMock( @@ -139,9 +70,8 @@ async def test_cohere_llm_happy_path_async(mock_cohere: Mock) -> None: def test_cohere_llm_failed(mock_cohere: Mock) -> None: mock_cohere.ClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError) as excinfo: + with pytest.raises(LLMGenerationError, match="ApiError"): llm.invoke("my text") - assert "ApiError" in str(excinfo) @pytest.mark.asyncio @@ -149,6 +79,5 @@ async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None: mock_cohere.AsyncClientV2.return_value.chat.side_effect = cohere.core.ApiError llm = CohereLLM(model_name="something") - with pytest.raises(LLMGenerationError) as excinfo: + with pytest.raises(LLMGenerationError, match="ApiError"): await llm.ainvoke("my text") - assert "ApiError" in str(excinfo) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index c1d3f9fdc..6ecb9fb13 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -41,6 +41,7 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) + mock_ollama.Message.return_value = {"role": "user", "content": "test"} model = "gpt" model_params = {"temperature": 0.3} with pytest.warns(DeprecationWarning) as record: @@ -59,11 +60,10 @@ def test_ollama_llm_happy_path_deprecated_options(mock_import: Mock) -> None: res = llm.invoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" - messages = [ - {"role": "user", "content": question}, - ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options={"temperature": 0.3} + model=model, + messages=[{"role": "user", "content": "test"}], + options={"temperature": 0.3}, ) @@ -90,6 +90,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) + mock_ollama.Message.return_value = {"role": "user", "content": "test"} model = "gpt" options = {"temperature": 0.3} model_params = {"options": options, "format": "json"} @@ -102,7 +103,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" messages = [ - {"role": "user", "content": question}, + {"role": "user", "content": "test"}, ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, @@ -112,102 +113,6 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: ) -@patch("builtins.__import__") -def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama - mock_ollama.Client.return_value.chat.return_value = MagicMock( - message=MagicMock(content="ollama chat response"), - ) - model = "gpt" - options = {"temperature": 0.3} - model_params = {"options": options, "format": "json"} - llm = OllamaLLM( - model, - model_params=model_params, - ) - system_instruction = "You are a helpful assistant." - question = "What about next season?" - - response = llm.invoke(question, system_instruction=system_instruction) - assert response.content == "ollama chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, - messages=messages, - options=options, - format="json", - ) - - -@patch("builtins.__import__") -def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama - mock_ollama.Client.return_value.chat.return_value = MagicMock( - message=MagicMock(content="ollama chat response"), - ) - model = "gpt" - options = {"temperature": 0.3} - model_params = {"options": options} - llm = OllamaLLM( - model, - model_params=model_params, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - response = llm.invoke(question, message_history) # type: ignore - assert response.content == "ollama chat response" - messages = [m for m in message_history] - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options=options - ) - - -@patch("builtins.__import__") -def test_ollama_invoke_with_message_history_and_system_instruction( - mock_import: Mock, -) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama - mock_ollama.Client.return_value.chat.return_value = MagicMock( - message=MagicMock(content="ollama chat response"), - ) - model = "gpt" - options = {"temperature": 0.3} - model_params = {"options": options} - system_instruction = "You are a helpful assistant." - llm = OllamaLLM( - model, - model_params=model_params, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - response = llm.invoke( - question, - message_history, # type: ignore - system_instruction=system_instruction, - ) - assert response.content == "ollama chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options=options - ) - assert llm.client.chat.call_count == 1 # type: ignore - - @patch("builtins.__import__") def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) -> None: mock_ollama = get_mock_ollama() @@ -228,9 +133,8 @@ def test_ollama_invoke_with_message_history_validation_error(mock_import: Mock) ] question = "What about next season?" - with pytest.raises(LLMGenerationError) as exc_info: + with pytest.raises(LLMGenerationError, match="Input validation failed"): llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) @pytest.mark.asyncio diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 3c5ee1b9e..2b2cb29f2 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -21,6 +21,7 @@ from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.tool import Tool +from neo4j_graphrag.types import LLMMessage def get_mock_openai() -> MagicMock: @@ -36,7 +37,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_openai_llm_happy_path(mock_import: Mock) -> None: +def test_openai_llm_happy_path_e2e(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( @@ -49,89 +50,31 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None: assert res.content == "openai chat response" -@patch("builtins.__import__") -def test_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) +def test_openai_llm_get_messages() -> None: llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - message_history.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == message_history - assert call_args["model"] == "gpt" - - -@patch("builtins.__import__") -def test_openai_llm_with_message_history_and_system_instruction( - mock_import: Mock, -) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) - system_instruction = "You are a helpful assistent." - llm = OpenAILLM( - api_key="my key", - model_name="gpt", - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, + LLMMessage(**{"role": "system", "content": "do something"}), + LLMMessage( + **{"role": "user", "content": "When does the sun come up in the summer?"} + ), + LLMMessage(**{"role": "assistant", "content": "Usually around 6am."}), ] - question = "What about next season?" - - res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == messages - assert call_args["model"] == "gpt" - assert llm.client.chat.completions.create.call_count == 1 # type: ignore + messages = llm.get_messages(message_history) + assert isinstance(messages, list) + for actual, expected in zip(messages, message_history): + assert isinstance(actual, dict) + assert actual["role"] == expected["role"] + assert actual["content"] == expected["content"] -@patch("builtins.__import__") -def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) +def test_openai_llm_get_messages_unknown_role() -> None: llm = OpenAILLM(api_key="my key", model_name="gpt") message_history = [ - {"role": "human", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, + LLMMessage(**{"role": "unknown role", "content": "Usually around 6am."}), ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be 'user', 'assistant' or 'system'" in str(exc_info.value) + with pytest.raises(ValueError, match="Unknown role"): + llm.get_messages(message_history) @patch("builtins.__import__") @@ -176,130 +119,6 @@ def test_openai_llm_invoke_with_tools_happy_path( assert res.content == "openai tool response" -@patch("builtins.__import__") -@patch("json.loads") -def test_openai_llm_invoke_with_tools_with_message_history( - mock_json_loads: Mock, - mock_import: Mock, - test_tool: Tool, -) -> None: - # Set up json.loads to return a dictionary - mock_json_loads.return_value = {"param1": "value1"} - - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - - # Mock the tool call response - mock_function = MagicMock() - mock_function.name = "test_tool" - mock_function.arguments = '{"param1": "value1"}' - - mock_tool_call = MagicMock() - mock_tool_call.function = mock_function - - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock( - content="openai tool response", tool_calls=[mock_tool_call] - ) - ) - ], - ) - - llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [test_tool] - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke_with_tools(question, tools, message_history) # type: ignore - assert isinstance(res, ToolCallResponse) - assert len(res.tool_calls) == 1 - assert res.tool_calls[0].name == "test_tool" - assert res.tool_calls[0].arguments == {"param1": "value1"} - - # Verify the correct messages were passed - message_history.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == message_history - assert call_args["model"] == "gpt" - # Check tools content rather than direct equality - assert len(call_args["tools"]) == 1 - assert call_args["tools"][0]["type"] == "function" - assert call_args["tools"][0]["function"]["name"] == "test_tool" - assert call_args["tools"][0]["function"]["description"] == "A test tool" - assert call_args["tool_choice"] == "auto" - assert call_args["temperature"] == 0.0 - - -@patch("builtins.__import__") -@patch("json.loads") -def test_openai_llm_invoke_with_tools_with_system_instruction( - mock_json_loads: Mock, - mock_import: Mock, - test_tool: Mock, -) -> None: - # Set up json.loads to return a dictionary - mock_json_loads.return_value = {"param1": "value1"} - - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - - # Mock the tool call response - mock_function = MagicMock() - mock_function.name = "test_tool" - mock_function.arguments = '{"param1": "value1"}' - - mock_tool_call = MagicMock() - mock_tool_call.function = mock_function - - mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( - choices=[ - MagicMock( - message=MagicMock( - content="openai tool response", tool_calls=[mock_tool_call] - ) - ) - ], - ) - - llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [test_tool] - - system_instruction = "You are a helpful assistant." - - res = llm.invoke_with_tools("my text", tools, system_instruction=system_instruction) - assert isinstance(res, ToolCallResponse) - - # Verify system instruction was included - messages = [{"role": "system", "content": system_instruction}] - messages.append({"role": "user", "content": "my text"}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == messages - assert call_args["model"] == "gpt" - # Check tools content rather than direct equality - assert len(call_args["tools"]) == 1 - assert call_args["tools"][0]["type"] == "function" - assert call_args["tools"][0]["function"]["name"] == "test_tool" - assert call_args["tools"][0]["function"]["description"] == "A test tool" - assert call_args["tool_choice"] == "auto" - assert call_args["temperature"] == 0.0 - - @patch("builtins.__import__") def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None: mock_openai = get_mock_openai() @@ -342,67 +161,3 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None: res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "openai chat response" - - -@patch("builtins.__import__") -def test_azure_openai_llm_with_message_history_happy_path(mock_import: Mock) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( - MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) - ) - llm = AzureOpenAILLM( - model_name="gpt", - azure_endpoint="https://test.openai.azure.com/", - api_key="my key", - api_version="version", - ) - - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - message_history.append({"role": "user", "content": question}) - # Use assert_called_once() instead of assert_called_once_with() to avoid issues with overloaded functions - llm.client.chat.completions.create.assert_called_once() # type: ignore - # Check call arguments individually - call_args = llm.client.chat.completions.create.call_args[ # type: ignore - 1 - ] # Get the keyword arguments - assert call_args["messages"] == message_history - assert call_args["model"] == "gpt" - - -@patch("builtins.__import__") -def test_azure_openai_llm_with_message_history_validation_error( - mock_import: Mock, -) -> None: - mock_openai = get_mock_openai() - mock_import.return_value = mock_openai - mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = ( - MagicMock( - choices=[MagicMock(message=MagicMock(content="openai chat response"))], - ) - ) - llm = AzureOpenAILLM( - model_name="gpt", - azure_endpoint="https://test.openai.azure.com/", - api_key="my key", - api_version="version", - ) - - message_history = [ - {"role": "user", "content": 33}, - ] - question = "What about next season?" - - with pytest.raises(LLMGenerationError) as exc_info: - llm.invoke(question, message_history) # type: ignore - assert "Input should be a valid string" in str(exc_info.value) diff --git a/tests/unit/llm/test_utils.py b/tests/unit/llm/test_utils.py new file mode 100644 index 000000000..6a969864d --- /dev/null +++ b/tests/unit/llm/test_utils.py @@ -0,0 +1,144 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from pydantic import ValidationError + +from neo4j_graphrag.llm.utils import ( + system_instruction_from_messages, + legacy_inputs_to_messages, +) +from neo4j_graphrag.message_history import InMemoryMessageHistory +from neo4j_graphrag.types import LLMMessage + + +def test_system_instruction_from_messages() -> None: + messages = [ + LLMMessage(role="system", content="text"), + ] + assert system_instruction_from_messages(messages) == "text" + + messages = [] + assert system_instruction_from_messages(messages) is None + + messages = [ + LLMMessage(role="assistant", content="text"), + ] + assert system_instruction_from_messages(messages) is None + + +def test_legacy_inputs_to_messages_only_input_as_llm_message_list() -> None: + messages = legacy_inputs_to_messages( + input=[ + LLMMessage(role="user", content="text"), + ] + ) + assert messages == [ + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_only_input_as_message_history() -> None: + messages = legacy_inputs_to_messages( + input=InMemoryMessageHistory( + messages=[ + LLMMessage(role="user", content="text"), + ] + ) + ) + assert messages == [ + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_only_input_as_str() -> None: + messages = legacy_inputs_to_messages(input="text") + assert messages == [ + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_llm_message_list() -> ( + None +): + messages = legacy_inputs_to_messages( + input="text", + message_history=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ], + ) + assert messages == [ + LLMMessage(role="assistant", content="How can I assist you today?"), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_input_as_str_and_message_history_as_message_history() -> ( + None +): + messages = legacy_inputs_to_messages( + input="text", + message_history=InMemoryMessageHistory( + messages=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ] + ), + ) + assert messages == [ + LLMMessage(role="assistant", content="How can I assist you today?"), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_with_explicit_system_instruction() -> None: + messages = legacy_inputs_to_messages( + input="text", + message_history=[ + LLMMessage(role="assistant", content="How can I assist you today?"), + ], + system_instruction="You are a genius.", + ) + assert messages == [ + LLMMessage(role="system", content="You are a genius."), + LLMMessage(role="assistant", content="How can I assist you today?"), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_do_not_duplicate_system_instruction() -> None: + with pytest.warns( + UserWarning, + match="system_instruction provided but ignored as the message history already contains a system message", + ): + messages = legacy_inputs_to_messages( + input="text", + message_history=[ + LLMMessage(role="system", content="You are super smart."), + ], + system_instruction="You are a genius.", + ) + assert messages == [ + LLMMessage(role="system", content="You are super smart."), + LLMMessage(role="user", content="text"), + ] + + +def test_legacy_inputs_to_messages_wrong_type_in_message_list() -> None: + with pytest.raises(ValidationError, match="Input should be a valid string"): + legacy_inputs_to_messages( + input="text", + message_history=[ + {"role": "system", "content": 10}, # type: ignore + ], + )