Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generic RAG chatbot example #250

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions community/generic-rag-bot/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[client]
showErrorDetails = false

[theme]
primaryColor = "#76b900"
backgroundColor = "white"

[browser]
gatherUsageStats = false
42 changes: 42 additions & 0 deletions community/generic-rag-bot/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Tutorial for a Generic RAG-Based Chatbot

This is a tutorial for how to build your own generic RAG chatbot. It is intended as a foundation for building more complex, domain-specific RAG-based assistants or tools. Note that no local GPU is needed to run this, as it is using NIMs from the NVIDIA NGC catalog.

## Acknowledgements

- This implementation is based on [Rag in 5 Minutes](https://github.com/NVIDIA/GenerativeAIExamples/tree/main/community/5_mins_rag_no_gpu), with changes primarily made to the UI as well as updates to some outdated libraries, models, and dependencies.
- Alyssa Sawyer also contributed to updating and further developing this repo in her project, [Resume RAG Bot](https://github.com/alysawyer/resume-rag-nv).

## Steps

1. Create a python virtual environment and activate it:

```console
python3 -m venv genai
source genai/bin/activate
```

1. From the root of this folder, `generic-rag-chatbot`, install the requirements:

```console
pip install -r requirements.txt
```

1. Add your NVIDIA API key as an environment variable:

```console
export NVIDIA_API_KEY="nvapi-*"
```

If you don't already have an API key, visit the NVIDIA API Catalog, select on any model, then click on **Get API Key**.

1. Run the example using Streamlit:

```console
streamlit run main.py
```

1. Test the deployed example by going to `http://<host_ip>:8501` in a web browser.

Click **Browse Files** and select the documents for your knowledge base.
After selecting, click **Upload!** to complete the ingestion process.
203 changes: 203 additions & 0 deletions community/generic-rag-bot/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is a simple standalone implementation showing rag pipeline using Nvidia AI Foundational Models.
# It uses a simple Streamlit UI and one file implementation of a minimalistic RAG pipeline.


############################################
# Component #0.5 - UI / Header
############################################

import streamlit as st
import os

# Page settings
st.set_page_config(
layout="wide",
page_title="RAG Chatbot",
page_icon = "🤖",
initial_sidebar_state="expanded")

# Page title
st.header('Generic RAG Chatbot 🤖📝', divider='rainbow')

# Custom CSS
def local_css(file_name):
with open(file_name, "r") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
local_css("style.css")

# Page description
st.markdown('''Manually looking through vast amounts of data can be tedious and time-consuming. This chatbot can expedite that process by providing a platform to query your documents.''')
st.warning("This system leverages generative AI. Any output from the AI agent should be used in conjunction with the original data.", icon="⚠️")

############################################
# Component #1 - Document Loader
############################################

with st.sidebar:
st.subheader("Upload your Documents")

DOCS_DIR = os.path.abspath("./uploaded_docs")

# Make dir to store uploaded documents
if not os.path.exists(DOCS_DIR):
os.makedirs(DOCS_DIR)

# Define form on Streamlit page for uploading files to KB
with st.form("my-form", clear_on_submit=True):
uploaded_files = st.file_uploader("Upload a file to the Knowledge Base:", accept_multiple_files=True)
submitted = st.form_submit_button("Upload!")

# Acknowledge successful file uploads
if uploaded_files and submitted:
for uploaded_file in uploaded_files:
st.success(f"File {uploaded_file.name} uploaded successfully!")
with open(os.path.join(DOCS_DIR, uploaded_file.name), "wb") as f:
f.write(uploaded_file.read())

############################################
# Component #2 - Initalizing Embedding Model and LLM
############################################

from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings

# Make sure to export your NGC NV-Developer API key as NVIDIA_API_KEY!
API_KEY = os.environ['NVIDIA_API_KEY']

# Select embedding model and LLM
document_embedder = NVIDIAEmbeddings(model="nvidia/nv-embedqa-e5-v5", api_key=API_KEY, model_type="passage", truncate="END")
llm = ChatNVIDIA(model="meta/llama-3.1-70b-instruct", api_key=API_KEY, temperature=0)

############################################
# Component #3 - Vector Database Store
############################################

import pickle
import nltk
import ssl

# Disable SSL check NLTK uses on nltk.download calls
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context

# Download tokenizer models from NLTK
nltk.download('punkt')

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# Option for using an existing vector store
with st.sidebar:
use_existing_vector_store = st.radio("Use existing vector store if available", ["Yes", "No"], horizontal=True)

# Load raw documents from the directory
DOCS_DIR = os.path.abspath("./uploaded_docs")
raw_documents = DirectoryLoader(DOCS_DIR).load()

# Check for existing vector store file
vector_store_path = "vectorstore.pkl"
vector_store_exists = os.path.exists(vector_store_path)
vectorstore = None

if use_existing_vector_store == "Yes" and vector_store_exists:
# Load existing vector store
with open(vector_store_path, "rb") as f:
vectorstore = pickle.load(f)
with st.sidebar:
st.info("Existing vector store loaded successfully.")
else:
with st.sidebar:
if raw_documents and use_existing_vector_store == "Yes":
# Chunk documents
with st.spinner("Splitting documents into chunks..."):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=100)
documents = text_splitter.split_documents(raw_documents)

# Convert document chunks to embeddings, and save in a vector store
with st.spinner("Adding document chunks to vector database..."):
vectorstore = FAISS.from_documents(documents, document_embedder)

# Save vector store
with st.spinner("Saving vector store"):
with open(vector_store_path, "wb") as f:
pickle.dump(vectorstore, f)
st.success("Vector store created and saved.")
else:
st.warning("No documents available to process!", icon="⚠️")

############################################
# Component #4 - LLM Response Generation and Chat
############################################

st.subheader("Chat with your AI Assistant")

# Save chat history for this user session
if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

# Define prompt for LLM
prompt_template = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Use the provided context to inform your responses. If no context is available, please state that."),
("human", "{input}")
])

