-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag_system.py
180 lines (141 loc) · 5.84 KB
/
rag_system.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
This module implements the core RAG (Retrieval-Augmented Generation) system.
It combines document processing, embedding generation, and query handling.
"""
import os
from typing import List
import fire
import PyPDF2
from openai_ops import OpenAIOperations
from postgres_ops import PostgresOperations
class DocumentProcessor:
"""A utility class for processing documents."""
@staticmethod
def clean_text(text: str) -> str:
"""
Clean the input text by removing problematic characters.
Args:
text (str): The input text to clean.
Returns:
str: The cleaned text.
"""
# Remove null characters
text = text.replace('\x00', '')
# Remove other potentially problematic characters
text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\r\t')
return text
@staticmethod
def extract_text_from_pdf(file_path: str) -> str:
"""
Extract text content from a PDF file and clean it.
Args:
file_path (str): The path to the PDF file.
Returns:
str: The extracted and cleaned text content.
"""
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return DocumentProcessor.clean_text(text)
@staticmethod
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
"""
Split text into overlapping chunks.
Args:
text (str): The input text to chunk.
chunk_size (int): The size of each chunk.
overlap (int): The overlap between chunks.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
start = end - overlap
return chunks
@staticmethod
def estimate_tokens(text: str) -> int:
"""
Estimate the number of tokens in a text.
Args:
text (str): The input text.
Returns:
int: An estimate of the number of tokens.
"""
# This is a rough estimate. Actual token count may vary.
return len(text.split())
class RAGSystem:
"""The main RAG system class that orchestrates document processing and querying."""
def __init__(self, postgres_host='postgres', postgres_port=5432):
"""
Initialize the RAG system.
Args:
postgres_host (str): The host for the PostgreSQL database.
postgres_port (int): The port for the PostgreSQL database.
"""
self.openai_ops = OpenAIOperations(api_key=os.getenv('AZURE_OPENAI_KEY'))
self.postgres_ops = PostgresOperations(
dbname=os.getenv('POSTGRES_DB'),
user=os.getenv('POSTGRES_USER'),
password=os.getenv('POSTGRES_PASSWORD'),
host=postgres_host, # This should match the service name in docker-compose.yml
port=postgres_port
)
self.max_tokens = 16000 # Maximum tokens for 16k model
def upsert_document(self, file_path: str):
"""
Process and store a document in the system.
Args:
file_path (str): The path to the document file.
Returns:
dict: A message indicating the success of the operation.
"""
filename = os.path.basename(file_path)
content = DocumentProcessor.extract_text_from_pdf(file_path)
chunks = DocumentProcessor.chunk_text(content, chunk_size=1000, overlap=100)
embeddings = [self.openai_ops.get_embedding(chunk) for chunk in chunks]
self.postgres_ops.store_document(filename, content, chunks, embeddings)
return {"message": f"Document {filename} uploaded and embedded successfully"}
def query(self, question: str):
"""
Process a query and generate a response.
Args:
question (str): The query text.
Returns:
dict: The answer to the query and related information.
"""
query_embedding = self.openai_ops.get_embedding(question)
similar_chunks = self.postgres_ops.search_similar_chunks(query_embedding, top_k=5)
context = " ".join([chunk["chunk"] for chunk in similar_chunks])
# Ensure context is within token limit
if DocumentProcessor.estimate_tokens(context) > self.max_tokens:
context = DocumentProcessor.truncate_text_to_tokens(context, self.max_tokens)
# Generate response
answer = self.openai_ops.generate_response(context, question)
# return {"answer": answer, "source_documents": [chunk["filename"] for chunk in similar_chunks]}
return {"answer": answer}
def clear_db(self):
"""
Clear all data from the database.
Utility function for testing.
Returns:
dict: A message indicating the success of the operation.
"""
self.postgres_ops.clear_db()
return {"message": "Database cleared successfully"}
def print_db_contents(self):
"""
Retrieve and print the contents of the database.
Returns:
List[Dict[str, any]]: A summary of all documents in the database.
"""
return self.postgres_ops.print_db_contents()
if __name__ == "__main__":
# Initialize RAG system and expose methods to command line interface
rag_system = RAGSystem(postgres_host='postgres', postgres_port=5432)
fire.Fire({"upsert": rag_system.upsert_document, "query": rag_system.query})