Skip to content

Commit

Permalink
refactoring delete and create index
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniePacheco committed Oct 7, 2024
1 parent 87e15cc commit 9735047
Showing 1 changed file with 75 additions and 53 deletions.
128 changes: 75 additions & 53 deletions backend/managers/RagManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,17 @@ async def retrieve_and_generate(self, collection_name, query, llm) -> str:
persona = await personas_m.retrieve_persona(persona_id)
personality_prompt = persona.description
# Combine the system prompt and context
system_prompt = (os.environ.get('SYSTEM_PROMPT') +
"\n\n{context}" +
"\n\nHere is some information about the assistant expertise to help you answer your questions: " + personality_prompt +
".\n\nIf the user asks you a question about the assistant information, example: 'What can you tell me about the assistant?', 'What is the name of the assistant?', 'Who is the assistant?', etc." +
"\n\nPlease answer the question using the following information:\n\n" + personality_prompt
)
# Combine the system prompt and context
system_prompt = (os.environ.get('SYSTEM_PROMPT') + "\n\n{context}" +
"\n\nHere is some information about the assistant expertise to help you answer your questions: " +
personality_prompt)
# system_prompt = (os.environ.get('SYSTEM_PROMPT') +
# "\n\n{context}" +
# "\n\nHere is some information about the assistant expertise to help you answer your questions: " + personality_prompt +
# ".\n\nIf the user asks you a question about the assistant information, example: 'What can you tell me about the assistant?', 'What is the name of the assistant?', 'Who is the assistant?'. "+
# "\n\nPlease answer the question using the following information:\n\n" + personality_prompt + "."
# "\n\nAlthough if the user asks about who is other person which is not the assistant, please look into the context {context} to answer."
# )

prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
Expand Down Expand Up @@ -271,71 +276,88 @@ async def retrieve_chunks(self, page_id: str) -> List[Chunk]:

async def delete_documents_from_chroma(self, resource_id: str, file_ids: List[str]):
vectorstore = await self.initialize_chroma(resource_id)

for file_id in file_ids:
# Retrieve the file (you already have this step)
# Retrieve the files
files = await self.retrieve_file(file_id)

if files:
list_chunks_id = []

for file in files:
# Retrieve pages for the file
pages = await self.retrieve_pages(file.id)

for page in pages:
# Retrieve chunks for each page
chunks = await self.retrieve_chunks(page.id)

for chunk in chunks:
# Collect chunk IDs for deletion
list_chunks_id.append(chunk.id)

# Delete all chunks from ChromaDB
if list_chunks_id:
vectorstore.delete(ids=list_chunks_id)
else:
if not files:
return None

# Collect and delete chunks for the files
await self._process_files_for_deletion(files, vectorstore)

return "Documents deleted"

async def _process_files_for_deletion(self, files: List[File], vectorstore):
list_chunks_id = []

for file in files:
# Collect chunk IDs for each file
chunk_ids = await self._collect_chunk_ids(file.id)
list_chunks_id.extend(chunk_ids)

# Delete all chunks from ChromaDB
if list_chunks_id:
vectorstore.delete(ids=list_chunks_id)

async def _collect_chunk_ids(self, file_id: int) -> List[str]:
list_chunks_id = []

# Retrieve pages for the file
pages = await self.retrieve_pages(file_id)

for page in pages:
# Retrieve chunks for each page and collect their IDs
chunks = await self.retrieve_chunks(page.id)
list_chunks_id.extend([chunk.id for chunk in chunks])

return list_chunks_id


async def delete_file_from_db(self, file_ids: List[str]):
async with db_session_context() as session:
try:
for file_id in file_ids:
# Retrieve the file (you already have this step)
files = await self.retrieve_file(file_id)

if files:
for file in files:
# Retrieve pages for the file
pages = await self.retrieve_pages(file.id)

for page in pages:
# Retrieve chunks for each page
chunks = await self.retrieve_chunks(page.id)

# Delete all chunks for the page
if chunks:
stmt_delete_chunks = delete(Chunk).where(Chunk.id.in_([chunk.id for chunk in chunks]))
print("delete chunk")
await session.execute(stmt_delete_chunks)

# Delete all pages for the file
if pages:
stmt_delete_pages = delete(Page).where(Page.id.in_([page.id for page in pages]))
await session.execute(stmt_delete_pages)

# Finally, delete the file itself
stmt_delete_file = delete(File).where(File.id == file.id)
await session.execute(stmt_delete_file)

# Commit the transaction

if not files:
continue

await self._delete_files(files, session)

await session.commit()
return True
except Exception as e:
print("Error in delete from db:", e)
return None

async def _delete_files(self, files: List[FileSchema], session: db_session_context):
for file in files:
await self._delete_pages_and_chunks(file.id, session)
stmt_delete_file = delete(File).where(File.id == file.id)
await session.execute(stmt_delete_file)

async def _delete_pages_and_chunks(self, file_id: str, session: db_session_context):
pages = await self.retrieve_pages(file_id)
if not pages:
return

for page in pages:
# Retrieve and delete chunks
await self._delete_chunks(page.id, session)

# Delete all pages for the file
stmt_delete_pages = delete(Page).where(Page.id.in_([page.id for page in pages]))
await session.execute(stmt_delete_pages)

async def _delete_chunks(self, page_id: int, session):
chunks = await self.retrieve_chunks(page_id)
if chunks:
stmt_delete_chunks = delete(Chunk).where(Chunk.id.in_([chunk.id for chunk in chunks]))
print("delete chunk")
await session.execute(stmt_delete_chunks)

async def retrieve_files(self, resource_id: str, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None,
sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[FileSchema], int]:
Expand Down Expand Up @@ -373,4 +395,4 @@ async def _get_total_count(self, filters: Optional[Dict[str, Any]]) -> int:
count_query = self._apply_filters(count_query, filters)

total_count = await session.execute(count_query)
return total_count.scalar()
return total_count.scalar()

0 comments on commit 9735047

Please sign in to comment.