diff --git a/README.md b/README.md index b500d50..62127bd 100644 --- a/README.md +++ b/README.md @@ -1 +1,174 @@ -# Text-to-SQL RAG +# Natural Language to SQL Bot (Text to SQL for SQLite) + +## 📅 Overview + +This project is a **Text-to-SQL Bot** where users can ask questions in simple English, and the system will: +- Generate the correct SQL query. +- Execute the SQL on a **SQLite database** (`employee.db`). +- Return both the **query** and the **output**. + +The system is built using **Flask**, **LangChain**, **OpenAI GPT-3.5**, **FAISS**, and **SQLite**. + +--- + +## 📈 Step-by-Step Project Flow + +```text +Step 1: +I created a SQLite database with a basic Employee table having fields like name, age, city, gender, total experience, and blood group. + +Step 2: +I prepared a CSV file (employee_questions.csv) where I wrote simple natural language questions, their matching SQL queries, and short descriptions. +This CSV acts as few-shot examples to guide the model. + +Step 3: +I generated embeddings for these examples using OpenAI embeddings and stored them in FAISS, a fast vector database, for quick searching. + +Step 4: +I used LangChain to create: +- A FAISS retriever (to search examples based on user input) +- A Conversational Retrieval Chain that connects the retriever with a language model (LLM). + +Step 5: +I connected the system with OpenAI GPT-3.5-turbo as the model. +(But the setup is flexible — it can also work with Gemini, Llama, or any open-source model.) + +Step 6: +I built a Flask API where: +- The user sends a question. +- The system finds similar examples from FAISS. +- It creates a final prompt (schema + examples + user query). +- The model generates the SQL query. +- SQL is cleaned and validated. +- The query is run on the employee.db database. +- The API sends back both the SQL query and the query output. + +Step 7: +In future, I can enhance this system by: +- Adding support for multi-table joins and data modification queries (INSERT/UPDATE). +- Integrating conversation history, so the bot can understand previous context and give smarter, more connected answers. +- Replacing the model with open-source alternatives for cost-saving. +``` + +--- + +## 📞 System Flow + +```text +User Question + ↓ +Flask API Endpoint (POST /) + ↓ +FAISS Retriever (Semantic Search on employee_questions.csv examples) + ↓ +Prompt Formation (Database Schema + Retrieved Examples + User Question) + ↓ +LLM (OpenAI GPT-3.5 / Gemini / Llama etc.) + ↓ +Generated SQL Query + ↓ +SQL Cleaning & Validation + ↓ +Execution on SQLite Database (employee.db) + ↓ +Return Query + Output as API Response +``` + +--- + +## 📊 Technology Stack + +| Component | Technology | +|:----------|:-----------| +| API Server | Flask | +| Database | SQLite (employee.db) | +| Embeddings | OpenAI text-embedding-ada-002 | +| Vectorstore | FAISS | +| LLM | OpenAI GPT-3.5-turbo (flexible to switch) | +| Memory | LangChain ConversationBufferMemory | +| Retrieval Chain | LangChain ConversationalRetrievalChain | + +--- + +## 🔧 Setup Instructions + +### 1. Clone the repository +```bash +git clone +cd your-repo-folder +``` + +### 2. Create and activate a virtual environment +```bash +python3 -m venv env +source env/bin/activate # For Mac/Linux +# OR +env\Scripts\activate.bat # For Windows +``` + +### 3. Install dependencies +```bash +pip install -r requirements.txt +``` + +### 4. Set up environment variables +Create a `.env` file: +```bash +echo "OPENAI_API_KEY=your_openai_api_key_here" > .env +``` + +### 5. Start the server +```bash +python chat.py +``` +Server will run on: +```bash +http://127.0.0.1:8000/ +``` + +--- + +## 📡 API Usage + +- **Endpoint:** `POST /` +- **Request Body Example:** +```json +{ + "question": "How many employees have more than 5 years of experience?" +} +``` + +- **Response Example:** +```json +{ + "response": { + "query": "SELECT COUNT(*) FROM Employee WHERE total_experience > 5;", + "result": [{"COUNT(*)": 20}] + } +} +``` + +You can test using **Postman** or **cURL**. + +--- + +## 🛠️ Future Enhancements +- Add multi-table joins. +- Support Insert, Update, and Delete queries. +- Add conversation history to understand context better. +- Integrate open-source LLMs (to reduce cost and improve control). +- Build a simple frontend UI (Streamlit or React). + +--- + +## 💚 Final Notes + +This project shows how natural language questions can be turned into real SQL queries and executed live on a database. +It's a working example of how **RAG (Retrieval Augmented Generation)** can make databases talk in human language! + +✅ To check chatbot outputs, refer to the **`result_img` folder** available in the repository, where screenshots of working results are attached. + +--- + +# 🌟 Thank you for exploring the Natural Language to SQL Bot! + diff --git a/Result_img/1.png b/Result_img/1.png new file mode 100644 index 0000000..2e8b411 Binary files /dev/null and b/Result_img/1.png differ diff --git a/Result_img/2.png b/Result_img/2.png new file mode 100644 index 0000000..12c72e8 Binary files /dev/null and b/Result_img/2.png differ diff --git a/Result_img/3.png b/Result_img/3.png new file mode 100644 index 0000000..604880d Binary files /dev/null and b/Result_img/3.png differ diff --git a/Result_img/4.png b/Result_img/4.png new file mode 100644 index 0000000..113f1ec Binary files /dev/null and b/Result_img/4.png differ diff --git a/Result_img/5.png b/Result_img/5.png new file mode 100644 index 0000000..03f8fed Binary files /dev/null and b/Result_img/5.png differ diff --git a/Result_img/6.png b/Result_img/6.png new file mode 100644 index 0000000..459d9d4 Binary files /dev/null and b/Result_img/6.png differ diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..a02321c --- /dev/null +++ b/requirement.txt @@ -0,0 +1,158 @@ +accelerate==0.0.1 +aiohttp==3.9.5 +aiosignal==1.3.1 +annotated-types==0.6.0 +anyio==4.0.0 +appnope==0.1.3 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.0 +async-lru==2.0.4 +attrs==23.1.0 +Babel==2.13.0 +backcall==0.2.0 +beautifulsoup4==4.12.2 +bleach==6.1.0 +blinker==1.7.0 +cachetools==5.3.1 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.0 +click==8.1.7 +comm==0.1.4 +dataclasses-json==0.6.4 +debugpy==1.8.0 +decorator==5.1.1 +defusedxml==0.7.1 +distro==1.9.0 +executing==2.0.0 +faiss-cpu==1.8.0 +fastjsonschema==2.18.1 +Flask==3.0.3 +fqdn==1.5.1 +frozenlist==1.4.1 +google-ai-generativelanguage==0.3.3 +google-api-core==2.12.0 +google-auth==2.23.3 +google-generativeai==0.2.1 +googleapis-common-protos==1.61.0 +grpcio==1.59.0 +grpcio-status==1.59.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.0 +hyperplane==0.0.1 +idna==3.4 +ipykernel==6.25.2 +ipython==8.16.1 +ipython-genutils==0.2.0 +ipywidgets==8.1.1 +isoduration==20.11.0 +itsdangerous==2.2.0 +jedi==0.19.1 +Jinja2==3.1.2 +json5==0.9.14 +jsonpatch==1.33 +jsonpointer==2.4 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +jupyter==1.0.0 +jupyter-console==6.6.3 +jupyter-events==0.8.0 +jupyter-lsp==2.2.0 +jupyter_client==8.4.0 +jupyter_core==5.4.0 +jupyter_server==2.8.0 +jupyter_server_terminals==0.4.4 +jupyterlab==4.0.7 +jupyterlab-pygments==0.2.2 +jupyterlab-widgets==3.0.9 +jupyterlab_server==2.25.0 +langchain==0.1.16 +langchain-community==0.0.34 +langchain-core==0.1.45 +langchain-openai==0.1.3 +langchain-text-splitters==0.0.1 +langsmith==0.1.49 +MarkupSafe==2.1.3 +marshmallow==3.21.1 +matplotlib-inline==0.1.6 +mistune==3.0.2 +multidict==6.0.5 +mypy-extensions==1.0.0 +nbclient==0.8.0 +nbconvert==7.9.2 +nbformat==5.9.2 +nest-asyncio==1.5.8 +notebook==7.0.6 +notebook_shim==0.2.3 +numpy==1.26.1 +openai==1.23.2 +orjson==3.10.1 +overrides==7.4.0 +packaging==23.2 +pandas==2.1.1 +pandocfilters==1.5.0 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +platformdirs==3.11.0 +prometheus-client==0.17.1 +prompt-toolkit==3.0.39 +proto-plus==1.22.3 +protobuf==4.24.4 +psutil==5.9.6 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycparser==2.21 +pydantic==2.11.3 +pydantic_core==2.33.1 +Pygments==2.16.1 +PyPDF2==3.0.1 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +python-json-logger==2.0.7 +pytz==2023.3.post1 +PyYAML==6.0.1 +pyzmq==25.1.1 +qtconsole==5.4.4 +QtPy==2.4.0 +referencing==0.30.2 +regex==2024.4.16 +requests==2.31.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rpds-py==0.10.6 +rsa==4.9 +Send2Trash==1.8.2 +setuptools==68.2.2 +six==1.16.0 +sniffio==1.3.0 +soupsieve==2.5 +SQLAlchemy==2.0.29 +stack-data==0.6.3 +tenacity==8.2.3 +terminado==0.17.1 +tiktoken==0.6.0 +tinycss2==1.2.1 +tornado==6.3.3 +tqdm==4.66.1 +traitlets==5.11.2 +types-python-dateutil==2.8.19.14 +typing-inspect==0.9.0 +typing-inspection==0.4.0 +typing_extensions==4.13.2 +tzdata==2023.3 +uri-template==1.3.0 +urllib3==2.0.7 +wcwidth==0.2.8 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.6.4 +Werkzeug==3.0.2 +wheel==0.41.2 +widgetsnbextension==4.0.9 +yarl==1.9.4 diff --git a/sqlite/database.py b/sqlite/database.py new file mode 100644 index 0000000..bf9ff3e --- /dev/null +++ b/sqlite/database.py @@ -0,0 +1,68 @@ +import sqlite3 +import random + +conn = sqlite3.connect("employee.db") +cursor = conn.cursor() + +cursor.execute(""" +CREATE TABLE IF NOT EXISTS Employee ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + age INTEGER, + city TEXT, + gender TEXT, + total_experience REAL, + blood_group TEXT +) +""") + +first_names = [ + "Alice", "Bob", "Charlie", "Diana", "Eva", "Frank", "Grace", "Henry", "Ivy", "Jack", + "Karen", "Leo", "Mona", "Nate", "Olivia", "Paul", "Quinn", "Rachel", "Sam", "Tina", + "Uma", "Victor", "Wendy", "Xavier", "Yvonne", "Zachary", "Amber", "Ben", "Cathy", "Derek", + "Ella", "Fred", "Gina", "Harvey", "Isla", "Jake", "Kelly", "Liam", "Maya", "Noah", + "Olive", "Peter", "Queen", "Ryan", "Sophie", "Thomas", "Ursula", "Vince", "Willow", "Zane" +] + +last_names = [ + "Johnson", "Smith", "Lee", "Evans", "Green", "Wright", "Kim", "Adams", "Brown", "Wilson", + "Clark", "Carter", "Reed", "Foster", "Bell", "Walker", "Parker", "Morris", "Rogers", "Hughes", + "Patel", "Kelly", "Price", "James", "Scott", "King", "Russell", "Bailey", "Hayes", "Barnes", + "Chapman", "Jordan", "Armstrong", "Crawford", "Shaw", "Stone", "Norris", "Palmer", "Holmes", "Douglas", + "Barrett", "Reeves", "Soto", "Bush", "Lane", "Bates", "Ellis", "Flynn", "Flynn", "Flynn" +] + +cities = ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix", "Seattle", "Miami", "Boston", "Dallas", "Atlanta"] +blood_groups = ["O+", "A-", "B+", "AB-", "O-", "B-", "AB+", "A+"] +genders = ["Male", "Female"] + +used_names = set() + +records = [] +while len(records) < 50: + first = random.choice(first_names) + last = random.choice(last_names) + name = f"{first} {last}" + + # Ensure unique name + if name in used_names: + continue + used_names.add(name) + + age = random.randint(22, 50) + city = random.choice(cities) + gender = random.choice(genders) + experience = round(random.uniform(1.0, 15.0), 1) + blood_group = random.choice(blood_groups) + + records.append((name, age, city, gender, experience, blood_group)) + +cursor.executemany(""" +INSERT INTO Employee (name, age, city, gender, total_experience, blood_group) +VALUES (?, ?, ?, ?, ?, ?) +""", records) + +conn.commit() +conn.close() + +print("Database created and 50 unique employee records inserted successfully!") diff --git a/sqlite/employee.db b/sqlite/employee.db new file mode 100644 index 0000000..988cf76 Binary files /dev/null and b/sqlite/employee.db differ diff --git a/text-to-sql/chat.py b/text-to-sql/chat.py new file mode 100644 index 0000000..54c867e --- /dev/null +++ b/text-to-sql/chat.py @@ -0,0 +1,149 @@ +from flask import Flask, request, jsonify +from dotenv import load_dotenv +import os +import pandas as pd +import sqlite3 +from langchain_openai import OpenAIEmbeddings +from langchain.vectorstores import FAISS +from langchain_community.chat_models import ChatOpenAI +from langchain.memory import ConversationBufferMemory +from langchain.chains import ConversationalRetrievalChain +import os +os.environ["LANGCHAIN_PYDANTIC_V1"] = "1" + +load_dotenv() +app = Flask(__name__) + +api_key = os.getenv("OPENAI_API_KEY") +os.environ["OPENAI_API_KEY"] = api_key + +vectorstore = None +conversation_chain = None +chat_history = [] + +db_connection = sqlite3.connect("employee.db", check_same_thread=False) + +table_schema = """ +Table: Employee +Columns: +- id (INTEGER PRIMARY KEY AUTOINCREMENT) +- name (TEXT) +- age (INTEGER) +- city (TEXT) +- gender (TEXT) +- total_experience (REAL) +- blood_group (TEXT) +""" + +def main(csv_file_path): + """Main function to initialize vectorstore and conversation_chain.""" + global vectorstore, conversation_chain + try: + raw_chunks = get_csv_chunks(csv_file_path) + vectorstore = get_vectorstore(raw_chunks) + conversation_chain = get_conversation_chain(vectorstore) + except Exception as e: + print(f"Error in main function: {str(e)}") + +def get_csv_chunks(csv_file_path): + """Extract and format text chunks from CSV file.""" + try: + df = pd.read_csv(csv_file_path) + chunks = [] + for _, row in df.iterrows(): + question = str(row.get('question', '')).strip() + query = str(row.get('query', '')).strip() + description = str(row.get('description', '')).strip() + formatted_text = f"question: {question} query: {query} description: {description}" + chunks.append(formatted_text) + return chunks + except Exception as e: + print(f"Error in get_csv_chunks function: {str(e)}") + +def get_vectorstore(text_chunks): + """Generate vectorstore from text chunks.""" + try: + embeddings = OpenAIEmbeddings() + vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) + return vectorstore + except Exception as e: + print(f"Error in get_vectorstore function: {str(e)}") + +def get_conversation_chain(vectorstore): + """Create a conversation chain for handling user queries.""" + try: + llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.3) + conversation_chain = ConversationalRetrievalChain.from_llm( + llm=llm, + retriever=vectorstore.as_retriever(), + ) + return conversation_chain + except Exception as e: + print(f"Error in get_conversation_chain function: {str(e)}") + +@app.route("/", methods=["POST"]) +def answer_question(): + """Handle POST requests containing user questions.""" + try: + data = request.get_json() + if "question" in data: + user_question = data["question"] + response = handle_user_input(user_question) + print("answer", response) + return jsonify({"response": response}) + else: + return jsonify({"error": "Missing 'question' parameter in the request"}), 400 + except Exception as e: + print(f"Error in answer_question function: {str(e)}") + return jsonify({"error": "An error occurred while processing the request"}), 500 + +import re + +def handle_user_input(user_question): + try: + global chat_history + default_prompt = ("You are a SQL expert support bot. Given the following Employee database table schema, " + "write only the correct SQL query for the user's question without any explanation. " + "Just return the SQL inside triple backticks.\n\n" + f"{table_schema}\n\n" + "If it is a greeting like 'hi' or 'hello', reply politely like 'Hello! I'm here to assist you.'\n\n") + user_question_with_prompt = default_prompt + user_question + response = conversation_chain.invoke({'question': user_question_with_prompt, 'chat_history': chat_history}) + + if 'answer' in response: + sql_query = response['answer'] + + sql_query = sql_query.replace("```sql", "").replace("```", "").strip() + + sql_match = re.search(r"(SELECT .*?;)", sql_query, flags=re.IGNORECASE | re.DOTALL) + if sql_match: + sql_query = sql_match.group(1).strip() + else: + sql_query = sql_query.split('\n')[-1].strip() + + print(f"Cleaned SQL: {sql_query}") + + try: + cursor = db_connection.cursor() + cursor.execute(sql_query) + result = cursor.fetchall() + columns = [description[0] for description in cursor.description] + formatted_result = [dict(zip(columns, row)) for row in result] + + + return {"query": sql_query, "result": formatted_result} + except Exception as query_error: + print(f"Error executing SQL query: {str(query_error)}") + return "An error occurred while executing the SQL query." + else: + return "Error: Unexpected response format" + except Exception as e: + print(f"Error in handle_user_input function: {str(e)}") + return "An error occurred while processing the user input" + + + +if __name__ == '__main__': + csv_file_path = "employee_questions.csv" + main(csv_file_path) + app.run(port=8000) diff --git a/text-to-sql/employee_questions.csv b/text-to-sql/employee_questions.csv new file mode 100644 index 0000000..ec2b85d --- /dev/null +++ b/text-to-sql/employee_questions.csv @@ -0,0 +1,39 @@ +question,query,description +How many total employees are there?,SELECT COUNT(*) FROM Employee;,Counts the total number of employees in the Employee table. +List all employee names.,SELECT name FROM Employee;,Fetches the names of all employees. +Get the average age of employees.,SELECT AVG(age) FROM Employee;,Calculates the average age of all employees. +How many male employees are there?,SELECT COUNT(*) FROM Employee WHERE gender = 'Male';,Counts the number of male employees. +How many female employees are there?,SELECT COUNT(*) FROM Employee WHERE gender = 'Female';,Counts the number of female employees. +List employees with more than 10 years of experience.,SELECT * FROM Employee WHERE total_experience > 10;,Fetches employees whose total experience is greater than 10 years. +List employees whose blood group is 'O+'.,SELECT * FROM Employee WHERE blood_group = 'O+';,Fetches employees with blood group 'O+'. +Get employee count in each city.,"SELECT city, COUNT(*) FROM Employee GROUP BY city;",Counts the number of employees grouped by their city. +Find the youngest employee.,SELECT * FROM Employee ORDER BY age ASC LIMIT 1;,Fetches the details of the youngest employee. +Find the oldest employee.,SELECT * FROM Employee ORDER BY age DESC LIMIT 1;,Fetches the details of the oldest employee. +List employees ordered by experience (highest first).,SELECT * FROM Employee ORDER BY total_experience DESC;,Lists employees starting from the highest total experience. +Get a list of unique blood groups in the company.,SELECT DISTINCT blood_group FROM Employee;,Fetches the unique blood groups present among employees. +Get employees older than 40 years.,SELECT * FROM Employee WHERE age > 40;,Fetches employees whose age is greater than 40. +Find employees located in 'New York'.,SELECT * FROM Employee WHERE city = 'New York';,Fetches all employees located in New York. +Find employees who have less than 5 years of experience.,SELECT * FROM Employee WHERE total_experience < 5;,Fetches employees with less than 5 years of total experience. +Get employee details whose name starts with 'A'.,SELECT * FROM Employee WHERE name LIKE 'A%';,Fetches employees whose names start with the letter 'A'. +Get employee details ordered by name alphabetically.,SELECT * FROM Employee ORDER BY name ASC;,Lists employees ordered alphabetically by their names. +Get employee count for each blood group.,"SELECT blood_group, COUNT(*) FROM Employee GROUP BY blood_group;",Counts the number of employees grouped by blood group. +How many employees have more than 20 years of experience?,SELECT COUNT(*) FROM Employee WHERE total_experience > 20;,Counts employees with more than 20 years total experience. +List all employees with experience between 5 and 10 years.,SELECT * FROM Employee WHERE total_experience BETWEEN 5 AND 10;,Fetches employees with experience between 5 and 10 years. +Get average experience of employees in Dallas.,SELECT AVG(total_experience) FROM Employee WHERE city = 'Dallas';,Calculates average total experience for employees in Dallas. +List employees aged below 25.,SELECT * FROM Employee WHERE age < 25;,Fetches employees whose age is less than 25. +How many employees have blood group A+?,SELECT COUNT(*) FROM Employee WHERE blood_group = 'A+';,Counts employees with blood group A+. +Find the employee with the maximum total experience.,SELECT * FROM Employee ORDER BY total_experience DESC LIMIT 1;,Fetches the employee with the highest total experience. +Find the employee with the minimum total experience.,SELECT * FROM Employee ORDER BY total_experience ASC LIMIT 1;,Fetches the employee with the lowest total experience. +List all employees who live in New York or Los Angeles.,"SELECT * FROM Employee WHERE city IN ('New York', 'Los Angeles');",Fetches employees from either New York or Los Angeles. +List employees with names starting with letter 'S'.,SELECT * FROM Employee WHERE name LIKE 'S%';,Fetches employees whose names start with 'S'. +Find the total number of cities represented by employees.,SELECT COUNT(DISTINCT city) FROM Employee;,Counts how many unique cities employees belong to. +List employees sorted by age descending.,SELECT * FROM Employee ORDER BY age DESC;,Lists employees from oldest to youngest. +List employees sorted by total experience ascending.,SELECT * FROM Employee ORDER BY total_experience ASC;,Lists employees starting from the least experience. +How many male and female employees are there?,"SELECT gender, COUNT(*) FROM Employee GROUP BY gender;",Counts employees grouped by gender. +Get average age of female employees.,SELECT AVG(age) FROM Employee WHERE gender = 'Female';,Calculates average age of all female employees. +List employees whose name contains 'an'.,SELECT * FROM Employee WHERE name LIKE '%an%';,Fetches employees whose name contains 'an'. +Find all employees with experience less than 2 years.,SELECT * FROM Employee WHERE total_experience < 2;,Fetches employees with less than 2 years experience. +Get total number of employees in Phoenix.,SELECT COUNT(*) FROM Employee WHERE city = 'Phoenix';,Counts number of employees located in Phoenix. +Find employees aged between 30 and 40.,SELECT * FROM Employee WHERE age BETWEEN 30 AND 40;,Fetches employees aged between 30 and 40 years. +List employees ordered alphabetically by city.,SELECT * FROM Employee ORDER BY city ASC;,Lists employees sorted alphabetically by city name. +Find the distinct blood groups available.,SELECT DISTINCT blood_group FROM Employee;,Lists all distinct blood groups among employees. \ No newline at end of file