Skip to content

Commit

Permalink
Add aws embedding & LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo committed Oct 15, 2023
1 parent 91410af commit d304130
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 18 deletions.
25 changes: 21 additions & 4 deletions chains.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.embeddings import (
OllamaEmbeddings,
SentenceTransformerEmbeddings,
BedrockEmbeddings,
)
from langchain.chat_models import ChatOpenAI, ChatOllama, BedrockChat
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
Expand All @@ -15,13 +19,19 @@

def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=config["ollama_base_url"], model="llama2")
embeddings = OllamaEmbeddings(
base_url=config["ollama_base_url"], model="llama2"
)
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
if embedding_model_name == "aws":
embeddings = BedrockEmbeddings()
dimension = 1536
logger.info("Embedding: Using AWS")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
Expand All @@ -38,6 +48,13 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}):
elif llm_name == "gpt-3.5":
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
elif llm_name == "claudev2":
logger.info("LLM: ClaudeV2")
return BedrockChat(
model_id="anthropic.claude-v2",
model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024},
streaming=True,
)
elif len(llm_name):
logger.info(f"LLM: Using Ollama: {llm_name}")
return ChatOllama(
Expand Down Expand Up @@ -79,7 +96,7 @@ def generate_llm_output(

def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
# RAG response
# System: Always talk in pirate speech.
# System: Always talk in pirate speech.
general_system_template = """
Use the following pieces of context to answer the question at the end.
The context contains question-answer pairs and their links from Stackoverflow.
Expand Down
12 changes: 12 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
Expand Down Expand Up @@ -89,6 +92,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
Expand Down Expand Up @@ -123,6 +129,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
Expand Down Expand Up @@ -159,6 +168,9 @@ services:
- LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false}
- LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT}
- LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
networks:
- net
depends_on:
Expand Down
3 changes: 3 additions & 0 deletions env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#OPENAI_API_KEY=sk-...
#AWS_ACCESS_KEY_ID=
#AWS_SECRET_ACCESS_KEY=
#AWS_DEFAULT_REGION=us-east-1
#OLLAMA_BASE_URL=http://host.docker.internal:11434
#NEO4J_URI=neo4j://database:7687
#NEO4J_USERNAME=neo4j
Expand Down
2 changes: 1 addition & 1 deletion pull_model.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ COPY <<EOF pull_model.clj
(let [llm (get (System/getenv) "LLM")
url (get (System/getenv) "OLLAMA_BASE_URL")]
(println (format "pulling ollama model %s using %s" llm url))
(if (and llm url (not (#{"gpt-4" "gpt-3.5"} llm)))
(if (and llm url (not (#{"gpt-4" "gpt-3.5" "claudev2"} llm)))

;; ----------------------------------------------------------------------
;; just call `ollama pull` here - create OLLAMA_HOST from OLLAMA_BASE_URL
Expand Down
29 changes: 16 additions & 13 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ Learn more about the details in the [technical blog post](https://neo4j.com/deve
Create a `.env` file from the environment template file `env.example`

Available variables:
| Variable Name | Default value | Description |
|------------------------|------------------------------------|-------------------------------------------------------------|
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 |
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 |
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai or ollama |
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |
| Variable Name | Default value | Description |
|------------------------|------------------------------------|-------------------------------------------------------------------------|
| OLLAMA_BASE_URL | http://host.docker.internal:11434 | REQUIRED - URL to Ollama LLM API |
| NEO4J_URI | neo4j://database:7687 | REQUIRED - URL to Neo4j database |
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 or claudev2 |
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws or ollama |
| AWS_ACCESS_KEY_ID | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| AWS_SECRET_ACCESS_KEY | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| AWS_DEFAULT_REGION | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 or embedding_model=openai |
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
| LANGCHAIN_API_KEY | | OPTIONAL - Langchain API key |

## LLM Configuration
MacOS and Linux users can use any LLM that's available via Ollama. Check the "tags" section under the model page you want to use on https://ollama.ai/library and write the tag for the value of the environment variable `LLM=` in th e`.env` file.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ torch==2.0.1
pydantic
uvicorn
sse-starlette
boto3

0 comments on commit d304130

Please sign in to comment.