forked from docker/genai-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor logic into separate files for easier reading
- Loading branch information
Showing
6 changed files
with
205 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from langchain.embeddings import OllamaEmbeddings, SentenceTransformerEmbeddings | ||
from langchain.chat_models import ChatOpenAI, ChatOllama | ||
from langchain.vectorstores.neo4j_vector import Neo4jVector | ||
from langchain.chains import RetrievalQAWithSourcesChain | ||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | ||
from langchain.prompts.chat import ( | ||
ChatPromptTemplate, | ||
SystemMessagePromptTemplate, | ||
HumanMessagePromptTemplate, | ||
) | ||
from typing import List, Any | ||
from utils import BaseLogger | ||
|
||
|
||
def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}): | ||
if embedding_model_name == "ollama": | ||
embeddings = OllamaEmbeddings(base_url=config.ollama_base_url, model="llama2") | ||
dimension = 4096 | ||
logger.info("Embedding: Using Ollama") | ||
elif embedding_model_name == "openai": | ||
embeddings = OpenAIEmbeddings() | ||
dimension = 1536 | ||
logger.info("Embedding: Using OpenAI") | ||
else: | ||
embeddings = SentenceTransformerEmbeddings( | ||
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model" | ||
) | ||
dimension = 384 | ||
logger.info("Embedding: Using SentenceTransformer") | ||
return embeddings, dimension | ||
|
||
|
||
def load_llm(llm_name: str, logger=BaseLogger(), config={}): | ||
if llm_name == "gpt-4": | ||
logger.info("LLM: Using GPT-4") | ||
return ChatOpenAI(temperature=0, model_name="gpt-4", streaming=True) | ||
elif llm_name == "gpt-3.5": | ||
logger.info("LLM: Using GPT-3.5") | ||
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True) | ||
elif len(llm_name): | ||
logger.info(f"LLM: Using Ollama: {llm_name}") | ||
return ChatOllama( | ||
temperature=0, | ||
base_url=config["ollama_base_url"], | ||
model=llm_name, | ||
streaming=True, | ||
top_k=10, # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative. | ||
top_p=0.3, # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more focused text. | ||
num_ctx=3072, # Sets the size of the context window used to generate the next token. | ||
) | ||
logger.info("LLM: Using GPT-3.5") | ||
return ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", streaming=True) | ||
|
||
|
||
def configure_llm_only_chain(llm): | ||
# LLM only response | ||
template = """ | ||
You are a helpful assistant that helps a support agent with answering programming questions. | ||
If you don't know the answer, just say that you don't know, don't try to make up an answer. | ||
""" | ||
system_message_prompt = SystemMessagePromptTemplate.from_template(template) | ||
human_template = "{text}" | ||
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) | ||
chat_prompt = ChatPromptTemplate.from_messages( | ||
[system_message_prompt, human_message_prompt] | ||
) | ||
|
||
def generate_llm_output( | ||
user_input: str, callbacks: List[Any], prompt=chat_prompt | ||
) -> str: | ||
answer = llm( | ||
prompt.format_prompt( | ||
text=user_input, | ||
).to_messages(), | ||
callbacks=callbacks, | ||
).content | ||
return {"answer": answer} | ||
|
||
return generate_llm_output | ||
|
||
|
||
def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, password): | ||
# RAG response | ||
general_system_template = """ | ||
Use the following pieces of context to answer the question at the end. | ||
The context contains question-answer pairs and their links from Stackoverflow. | ||
You should prefer information from accepted or more upvoted answers. | ||
Make sure to rely on information from the answers and not on questions to provide accuate responses. | ||
When you find particular answer in the context useful, make sure to cite it in the answer using the link. | ||
If you don't know the answer, just say that you don't know, don't try to make up an answer. | ||
---- | ||
{summaries} | ||
---- | ||
Each answer you generate should contain a section at the end of links to | ||
Stackoverflow questions and answers you found useful, which are described under Source value. | ||
You can only use links to StackOverflow questions that are present in the context and always | ||
add links to the end of the answer in the style of citations. | ||
Generate concise answers with references sources section of links to | ||
relevant StackOverflow questions only at the end of the answer. | ||
""" | ||
general_user_template = "Question:```{question}```" | ||
messages = [ | ||
SystemMessagePromptTemplate.from_template(general_system_template), | ||
HumanMessagePromptTemplate.from_template(general_user_template), | ||
] | ||
qa_prompt = ChatPromptTemplate.from_messages(messages) | ||
|
||
qa_chain = load_qa_with_sources_chain( | ||
llm, | ||
chain_type="stuff", | ||
prompt=qa_prompt, | ||
) | ||
|
||
# Vector + Knowledge Graph response | ||
kg = Neo4jVector.from_existing_index( | ||
embedding=embeddings, | ||
url=embeddings_store_url, | ||
username=username, | ||
password=password, | ||
database="neo4j", # neo4j by default | ||
index_name="stackoverflow", # vector by default | ||
text_node_property="body", # text by default | ||
retrieval_query=""" | ||
WITH node AS question, score AS similarity | ||
CALL { with question | ||
MATCH (question)<-[:ANSWERS]-(answer) | ||
WITH answer | ||
ORDER BY answer.is_accepted DESC, answer.score DESC | ||
WITH collect(answer)[..2] as answers | ||
RETURN reduce(str='', answer IN answers | str + | ||
'\n### Answer (Accepted: '+ answer.is_accepted + | ||
' Score: ' + answer.score+ '): '+ answer.body + '\n') as answerTexts | ||
} | ||
RETURN '##Question: ' + question.title + '\n' + question.body + '\n' | ||
+ answerTexts AS text, similarity as score, {source: question.link} AS metadata | ||
ORDER BY similarity ASC // so that best answers are the last | ||
""", | ||
) | ||
|
||
kg_qa = RetrievalQAWithSourcesChain( | ||
combine_documents_chain=qa_chain, | ||
retriever=kg.as_retriever(search_kwargs={"k": 2}), | ||
reduce_k_below_max_tokens=False, | ||
max_tokens_limit=3375, | ||
) | ||
return kg_qa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.