Skip to content

Commit

Permalink
Refactor logic into separate files for easier reading
Browse files Browse the repository at this point in the history
  • Loading branch information
oskarhane committed Oct 2, 2023
1 parent 325a026 commit 48be0ca
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 215 deletions.
2 changes: 1 addition & 1 deletion bot.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ COPY requirements.txt .

RUN pip install --upgrade -r requirements.txt

# COPY .env .
COPY bot.py .
COPY utils.py .
COPY chains.py .

EXPOSE 8501

Expand Down
181 changes: 21 additions & 160 deletions bot.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os
from typing import List, Any

import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain.vectorstores.neo4j_vector import Neo4jVector

from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import extract_title_and_question, load_embedding_model
from utils import (
extract_title_and_question,
create_vector_index,
)
from chains import (
load_embedding_model,
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
)

load_dotenv(".env")

Expand All @@ -33,19 +36,10 @@

# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)


def create_vector_index(dimension: int) -> None:
index_query = "CALL db.index.vector.createNodeIndex('stackoverflow', 'Question', 'embedding', $dimension, 'cosine')"
try:
neo4j_graph.query(index_query, {"dimension": dimension})
except: # Already exists
pass
index_query = "CALL db.index.vector.createNodeIndex('top_answers', 'Answer', 'embedding', $dimension, 'cosine')"
try:
neo4j_graph.query(index_query, {"dimension": dimension})
except: # Already exists
pass
embeddings, dimension = load_embedding_model(
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
)
create_vector_index(neo4j_graph, dimension)


class StreamHandler(BaseCallbackHandler):
Expand All @@ -58,142 +52,11 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:
self.container.markdown(self.text)


embeddings, dimension = load_embedding_model(
embedding_model_name, config={ollama_base_url: ollama_base_url}, logger=logger
)

create_vector_index(dimension)

if llm_name == "gpt-4":
llm = ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
logger.info("LLM: Using GPT-4")
elif llm_name == "gpt-3.5":
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
logger.info("LLM: Using GPT-3.5 Turbo")
elif len(llm_name):
llm = ChatOllama(
temperature=0,
base_url=ollama_base_url,
model=llm_name,
streaming=True,
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
num_ctx=3072, # Sets the size of the context window used to generate the next token.
)
logger.info(f"LLM: Using Ollama ({llm_name})")
else:
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
logger.info("LLM: Using GPT-3.5 Turbo")

