Skip to content

Commit

Permalink
Refactor embedding instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
oskarhane committed Sep 28, 2023
1 parent 4572231 commit c2cf59e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 40 deletions.
25 changes: 6 additions & 19 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings

from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
Expand All @@ -17,7 +16,7 @@
)
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import extract_title_and_question
from utils import extract_title_and_question, load_embedding_model

load_dotenv(".env")

Expand All @@ -27,12 +26,11 @@
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
embedding_model_name = os.getenv("EMBEDDING_MODEL")
llm_name = os.getenv("LLM")

# Remapping ror Langchain Neo4j integration
os.environ["NEO4J_URL"] = url

logger = get_logger(__name__)


neo4j_graph = Neo4jGraph(url=url, username=username, password=password)


Expand All @@ -59,20 +57,9 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
self.container.markdown(self.text)


if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=ollama_base_url, model="llama2")
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
)
dimension = 384
logger.info("Embedding: Using SentenceTransformer")
embeddings, dimension = load_embedding_model(
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
)

create_vector_index(dimension)

Expand Down
1 change: 1 addition & 0 deletions loader.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ RUN pip install --upgrade -r requirements.txt

# COPY .env .
COPY loader.py .
COPY utils.py .

EXPOSE 8502

Expand Down
25 changes: 4 additions & 21 deletions loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import os
import requests

from dotenv import load_dotenv
from langchain.embeddings import (
OllamaEmbeddings,
OpenAIEmbeddings,
SentenceTransformerEmbeddings,
)
from langchain.graphs import Neo4jGraph

import streamlit as st
from streamlit.logger import get_logger
from utils import load_embedding_model

load_dotenv(".env")

Expand All @@ -24,20 +18,9 @@

logger = get_logger(__name__)

if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=ollama_base_url, model="llama2")
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
)
dimension = 384
logger.info("Embedding: Using SentenceTransformer")
embeddings, dimension = load_embedding_model(
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
)

neo4j_graph = Neo4jGraph(url=url, username=username, password=password)

Expand Down
22 changes: 22 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings


def load_embedding_model(embedding_model_name: str, config={}, logger=print):
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
)
dimension = 384
logger.info("Embedding: Using SentenceTransformer")
return embeddings, dimension


def extract_title_and_question(input_string):
lines = input_string.strip().split("\n")

Expand Down

0 comments on commit c2cf59e

Please sign in to comment.