Skip to content

Commit 7b4e2d3

Browse files
authored
Update LLMInterface to restore LC compatibility (#416)
* Update LLMInterface to restore LC compatibility * Update AnthropicLLM * Update MistralAILLM * Update OllamaLLM * Update CohereLLM * Mypy / ruff * Update VertexAILLM * Update (a)invoke_with_tools methods in the same way * Rename method and return directly list[LLMMessage] * Update GraphRAG to restore full LC compatibility * Test for the utils functions * WIP: update tests * Improve test coverage for utils and base modules * Fix UT OpenAILLM * Update Ollama tests * Update Ollama/Anthropic * WIP update cohere * CHANGELOG.md * Ruff after rebase * More fixes on cohere tests * Add tests for retry behavior * Fix MistralAILLM tests * Fix VertexAILLM tests * mypy * Address comments * Fix mypy again * Fix e2e * Fix CI * Comments
1 parent 62d62b4 commit 7b4e2d3

31 files changed

+989
-1177
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44

55
### Added
66

7+
- Document node is now always created when running SimpleKGPipeline, even if `from_pdf=False`.
8+
- Document metadata is exposed in SimpleKGPipeline run method.
79
- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely.
810

11+
### Fixed
12+
13+
- LangChain Chat models compatibility is now working again.
14+
15+
916
## 1.10.0
1017

1118
### Added

examples/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ are listed in [the last section of this file](#customize).
6969
- [OpenAI (GPT)](./customize/llms/openai_llm.py)
7070
- [Azure OpenAI]()
7171
- [VertexAI (Gemini)](./customize/llms/vertexai_llm.py)
72-
- [MistralAI](./customize/llms/mistalai_llm.py)
72+
- [MistralAI](customize/llms/mistralai_llm.py)
7373
- [Cohere](./customize/llms/cohere_llm.py)
7474
- [Anthropic (Claude)](./customize/llms/anthropic_llm.py)
7575
- [Ollama](./customize/llms/ollama_llm.py)
@@ -142,7 +142,7 @@ are listed in [the last section of this file](#customize).
142142

143143
### Answer: GraphRAG
144144

145-
- [LangChain compatibility](./customize/answer/langchain_compatiblity.py)
145+
- [LangChain compatibility](customize/answer/langchain_compatibility.py)
146146
- [Use a custom prompt](./customize/answer/custom_prompt.py)
147147

148148

File renamed without changes.
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
from neo4j_graphrag.llm import AnthropicLLM, LLMResponse
2+
from neo4j_graphrag.types import LLMMessage
23

34
# set api key here on in the ANTHROPIC_API_KEY env var
45
api_key = None
56

7+
messages: list[LLMMessage] = [
8+
{
9+
"role": "system",
10+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
11+
},
12+
{
13+
"role": "user",
14+
"content": "say something",
15+
},
16+
]
17+
18+
619
llm = AnthropicLLM(
720
model_name="claude-3-opus-20240229",
821
model_params={"max_tokens": 1000}, # max_tokens must be specified
922
api_key=api_key,
1023
)
11-
res: LLMResponse = llm.invoke("say something")
24+
res: LLMResponse = llm.invoke(
25+
# "say something",
26+
messages,
27+
)
1228
print(res.content)
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
from neo4j_graphrag.llm import CohereLLM, LLMResponse
2+
from neo4j_graphrag.types import LLMMessage
23

34
# set api key here on in the CO_API_KEY env var
45
api_key = None
56

7+
messages: list[LLMMessage] = [
8+
{
9+
"role": "system",
10+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
11+
},
12+
{
13+
"role": "user",
14+
"content": "say something",
15+
},
16+
]
17+
618
llm = CohereLLM(
719
model_name="command-r",
820
api_key=api_key,
921
)
10-
res: LLMResponse = llm.invoke("say something")
22+
res: LLMResponse = llm.invoke(input=messages)
1123
print(res.content)

examples/customize/llms/custom_llm.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import random
22
import string
3-
from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union
3+
from typing import Any, Awaitable, Callable, Optional, TypeVar
44

55
from neo4j_graphrag.llm import LLMInterface, LLMResponse
66
from neo4j_graphrag.utils.rate_limit import (
77
RateLimitHandler,
88
# rate_limit_handler,
99
# async_rate_limit_handler,
1010
)
11-
from neo4j_graphrag.message_history import MessageHistory
1211
from neo4j_graphrag.types import LLMMessage
1312

1413

@@ -18,37 +17,27 @@ def __init__(
1817
):
1918
super().__init__(model_name, **kwargs)
2019

21-
# Optional: Apply rate limit handling to synchronous invoke method
22-
# @rate_limit_handler
23-
def invoke(
20+
def _invoke(
2421
self,
25-
input: str,
26-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
27-
system_instruction: Optional[str] = None,
22+
input: list[LLMMessage],
2823
) -> LLMResponse:
2924
content: str = (
3025
self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30))
3126
)
3227
return LLMResponse(content=content)
3328

34-
# Optional: Apply rate limit handling to asynchronous ainvoke method
35-
# @async_rate_limit_handler
36-
async def ainvoke(
29+
async def _ainvoke(
3730
self,
38-
input: str,
39-
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
40-
system_instruction: Optional[str] = None,
31+
input: list[LLMMessage],
4132
) -> LLMResponse:
4233
raise NotImplementedError()
4334

4435

45-
llm = CustomLLM(
46-
""
47-
) # if rate_limit_handler and async_rate_limit_handler decorators are used, the default rate limit handler will be applied automatically (retry with exponential backoff)
36+
llm = CustomLLM("")
4837
res: LLMResponse = llm.invoke("text")
4938
print(res.content)
5039

51-
# If rate_limit_handler and async_rate_limit_handler decorators are used and you want to use a custom rate limit handler
40+
# If you want to use a custom rate limit handler
5241
# Type variables for function signatures used in rate limit handlers
5342
F = TypeVar("F", bound=Callable[..., Any])
5443
AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]])