# Define simple prompt chain
chain = prompt_template | llm | StrOutputParser()

# Display an example query for user
user_query = st.chat_input("Please summarize these documents.")

# Complete this section of code every time the user inputs a new query
if user_query:
# Add user query to chat history
st.session_state.messages.append({"role": "user", "content": user_query})
with st.chat_message("user"):
st.markdown(user_query)

with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""

if vectorstore is not None and use_existing_vector_store == "Yes":
# Retrieve relevant chunks for the given user query from the vector store
retriever = vectorstore.as_retriever()
retrieved_docs = retriever.invoke(user_query)

# Concatenate retrieved chunks together as context for LLM
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
augmented_user_input = f"Context: {context}\n\nQuestion: {user_query}\n"
else:
augmented_user_input = f"Question: {user_query}\n"

# Get output from LLM
for response in chain.stream({"input": augmented_user_input}):
full_response += response
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
# Add AI assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
13 changes: 13 additions & 0 deletions community/generic-rag-bot/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
streamlit
faiss-cpu==1.7.4
unstructured[all-docs]==0.11.2
langchain
langchain-community
langchain-core
langchain-nvidia-ai-endpoints
langchain-text-splitters
nltk==3.8.1
numpy==1.23.5
onnx==1.16.1
onnxruntime==1.15.1
python-magic
73 changes: 73 additions & 0 deletions community/generic-rag-bot/style.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* style.css */

/* custom footer */
.footer {
text-align: center;
color: #666;
font-size: 14px;
}

/* NVIDIA green for headers */
h1, h2, h3, h4, h5 {
color: #76b900;
}


/* add line when hovering over link */
.hover-link {
text-decoration: none;
color: inherit;
position: relative;
}

.hover-link::after {
content: '';
position: absolute;
width: 100%;
height: 1px;
bottom: 0;
left: 0;
background-color: #000;
transform: scaleX(0);
transition: transform 0.3s ease-in-out;
}

.hover-link:hover::after {
transform: scaleX(1);
}

/* Remove default formatting for links */
a {
color: #666;
text-decoration: none;
}

/* Remove streamlit bar */
header {
visibility: hidden;
}

/* custom container */

.custom-image-container img {
border-radius: 10px;
}

.custom-column-container {
background-color: #f0f0f0;
border-radius: 10px;
padding: 20px;
}

.custom-column-container .stMarkdown {
padding-right: 20px;
}

.streamlit-expanderHeader {
background-color: white;
color: #76b900;
}
.streamlit-expanderContent {
background-color: white;
color: black;
}