|
| 1 | +from langchain_anthropic import ChatAnthropic |
| 2 | +from langchain_cohere import ChatCohere |
| 3 | +from langchain_groq import ChatGroq |
| 4 | +from langchain_openai import ChatOpenAI |
| 5 | +from langchain_core.runnables.history import RunnableWithMessageHistory |
| 6 | +from langchain_community.chat_message_histories import SQLChatMessageHistory |
| 7 | +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| 8 | +import gradio as gr |
| 9 | +from argparse import ArgumentParser |
| 10 | +from qdrant_client import QdrantClient |
| 11 | +from sentence_transformers import SentenceTransformer |
| 12 | +from utils import * |
| 13 | +import os |
| 14 | +import subprocess as sp |
| 15 | + |
| 16 | +argparse = ArgumentParser() |
| 17 | + |
| 18 | +argparse.add_argument( |
| 19 | + "-pf", |
| 20 | + "--pdf_file", |
| 21 | + help="Single pdf file or N pdfs reported like this: /path/to/file1.pdf,/path/to/file2.pdf,...,/path/to/fileN.pdf (there is no strict naming, you just need to provide them comma-separated)", |
| 22 | + required=False, |
| 23 | + default="No file" |
| 24 | +) |
| 25 | + |
| 26 | +argparse.add_argument( |
| 27 | + "-d", |
| 28 | + "--directory", |
| 29 | + help="Directory where all your pdfs of interest are stored", |
| 30 | + required=False, |
| 31 | + default="No directory" |
| 32 | +) |
| 33 | + |
| 34 | +argparse.add_argument( |
| 35 | + "-l", |
| 36 | + "--language", |
| 37 | + help="Language of the written content contained in the pdfs", |
| 38 | + required=False, |
| 39 | + default="Same as query" |
| 40 | +) |
| 41 | + |
| 42 | +args = argparse.parse_args() |
| 43 | + |
| 44 | + |
| 45 | +pdff = args.pdf_file |
| 46 | +dirs = args.directory |
| 47 | +lan = args.language |
| 48 | + |
| 49 | + |
| 50 | +if pdff.replace("\\","").replace("'","") != "None" and dirs.replace("\\","").replace("'","") == "No directory": |
| 51 | + pdfs = pdff.replace("\\","/").replace("'","").split(",") |
| 52 | +else: |
| 53 | + pdfs = [os.path.join(dirs.replace("\\","/").replace("'",""), f) for f in os.listdir(dirs.replace("\\","/").replace("'","")) if f.endswith(".pdf")] |
| 54 | + |
| 55 | +client = QdrantClient(host="host.docker.internal", port="6333") |
| 56 | +encoder = SentenceTransformer("all-MiniLM-L6-v2") |
| 57 | + |
| 58 | +pdfdb = PDFdatabase(pdfs, encoder, client) |
| 59 | +pdfdb.preprocess() |
| 60 | +pdfdb.collect_data() |
| 61 | +pdfdb.qdrant_collection_and_upload() |
| 62 | + |
| 63 | +sp.run("rm -rf memory.db", shell=True) |
| 64 | + |
| 65 | +def get_session_history(session_id): |
| 66 | + return SQLChatMessageHistory(session_id, "sqlite:///memory.db") |
| 67 | + |
| 68 | +NAME2CHAT = {"Cohere": ChatCohere, "claude-3-opus-20240229": ChatAnthropic, "claude-3-sonnet-20240229": ChatAnthropic, "claude-3-haiku-20240307": ChatAnthropic, "llama3-8b-8192": ChatGroq, "llama3-70b-8192": ChatGroq, "mixtral-8x7b-32768": ChatGroq, "gemma-7b-it": ChatGroq, "gpt-4o": ChatOpenAI, "gpt-3.5-turbo-0125": ChatOpenAI} |
| 69 | +NAME2APIKEY = {"Cohere": "COHERE_API_KEY", "claude-3-opus-20240229": "ANTHROPIC_API_KEY", "claude-3-sonnet-20240229": "ANTHROPIC_API_KEY", "claude-3-haiku-20240307": "ANTHROPIC_API_KEY", "llama3-8b-8192": "GROQ_API_KEY", "llama3-70b-8192": "GROQ_API_KEY", "mixtral-8x7b-32768": "GROQ_API_KEY", "gemma-7b-it": "GROQ_API_KEY", "gpt-4o": "OPENAI_API_KEY", "gpt-3.5-turbo-0125": "OPENAI_API_KEY"} |
| 70 | + |
| 71 | + |
| 72 | + |
| 73 | +system_template = "You are an helpful assistant that can rely on this: {context} and on the previous message history as context, and from that you build a context and history-aware reply to this user input:" |
| 74 | + |
| 75 | + |
| 76 | + |
| 77 | +def reply(message, history, name, api_key, temperature, max_new_tokens, sessionid): |
| 78 | + global pdfdb |
| 79 | + os.environ[NAME2APIKEY[name]] = api_key |
| 80 | + if name == "Cohere": |
| 81 | + model = NAME2CHAT[name](temperature=temperature, max_tokens=max_new_tokens) |
| 82 | + else: |
| 83 | + model = NAME2CHAT[name](model=name,temperature=temperature, max_tokens=max_new_tokens) |
| 84 | + prompt_template = ChatPromptTemplate.from_messages( |
| 85 | + [("system", system_template), |
| 86 | + MessagesPlaceholder(variable_name="history"), |
| 87 | + ("human", "{input}")] |
| 88 | + ) |
| 89 | + chain = prompt_template | model |
| 90 | + runnable_with_history = RunnableWithMessageHistory( |
| 91 | + chain, |
| 92 | + get_session_history, |
| 93 | + input_messages_key="input", |
| 94 | + history_messages_key="history", |
| 95 | + ) |
| 96 | + txt = Translation(message, "en") |
| 97 | + if txt.original == "en" and lan.replace("\\","").replace("'","") == "None": |
| 98 | + txt2txt = NeuralSearcher(pdfdb.collection_name, pdfdb.client, pdfdb.encoder) |
| 99 | + results = txt2txt.search(message) |
| 100 | + response = runnable_with_history.invoke({"context": results[0]["text"], "input": message}, config={"configurable": {"session_id": sessionid}})##CONFIGURE! |
| 101 | + return response.content |
| 102 | + elif txt.original == "en" and lan.replace("\\","").replace("'","") != "None": |
| 103 | + txt2txt = NeuralSearcher(pdfdb.collection_name, pdfdb.client, pdfdb.encoder) |
| 104 | + transl = Translation(message, lan.replace("\\","").replace("'","")) |
| 105 | + message = transl.translatef() |
| 106 | + results = txt2txt.search(message) |
| 107 | + t = Translation(results[0]["text"], txt.original) |
| 108 | + res = t.translatef() |
| 109 | + response = runnable_with_history.invoke({"context": res, "input": message}, config={"configurable": {"session_id": sessionid}})##CONFIGURE! |
| 110 | + return response.content |
| 111 | + elif txt.original != "en" and lan.replace("\\","").replace("'","") == "None": |
| 112 | + txt2txt = NeuralSearcher(pdfdb.collection_name, pdfdb.client, pdfdb.encoder) |
| 113 | + results = txt2txt.search(message) |
| 114 | + transl = Translation(results[0]["text"], "en") |
| 115 | + translation = transl.translatef() |
| 116 | + response = runnable_with_history.invoke({"context": translation, "input": message}, config={"configurable": {"session_id": sessionid}})##CONFIGURE! |
| 117 | + t = Translation(response.content, txt.original) |
| 118 | + res = t.translatef() |
| 119 | + return res |
| 120 | + else: |
| 121 | + txt2txt = NeuralSearcher(pdfdb.collection_name, pdfdb.client, pdfdb.encoder) |
| 122 | + transl = Translation(message, lan.replace("\\","").replace("'","")) |
| 123 | + message = transl.translatef() |
| 124 | + results = txt2txt.search(message) |
| 125 | + t = Translation(results[0]["text"], txt.original) |
| 126 | + res = t.translatef() |
| 127 | + response = runnable_with_history.invoke({"context": res, "input": message}, config={"configurable": {"session_id": sessionid}})##CONFIGURE! |
| 128 | + tr = Translation(response.content, txt.original) |
| 129 | + ress = tr.translatef() |
| 130 | + return ress |
| 131 | + |
| 132 | +chat_model = gr.Dropdown( |
| 133 | + [m for m in list(NAME2APIKEY)], label="Chat Model", info="Choose one of the available chat models" |
| 134 | + ) |
| 135 | + |
| 136 | +user_api_key = gr.Textbox( |
| 137 | + label="API key", |
| 138 | + info="Paste your API key here", |
| 139 | + lines=1, |
| 140 | + type="password", |
| 141 | +) |
| 142 | + |
| 143 | +user_temperature = gr.Slider(0, 1, value=0.5, label="Temperature", info="Select model temperature") |
| 144 | + |
| 145 | +user_max_new_tokens = gr.Slider(0, 8192, value=1024, label="Max new tokens", info="Select max output tokens (higher number of tokens will result in a longer latency)") |
| 146 | + |
| 147 | +user_session_id = gr.Textbox(label="Session ID",info="This alphanumeric code will link model reply to a specific message history of which the models will be aware when replying. Changing it will result in the loss of memory for your model",value="1") |
| 148 | + |
| 149 | +additional_accordion = gr.Accordion(label="Parameters to be set before you start chatting", open=True) |
| 150 | + |
| 151 | +demo = gr.ChatInterface(fn=reply, additional_inputs=[chat_model, user_api_key, user_temperature, user_max_new_tokens, user_session_id], additional_inputs_accordion=additional_accordion, title="everything-ai-buildyourllm") |
| 152 | + |
| 153 | + |
| 154 | +if __name__=="__main__": |
| 155 | + demo.launch(server_name="0.0.0.0", share=False) |
0 commit comments