examples/customize/llms/mistalai_llm.py

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from neo4j_graphrag.llm import MistralAILLM, LLMResponse
2+
from neo4j_graphrag.message_history import InMemoryMessageHistory
3+
from neo4j_graphrag.types import LLMMessage
4+
5+
# set api key here on in the MISTRAL_API_KEY env var
6+
api_key = None
7+
8+
9+
messages: list[LLMMessage] = [
10+
{
11+
"role": "system",
12+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
13+
},
14+
{
15+
"role": "user",
16+
"content": "say something",
17+
},
18+
]
19+
20+
21+
llm = MistralAILLM(
22+
model_name="mistral-small-latest",
23+
api_key=api_key,
24+
)
25+
res: LLMResponse = llm.invoke(
26+
# "say something",
27+
# messages,
28+
InMemoryMessageHistory(
29+
messages=messages,
30+
)
31+
)
32+
print(res.content)

examples/customize/llms/ollama_llm.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,26 @@
33
"""
44

55
from neo4j_graphrag.llm import LLMResponse, OllamaLLM
6+
from neo4j_graphrag.types import LLMMessage
7+
8+
messages: list[LLMMessage] = [
9+
{
10+
"role": "system",
11+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
12+
},
13+
{
14+
"role": "user",
15+
"content": "say something",
16+
},
17+
]
18+
619

720
llm = OllamaLLM(
8-
model_name="<model_name>",
21+
model_name="orca-mini:latest",
922
# model_params={"options": {"temperature": 0}, "format": "json"},
1023
# host="...", # if using a remote server
1124
)
12-
res: LLMResponse = llm.invoke("What is the additive color model?")
25+
res: LLMResponse = llm.invoke(
26+
messages,
27+
)
1328
print(res.content)
Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
2+
from neo4j_graphrag.message_history import InMemoryMessageHistory
3+
from neo4j_graphrag.types import LLMMessage
24

35
# set api key here on in the OPENAI_API_KEY env var
46
api_key = None
57

8+
messages: list[LLMMessage] = [
9+
{
10+
"role": "system",
11+
"content": "You are a seasoned actor and expert performer, renowned for your one-man shows and comedic talent.",
12+
},
13+
{
14+
"role": "user",
15+
"content": "say something",
16+
},
17+
]
18+
19+
620
llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
7-
res: LLMResponse = llm.invoke("say something")
21+
res: LLMResponse = llm.invoke(
22+
# "say something",
23+
# messages,
24+
InMemoryMessageHistory(
25+
messages=messages,
26+
)
27+
)
828
print(res.content)

0 commit comments

Comments
 (0)