diff --git a/chains.py b/chains.py index aac0536a..c8b82b75 100644 --- a/chains.py +++ b/chains.py @@ -1,6 +1,10 @@ from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings -from langchain.chat_models import ChatOpenAI, ChatOllama +from langchain.embeddings import ( + OllamaEmbeddings, + SentenceTransformerEmbeddings, + BedrockEmbeddings, +) +from langchain.chat_models import ChatOpenAI, ChatOllama, BedrockChat from langchain.vectorstores.neo4j_vector import Neo4jVector from langchain.chains import RetrievalQAWithSourcesChain from langchain.chains.qa_with_sources import load_qa_with_sources_chain @@ -15,13 +19,19 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}): if embedding_model_name == "ollama": - embeddings = OllamaEmbeddings(base_url=config["ollama_base_url"], model="llama2") + embeddings = OllamaEmbeddings( + base_url=config["ollama_base_url"], model="llama2" + ) dimension = 4096 logger.info("Embedding: Using Ollama") elif embedding_model_name == "openai": embeddings = OpenAIEmbeddings() dimension = 1536 logger.info("Embedding: Using OpenAI") + if embedding_model_name == "aws": + embeddings = BedrockEmbeddings() + dimension = 1536 + logger.info("Embedding: Using AWS") else: embeddings = SentenceTransformerEmbeddings( model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model" @@ -38,6 +48,13 @@ def load_llm(llm_name: str, logger=BaseLogger(), config={}): elif llm_name == "gpt-3.5": logger.info("LLM: Using GPT-3.5") return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True) + elif llm_name == "claudev2": + logger.info("LLM: ClaudeV2") + return BedrockChat( + model_id="anthropic.claude-v2", + model_kwargs={"temperature": 0.0, "max_tokens_to_sample": 1024}, + streaming=True, + ) elif len(llm_name): logger.info(f"LLM: Using Ollama: {llm_name}") return ChatOllama( @@ -79,7 +96,7 @@ def generate_llm_output( def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password): # RAG response -# System: Always talk in pirate speech. + # System: Always talk in pirate speech. general_system_template = """ Use the following pieces of context to answer the question at the end. The context contains question-answer pairs and their links from Stackoverflow. diff --git a/docker-compose.yml b/docker-compose.yml index c1a831f3..2fc12b6b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -51,6 +51,9 @@ services: - LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false} - LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT} - LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION} networks: - net depends_on: @@ -89,6 +92,9 @@ services: - LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false} - LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT} - LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION} networks: - net depends_on: @@ -123,6 +129,9 @@ services: - LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false} - LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT} - LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION} networks: - net depends_on: @@ -159,6 +168,9 @@ services: - LANGCHAIN_TRACING_V2=${LANGCHAIN_TRACING_V2-false} - LANGCHAIN_PROJECT=${LANGCHAIN_PROJECT} - LANGCHAIN_API_KEY=${LANGCHAIN_API_KEY} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + - AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION} networks: - net depends_on: diff --git a/env.example b/env.example index f661c366..9bc97ca1 100644 --- a/env.example +++ b/env.example @@ -1,4 +1,7 @@ #OPENAI_API_KEY=sk-... +#AWS_ACCESS_KEY_ID= +#AWS_SECRET_ACCESS_KEY= +#AWS_DEFAULT_REGION=us-east-1 #OLLAMA_BASE_URL=http://host.docker.internal:11434 #NEO4J_URI=neo4j://database:7687 #NEO4J_USERNAME=neo4j diff --git a/pull_model.Dockerfile b/pull_model.Dockerfile index 8cd788ad..e858b7f9 100644 --- a/pull_model.Dockerfile +++ b/pull_model.Dockerfile @@ -15,7 +15,7 @@ COPY <