From 231a1f44900fa83aae33b6a710aac21f25e37858 Mon Sep 17 00:00:00 2001 From: Roy Quesada Date: Mon, 30 Sep 2024 14:40:55 -0600 Subject: [PATCH] Fix indexing to work multithread --- backend/managers/RagManager.py | 235 +++++++++++------- backend/models.py | 15 +- backend/requirements.txt | 3 +- backend/schemas.py | 3 +- .../versions/29df33c77244_added_file_table.py | 2 - .../5a1a6050b8f0_added_chunk_table.py | 35 +++ .../versions/c189fb6eda90_added_page_table.py | 33 +++ 7 files changed, 224 insertions(+), 102 deletions(-) create mode 100644 migrations/versions/5a1a6050b8f0_added_chunk_table.py create mode 100644 migrations/versions/c189fb6eda90_added_page_table.py diff --git a/backend/managers/RagManager.py b/backend/managers/RagManager.py index 4526b62f..74f8318b 100644 --- a/backend/managers/RagManager.py +++ b/backend/managers/RagManager.py @@ -12,7 +12,7 @@ import shutil from starlette.datastructures import UploadFile from uuid import uuid4 -from backend.models import File +from backend.models import File, Page, Chunk from backend.db import db_session_context from sqlalchemy import delete, select, func from pathlib import Path @@ -22,10 +22,14 @@ import os import logging from enum import Enum +import aiofiles +import asyncio logger = logging.getLogger(__name__) class FileStatus(Enum): + WAITING = 'waiting' + UPLOADED = 'uploaded' SPLITTING = 'splitting' SPLIT = 'split' INDEXING = 'indexing' @@ -49,8 +53,8 @@ def __init__(self): if not hasattr(self, '_initialized'): self._initialized = True - async def create_index(self, resource_id: str, path_files: List[str]) -> List[dict]: - # Define the text splitter + async def create_index(self, resource_id: str, path_files: List[str], files_ids:List[str]) -> List[dict]: + loop = asyncio.get_running_loop() text_splitter = RecursiveCharacterTextSplitter( chunk_size=int(os.environ.get('CHUNK_SIZE')), chunk_overlap=int(os.environ.get('CHUNK_OVERLAP')), @@ -58,55 +62,42 @@ async def create_index(self, resource_id: str, path_files: List[str]) -> List[di ) file_info_list = [] - - # Initialize the vector store once vectorstore = await self.initialize_chroma(resource_id) - # Iterate over all the files - for path in path_files: - file_id = str(uuid4()) + for path, file_id in zip(path_files, files_ids): file_name = Path(path).name try: # Load the PDF loader = PyPDFLoader(path) - docs = loader.load() # Load all pages at once + docs = await loop.run_in_executor(None, loader.load) # Load all pages at once split_documents = [] split_ids = [] # Process each page in the PDF for doc in docs: page_id = str(uuid4()) # Unique ID for each page - - # Create a File entry in the database with status 'splitting' - await self.create_file( - assistant_id=resource_id, - file_id=file_id, - file_name=file_name, - page_id=page_id, # Unique page ID - indexing_status=FileStatus.SPLITTING.value - ) - file_info_list.append({"file_id": file_id, "file_name": file_name, "page_id": page_id}) + await self.update_file_status(file_id, FileStatus.SPLITTING.value) # Split the document into smaller chunks splits = text_splitter.split_documents([doc]) + await self.create_page(page_id, file_id, resource_id) - # Update num_chunks for all File records with the current file_id - await self.update_file_num_chunks(file_id, len(splits)) - - for i, split in enumerate(splits): + for split in splits: + chunk_id = str(uuid4()) # Unique ID for each chunk + await self.create_chunk(chunk_id, page_id, file_id, resource_id) split.metadata["original_id"] = page_id split_documents.append(split) - split_ids.append(f"{page_id}-{i}") + split_ids.append(chunk_id) # Update status to 'split' after splitting await self.update_file_status(file_id, FileStatus.SPLIT.value) # Add the split documents to the vectorstore and update status to 'indexing' await self.update_file_status(file_id, FileStatus.INDEXING.value) - vectorstore.add_documents(documents=split_documents, ids=split_ids) + await loop.run_in_executor(None, lambda: vectorstore.add_documents(documents=split_documents, ids=split_ids)) # Update status to 'done' once indexing is complete await self.update_file_status(file_id, FileStatus.DONE.value) @@ -115,47 +106,46 @@ async def create_index(self, resource_id: str, path_files: List[str]) -> List[di # Update status to 'failed' if an error occurs await self.update_file_status(file_id, FileStatus.FAILED.value) raise e - - return file_info_list - async def update_file_num_chunks(self, file_id: str, num_chunks: int): - async with db_session_context() as session: - stmt = select(File).filter(File.file_id == file_id) - files = (await session.execute(stmt)).scalars().all() - - for file in files: - file.num_chunks = str(num_chunks) - - await session.commit() + return file_info_list async def update_file_status(self, file_id: str, status: str): async with db_session_context() as session: - stmt = select(File).filter(File.file_id == file_id) + stmt = select(File).filter(File.id == file_id) files = (await session.execute(stmt)).scalars().all() for file in files: file.indexing_status = status - await session.commit() - - async def create_files_for_resource(self, resource_id: str, file_info_list: List[dict]): - for file_info in file_info_list: - await self.create_file( - assistant_id=resource_id, - file_id=file_info['file_id'], - file_name=file_info['file_name'], - page_id=file_info['page_id'], - indexing_status=file_info.get('indexing_status', 'initializing'), - num_chunks=file_info.get('num_chunks', 0) # Default to 0 if not set - ) + await session.commit() - async def create_file(self, assistant_id: str, file_id: str, file_name: str, page_id: str, indexing_status: str, num_chunks: int = 0): + async def create_file(self, file_id: str, assistant_id: str, file_name: str, indexing_status: str): async with db_session_context() as session: try: - new_file = File(id=page_id, name=file_name, assistant_id=assistant_id, file_id=file_id, num_chunks=str(num_chunks), indexing_status=indexing_status) + new_file = File(id=file_id, name=file_name, assistant_id=assistant_id, indexing_status=indexing_status) session.add(new_file) await session.commit() await session.refresh(new_file) except Exception as e: - print(f"An error occurred: {e}") + print(f"An error occurred creating a file: {e}") + + async def create_page(self, page_id: str, file_id: str, assistant_id: str): + async with db_session_context() as session: + try: + new_page = Page(id=page_id, file_id=file_id, assistant_id=assistant_id) + session.add(new_page) + await session.commit() + await session.refresh(new_page) + except Exception as e: + print(f"An error occurred creating a page: {e}") + + async def create_chunk(self, chunk_id:str, page_id: str, file_id: str, assistant_id: str): + async with db_session_context() as session: + try: + new_chunk = Chunk(id=chunk_id, page_id=page_id, file_id=file_id, assistant_id=assistant_id) + session.add(new_chunk) + await session.commit() + await session.refresh(new_chunk) + except Exception as e: + print(f"An error occurred creating a chunk: {e}") async def initialize_chroma(self, collection_name: str): @@ -196,30 +186,50 @@ async def retrieve_and_generate(self, collection_name, query, llm) -> str: print("Query: ", query) # Invoke the RAG chain with query as input response = rag_chain.invoke({"input": query}) - return response + return response + + async def save_file(self, file: UploadFile, directory: Path): + file_path = directory / file.filename + async with aiofiles.open(file_path, 'wb') as out_file: + while content := await file.read(1024): # Read in chunks + await out_file.write(content) + return str(file_path.absolute()) + async def upload_file(self, resource_id: str, files: List[UploadFile]) -> Union[List[dict], str]: + all_files_paths = [] + all_files_ids = [] + for file in files: + file_id = str(uuid4()) + # Create a File entry in the database with status 'splitting' + await self.create_file( + file_id=file_id, + assistant_id=resource_id, + file_name=file.filename, + indexing_status=FileStatus.WAITING.value + ) + all_files_ids.append(file_id) + + try: - all_docs = [] - for file in files: - # Define the directory where files will be saved - directory = Path(f"./uploads/{resource_id}") + + for file, file_id in zip(files, all_files_ids): + + directory = Path(f"./uploads/{resource_id}") directory.mkdir(parents=True, exist_ok=True) - # Save the file - file_path = directory / file.filename - path = str(file_path.absolute()) - all_docs.append(path) - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) + path = await self.save_file(file, directory) + await self.update_file_status(file_id, FileStatus.UPLOADED.value) + all_files_paths.append(path) - result = await self.create_index(resource_id, all_docs) + result = await self.create_index(resource_id, all_files_paths, all_files_ids) await self.delete_tmp_files(resource_id) return result except Exception as e: print(f"An error occurred while uploading files: {e}") return "File upload failed" + async def delete_tmp_files(self, assistant_id: str): try: # Define the directory path using the assistant_id @@ -236,52 +246,94 @@ async def delete_tmp_files(self, assistant_id: str): async def retrieve_file(self, file_id:str) -> Optional[List[FileSchema]]: async with db_session_context() as session: - result = await session.execute(select(File).filter(File.file_id == file_id)) + result = await session.execute(select(File).filter(File.id == file_id)) files = [FileSchema.from_orm(file) for file in result.scalars().all()] if files: return files return None + + # Method to retrieve pages for a given file_id + async def retrieve_pages(self, file_id: str) -> List[Page]: + async with db_session_context() as session: + result = await session.execute(select(Page).filter(Page.file_id == file_id)) + pages = result.scalars().all() + return pages + + # Method to retrieve chunks for a given page_id + async def retrieve_chunks(self, page_id: str) -> List[Chunk]: + async with db_session_context() as session: + result = await session.execute(select(Chunk).filter(Chunk.page_id == page_id)) + chunks = result.scalars().all() + return chunks - async def delete_documents_from_chroma(self, resource_id: str, file_ids=List[str]): + async def delete_documents_from_chroma(self, resource_id: str, file_ids: List[str]): vectorstore = await self.initialize_chroma(resource_id) for file_id in file_ids: + # Retrieve the file (you already have this step) files = await self.retrieve_file(file_id) if files: - page_ids = [] - for file in files: - num_chunks = int(file.num_chunks) - page_id = file.id - page_ids.append(page_id) + list_chunks_id = [] - list_chunks_id = [] - for n in range(0, num_chunks): - chunk_id = f"{page_id}-{n}" - list_chunks_id.append(chunk_id) + for file in files: + # Retrieve pages for the file + pages = await self.retrieve_pages(file.id) + for page in pages: + # Retrieve chunks for each page + chunks = await self.retrieve_chunks(page.id) + + for chunk in chunks: + # Collect chunk IDs for deletion + list_chunks_id.append(chunk.id) + + # Delete all chunks from ChromaDB + if list_chunks_id: vectorstore.delete(ids=list_chunks_id) else: return None return "Documents deleted" + async def delete_file_from_db(self, file_ids: List[str]): - page_ids = [] - for file_id in file_ids: - files = await self.retrieve_file(file_id) - if files: - for file in files: - page_id = file.id - page_ids.append(page_id) async with db_session_context() as session: - try: - stmt = delete(File).where(File.id.in_(page_ids)) - result = await session.execute(stmt) + try: + for file_id in file_ids: + # Retrieve the file (you already have this step) + files = await self.retrieve_file(file_id) + + if files: + for file in files: + # Retrieve pages for the file + pages = await self.retrieve_pages(file.id) + + for page in pages: + # Retrieve chunks for each page + chunks = await self.retrieve_chunks(page.id) + + # Delete all chunks for the page + if chunks: + stmt_delete_chunks = delete(Chunk).where(Chunk.id.in_([chunk.id for chunk in chunks])) + print("delete chunk") + await session.execute(stmt_delete_chunks) + + # Delete all pages for the file + if pages: + stmt_delete_pages = delete(Page).where(Page.id.in_([page.id for page in pages])) + await session.execute(stmt_delete_pages) + + # Finally, delete the file itself + stmt_delete_file = delete(File).where(File.id == file.id) + await session.execute(stmt_delete_file) + + # Commit the transaction await session.commit() - return result.rowcount > 0 + return True except Exception as e: - print("error in delete from db",e) + print("Error in delete from db:", e) return None + async def retrieve_files(self, resource_id: str, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[FileSchema], int]: async with db_session_context() as session: @@ -291,16 +343,9 @@ async def retrieve_files(self, resource_id: str, offset: int = 0, limit: int = 1 result = await session.execute(query) files = [FileSchema.from_orm(file) for file in result.scalars().all()] - seen_file_ids = set() - unique_files = [] - for file in files: - if file.file_id not in seen_file_ids: - unique_files.append(file) - seen_file_ids.add(file.file_id) - total_count = await self._get_total_count(filters) - return unique_files, total_count + return files, total_count def _apply_filters(self, query, filters: Optional[Dict[str, Any]]): if filters: diff --git a/backend/models.py b/backend/models.py index bf8a8b3d..0c2cff17 100644 --- a/backend/models.py +++ b/backend/models.py @@ -61,10 +61,21 @@ class File(Base): id = Column(String, primary_key=True) name = Column(String, nullable=False) assistant_id = Column(String, nullable=False) - num_chunks = Column(String, nullable=False) - file_id = Column(String, nullable=False) indexing_status = Column(String, nullable=False) +class Page(Base): + __tablename__ = "page" + id = Column(String, primary_key=True) + file_id = Column(String, nullable=False) + assistant_id = Column(String, nullable=False) + +class Chunk(Base): + __tablename__ = "chunk" + id = Column(String, primary_key=True) + page_id = Column(String, nullable=False) + file_id = Column(String, nullable=False) + assistant_id = Column(String, nullable=False) + class Message(Base): __tablename__ = "message" id = Column(String, primary_key=True) diff --git a/backend/requirements.txt b/backend/requirements.txt index d8988564..8f7597f9 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -21,4 +21,5 @@ langchain_community langchain_chroma langchain_openai pypdf -httpx \ No newline at end of file +httpx +langchain_ollama \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas.py index d5c2b336..8bd74351 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -154,8 +154,7 @@ class ConversationSchema(ConversationBaseSchema): # File schemas class FileBaseSchema(BaseModel): name: str - num_chunks: str - file_id: str + assistant_id: str indexing_status: str class Config: orm_mode = True diff --git a/migrations/versions/29df33c77244_added_file_table.py b/migrations/versions/29df33c77244_added_file_table.py index 933735ae..f54dd800 100644 --- a/migrations/versions/29df33c77244_added_file_table.py +++ b/migrations/versions/29df33c77244_added_file_table.py @@ -22,8 +22,6 @@ def upgrade() -> None: sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column('file_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('num_chunks', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column('indexing_status', sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.PrimaryKeyConstraint('id') ) diff --git a/migrations/versions/5a1a6050b8f0_added_chunk_table.py b/migrations/versions/5a1a6050b8f0_added_chunk_table.py new file mode 100644 index 00000000..680b6212 --- /dev/null +++ b/migrations/versions/5a1a6050b8f0_added_chunk_table.py @@ -0,0 +1,35 @@ +"""added chunk table + +Revision ID: 5a1a6050b8f0 +Revises: c189fb6eda90 +Create Date: 2024-09-26 15:28:51.352477 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = '5a1a6050b8f0' +down_revision: Union[str, None] = 'c189fb6eda90' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('chunk', + sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('page_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('file_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('chunk') + # ### end Alembic commands ### diff --git a/migrations/versions/c189fb6eda90_added_page_table.py b/migrations/versions/c189fb6eda90_added_page_table.py new file mode 100644 index 00000000..ef80d766 --- /dev/null +++ b/migrations/versions/c189fb6eda90_added_page_table.py @@ -0,0 +1,33 @@ +"""added page table + +Revision ID: c189fb6eda90 +Revises: dcaf2be4345d +Create Date: 2024-09-26 14:36:50.706957 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = 'c189fb6eda90' +down_revision: Union[str, None] = 'dcaf2be4345d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('page', + sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('file_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('page') + # ### end Alembic commands ###