# LLM only response
template = """
You are a helpful assistant that helps a support agent with answering programming questions.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)


def generate_llm_output(
user_input: str, callbacks: List[Any], prompt=chat_prompt
) -> str:
answer = llm(
prompt.format_prompt(
text=user_input,
).to_messages(),
callbacks=callbacks,
).content
return {"answer": answer}


# Vector response
neo4j_db = Neo4jVector.from_existing_index(
embedding=embeddings,
url=url,
username=username,
password=password,
database="neo4j", # neo4j by default
index_name="top_answers", # vector by default
text_node_property="body", # text by default
retrieval_query="""
OPTIONAL MATCH (node)-[:ANSWERS]->(question)
RETURN 'Question: ' + question.title + '\n' + question.body + '\nAnswer: ' +
coalesce(node.body,"") AS text, score, {source:question.link} AS metadata
ORDER BY score ASC // so that best answer are the last
""",
)

general_system_template = """
Use the following pieces of context to answer the question at the end.
The context contains question-answer pairs and their links from Stackoverflow.
You should prefer information from accepted or more upvoted answers.
Make sure to rely on information from the answers and not on questions to provide accuate responses.
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----
{summaries}
----
Each answer you generate should contain a section at the end of links to
Stackoverflow questions and answers you found useful, which are described under Source value.
You can only use links to StackOverflow questions that are present in the context and always
add links to the end of the answer in the style of citations.
Generate concise answers with references sources section of links to
relevant StackOverflow questions only at the end of the answer.
"""
general_user_template = "Question:```{question}```"
messages = [
SystemMessagePromptTemplate.from_template(general_system_template),
HumanMessagePromptTemplate.from_template(general_user_template),
]
qa_prompt = ChatPromptTemplate.from_messages(messages)

qa_chain = load_qa_with_sources_chain(
llm,
chain_type="stuff",
prompt=qa_prompt,
)
qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=neo4j_db.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=True,
max_tokens_limit=3375,
)

# Vector + Knowledge Graph response
kg = Neo4jVector.from_existing_index(
embedding=embeddings,
url=url,
username=username,
password=password,
database="neo4j", # neo4j by default
index_name="stackoverflow", # vector by default
text_node_property="body", # text by default
retrieval_query="""
WITH node AS question, score AS similarity
CALL { with question
MATCH (question)<-[:ANSWERS]-(answer)
WITH answer
ORDER BY answer.is_accepted DESC, answer.score DESC
WITH collect(answer)[..2] as answers
RETURN reduce(str='', answer IN answers | str +
'\n### Answer (Accepted: '+ answer.is_accepted +
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
}
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
ORDER BY similarity ASC // so that best answers are the last
""",
)
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})

kg_qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=kg.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=False,
max_tokens_limit=3375,
llm_chain = configure_llm_only_chain(llm)
rag_chain = configure_qa_rag_chain(
llm, embeddings, embeddings_store_url=url, username=username, password=password
)

# Streamlit UI
Expand Down Expand Up @@ -280,11 +143,9 @@ def mode_select() -> str:

name = mode_select()
if name == "LLM only" or name == "Disabled":
output_function = generate_llm_output
elif name == "Vector":
output_function = qa
output_function = llm_chain
elif name == "Vector + Graph" or name == "Enabled":
output_function = kg_qa
output_function = rag_chain


def generate_ticket():
Expand Down Expand Up @@ -337,7 +198,7 @@ def generate_ticket():
HumanMessagePromptTemplate.from_template("{text}"),
]
)
llm_response = generate_llm_output(
llm_response = llm_chain(
f"Here's the question to rewrite in the expected format: ```{q_prompt}```",
[],
chat_prompt,
Expand Down
147 changes: 147 additions & 0 deletions chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings
from langchain.chat_models import ChatOpenAI, ChatOllama
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from typing import List, Any
from utils import BaseLogger


def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
if embedding_model_name == "ollama":
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2")
dimension = 4096
logger.info("Embedding: Using Ollama")
elif embedding_model_name == "openai":
embeddings = OpenAIEmbeddings()
dimension = 1536
logger.info("Embedding: Using OpenAI")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
)
dimension = 384
logger.info("Embedding: Using SentenceTransformer")
return embeddings, dimension


def load_llm(llm_name: str, logger=BaseLogger(), config={}):
if llm_name == "gpt-4":
logger.info("LLM: Using GPT-4")
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True)
elif llm_name == "gpt-3.5":
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)
elif len(llm_name):
logger.info(f"LLM: Using Ollama: {llm_name}")
return ChatOllama(
temperature=0,
base_url=config["ollama_base_url"],
model=llm_name,
streaming=True,
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text.
num_ctx=3072, # Sets the size of the context window used to generate the next token.
)
logger.info("LLM: Using GPT-3.5")
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True)


def configure_llm_only_chain(llm):
# LLM only response
template = """
You are a helpful assistant that helps a support agent with answering programming questions.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
"""
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
human_template = "{text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)

def generate_llm_output(
user_input: str, callbacks: List[Any], prompt=chat_prompt
) -> str:
answer = llm(
prompt.format_prompt(
text=user_input,
).to_messages(),
callbacks=callbacks,
).content
return {"answer": answer}

return generate_llm_output


def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password):
# RAG response
general_system_template = """
Use the following pieces of context to answer the question at the end.
The context contains question-answer pairs and their links from Stackoverflow.
You should prefer information from accepted or more upvoted answers.
Make sure to rely on information from the answers and not on questions to provide accuate responses.
When you find particular answer in the context useful, make sure to cite it in the answer using the link.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----
{summaries}
----
Each answer you generate should contain a section at the end of links to
Stackoverflow questions and answers you found useful, which are described under Source value.
You can only use links to StackOverflow questions that are present in the context and always
add links to the end of the answer in the style of citations.
Generate concise answers with references sources section of links to
relevant StackOverflow questions only at the end of the answer.
"""
general_user_template = "Question:```{question}```"
messages = [
SystemMessagePromptTemplate.from_template(general_system_template),
HumanMessagePromptTemplate.from_template(general_user_template),
]
qa_prompt = ChatPromptTemplate.from_messages(messages)

qa_chain = load_qa_with_sources_chain(
llm,
chain_type="stuff",
prompt=qa_prompt,
)

# Vector + Knowledge Graph response
kg = Neo4jVector.from_existing_index(
embedding=embeddings,
url=embeddings_store_url,
username=username,
password=password,
database="neo4j", # neo4j by default
index_name="stackoverflow", # vector by default
text_node_property="body", # text by default
retrieval_query="""
WITH node AS question, score AS similarity
CALL { with question
MATCH (question)<-[:ANSWERS]-(answer)
WITH answer
ORDER BY answer.is_accepted DESC, answer.score DESC
WITH collect(answer)[..2] as answers
RETURN reduce(str='', answer IN answers | str +
'\n### Answer (Accepted: '+ answer.is_accepted +
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts
}
RETURN '##Question: ' + question.title + '\n' + question.body + '\n'
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata
ORDER BY similarity ASC // so that best answers are the last
""",
)

kg_qa = RetrievalQAWithSourcesChain(
combine_documents_chain=qa_chain,
retriever=kg.as_retriever(search_kwargs={"k": 2}),
reduce_k_below_max_tokens=False,
max_tokens_limit=3375,
)
return kg_qa
1 change: 1 addition & 0 deletions loader.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ RUN pip install --upgrade -r requirements.txt

COPY loader.py .
COPY utils.py .
COPY chains.py .
COPY images ./images

EXPOSE 8502
Expand Down
Loading

0 comments on commit 48be0ca

Please sign in to comment.