diff --git a/api.py b/api.py index 60b87dc3..ed1d3480 100644 --- a/api.py +++ b/api.py @@ -11,6 +11,7 @@ load_llm, configure_llm_only_chain, configure_qa_rag_chain, + generate_ticket, ) from fastapi import FastAPI, Depends from pydantic import BaseModel @@ -112,6 +113,10 @@ class Question(BaseModel): rag: bool = False +class BaseTicket(BaseModel): + text: str + + @app.get("/query-stream") def qstream(question: Question = Depends()): output_function = llm_chain @@ -143,4 +148,14 @@ async def ask(question: Question = Depends()): {"question": question.text, "chat_history": []}, callbacks=[] ) - return json.dumps({"result": result["answer"], "model": llm_name}) + return {"result": result["answer"], "model": llm_name} + + +@app.get("/generate-ticket") +async def generate_ticket_api(question: BaseTicket = Depends()): + new_title, new_question = generate_ticket( + neo4j_graph=neo4j_graph, + llm_chain=llm_chain, + input_question=question.text, + ) + return {"result": {"title": new_title, "text": new_question}, "model": llm_name} diff --git a/bot.py b/bot.py index 2290602a..61a5cd98 100644 --- a/bot.py +++ b/bot.py @@ -3,15 +3,9 @@ import streamlit as st from streamlit.logger import get_logger from langchain.callbacks.base import BaseCallbackHandler -from langchain.prompts.chat import ( - ChatPromptTemplate, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, -) from langchain.graphs import Neo4jGraph from dotenv import load_dotenv from utils import ( - extract_title_and_question, create_vector_index, ) from chains import ( @@ -19,6 +13,7 @@ load_llm, configure_llm_only_chain, configure_qa_rag_chain, + generate_ticket, ) load_dotenv(".env") @@ -148,65 +143,6 @@ def mode_select() -> str: output_function = rag_chain -def generate_ticket(): - # Get high ranked questions - records = neo4j_graph.query( - "MATCH (q:Question) RETURN q.title AS title, q.body AS body ORDER BY q.score DESC LIMIT 3" - ) - questions = [] - for i, question in enumerate(records, start=1): - questions.append((question["title"], question["body"])) - # Ask LLM to generate new question in the same style - questions_prompt = "" - for i, question in enumerate(questions, start=1): - questions_prompt += f"{i}. {question[0]}\n" - questions_prompt += f"{question[1]}\n\n" - questions_prompt += "----\n\n" - - gen_system_template = f""" - You're an expert in formulating high quality questions. - Can you formulate a question in the same style, detail and tone as the following example questions? - {questions_prompt} - --- - - Don't make anything up, only use information in the following question. - Return a title for the question, and the question post itself. - - Return example: - --- - Title: How do I use the Neo4j Python driver? - Question: I'm trying to connect to Neo4j using the Python driver, but I'm getting an error. - --- - """ - # we need jinja2 since the questions themselves contain curly braces - system_prompt = SystemMessagePromptTemplate.from_template( - gen_system_template, template_format="jinja2" - ) - q_prompt = st.session_state[f"user_input"][-1] - chat_prompt = ChatPromptTemplate.from_messages( - [ - system_prompt, - SystemMessagePromptTemplate.from_template( - """ - Respond in the following format or you will be unplugged. - --- - Title: New title - Question: New question - --- - """ - ), - HumanMessagePromptTemplate.from_template("{text}"), - ] - ) - llm_response = llm_chain( - f"Here's the question to rewrite in the expected format: ```{q_prompt}```", - [], - chat_prompt, - ) - new_title, new_question = extract_title_and_question(llm_response["answer"]) - return (new_title, new_question) - - def open_sidebar(): st.session_state.open_sidebar = True @@ -218,7 +154,11 @@ def close_sidebar(): if not "open_sidebar" in st.session_state: st.session_state.open_sidebar = False if st.session_state.open_sidebar: - new_title, new_question = generate_ticket() + new_title, new_question = generate_ticket( + neo4j_graph=neo4j_graph, + llm_chain=llm_chain, + input_question=st.session_state[f"user_input"][-1], + ) with st.sidebar: st.title("Ticket draft") st.write("Auto generated draft ticket") diff --git a/chains.py b/chains.py index b539a7e0..e0d563de 100644 --- a/chains.py +++ b/chains.py @@ -14,7 +14,7 @@ HumanMessagePromptTemplate, ) from typing import List, Any -from utils import BaseLogger +from utils import BaseLogger, extract_title_and_question def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}): @@ -88,7 +88,9 @@ def generate_llm_output( user_input: str, callbacks: List[Any], prompt=chat_prompt ) -> str: chain = prompt | llm - answer = chain.invoke(user_input, config={"callbacks": callbacks}).content + answer = chain.invoke( + {"question": user_input}, config={"callbacks": callbacks} + ).content return {"answer": answer} return generate_llm_output @@ -160,3 +162,61 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass max_tokens_limit=3375, ) return kg_qa + + +def generate_ticket(neo4j_graph, llm_chain, input_question): + # Get high ranked questions + records = neo4j_graph.query( + "MATCH (q:Question) RETURN q.title AS title, q.body AS body ORDER BY q.score DESC LIMIT 3" + ) + questions = [] + for i, question in enumerate(records, start=1): + questions.append((question["title"], question["body"])) + # Ask LLM to generate new question in the same style + questions_prompt = "" + for i, question in enumerate(questions, start=1): + questions_prompt += f"{i}. {question[0]}\n" + questions_prompt += f"{question[1]}\n\n" + questions_prompt += "----\n\n" + + gen_system_template = f""" + You're an expert in formulating high quality questions. + Can you formulate a question in the same style, detail and tone as the following example questions? + {questions_prompt} + --- + + Don't make anything up, only use information in the following question. + Return a title for the question, and the question post itself. + + Return example: + --- + Title: How do I use the Neo4j Python driver? + Question: I'm trying to connect to Neo4j using the Python driver, but I'm getting an error. + --- + """ + # we need jinja2 since the questions themselves contain curly braces + system_prompt = SystemMessagePromptTemplate.from_template( + gen_system_template, template_format="jinja2" + ) + chat_prompt = ChatPromptTemplate.from_messages( + [ + system_prompt, + SystemMessagePromptTemplate.from_template( + """ + Respond in the following format or you will be unplugged. + --- + Title: New title + Question: New question + --- + """ + ), + HumanMessagePromptTemplate.from_template("{question}"), + ] + ) + llm_response = llm_chain( + f"Here's the question to rewrite in the expected format: ```{input_question}```", + [], + chat_prompt, + ) + new_title, new_question = extract_title_and_question(llm_response["answer"]) + return (new_title, new_question) diff --git a/front-end/src/App.svelte b/front-end/src/App.svelte index 246be414..e609f647 100644 --- a/front-end/src/App.svelte +++ b/front-end/src/App.svelte @@ -4,69 +4,21 @@ import botImage from "./assets/images/bot.jpeg"; import meImage from "./assets/images/me.jpeg"; import MdLink from "./lib/MdLink.svelte"; + import External from "./lib/External.svelte"; + import { chatStates, chatStore } from "./lib/chat.store.js"; + import Modal from "./lib/Modal.svelte"; + import { generationStore } from "./lib/generation.store"; - let messages = []; - let ragMode = true; + let ragMode = false; let question = "How can I create a chatbot on top of my local PDF files using langchain?"; let shouldAutoScroll = true; let input; - let appState = "idle"; // or receiving let senderImages = { bot: botImage, me: meImage }; + let generationModalOpen = false; - async function send() { - if (!question.trim().length) { - return; - } - appState = "receiving"; - addMessage("me", question, ragMode); - const messageId = addMessage("bot", "", ragMode); - try { - const evt = new EventSource( - `http://localhost:8504/query-stream?text=${encodeURI(question)}&rag=${ragMode}` - ); - question = ""; - evt.onmessage = (e) => { - if (e.data) { - const data = JSON.parse(e.data); - if (data.init) { - updateMessage(messageId, "", data.model); - return; - } - updateMessage(messageId, data.token); - } - }; - evt.onerror = (e) => { - // Stream will end with an error - // and we want to close the connection on end (otherwise it will keep reconnecting) - evt.close(); - }; - } catch (e) { - updateMessage(messageId, "Error: " + e.message); - } finally { - appState = "idle"; - } - } - - function updateMessage(existingId, text, model = null) { - if (!existingId) { - return; - } - const existingIdIndex = messages.findIndex((m) => m.id === existingId); - if (existingIdIndex === -1) { - return; - } - messages[existingIdIndex].text += text; - if (model) { - messages[existingIdIndex].model = model; - } - messages = messages; - } - - function addMessage(from, text, rag) { - const newId = Math.random().toString(36).substring(2, 9); - const message = { id: newId, from, text, rag }; - messages = messages.concat([message]); - return newId; + function send() { + chatStore.send(question, ragMode); + question = ""; } function scrollToBottom(node, _) { @@ -79,7 +31,12 @@ shouldAutoScroll = e.target.scrollTop + e.target.clientHeight > e.target.scrollHeight - 55; } - $: appState === "idle" && input && focus(input); + function generateTicket(text) { + generationStore.generate(text); + generationModalOpen = true; + } + + $: $chatStore.state === chatStates.IDLE && input && focus(input); async function focus(node) { await tick(); node.focus(); @@ -88,24 +45,29 @@
-
+
- {#each messages as message (message.id)} + {#each $chatStore.data as message (message.id)}
-
+
+ {#if message.from === "me"} + + {/if}
- +
{#if message.from === "bot"}
@@ -133,7 +95,7 @@
+{#if generationModalOpen} + (generationModalOpen = false)} /> +{/if} diff --git a/front-end/src/lib/Modal.svelte b/front-end/src/lib/Modal.svelte new file mode 100644 index 00000000..5843e8b7 --- /dev/null +++ b/front-end/src/lib/Modal.svelte @@ -0,0 +1,60 @@ + + + + +
+

Create new internal ticket

+
+ +
+
+