diff --git a/raganything/processor.py b/raganything/processor.py index 2155a75c..13eb0f11 100644 --- a/raganything/processor.py +++ b/raganything/processor.py @@ -1547,19 +1547,15 @@ async def process_document_complete( self.logger.info(f"Starting complete document processing: {file_path}") - # Step 1: Parse document content_list, content_based_doc_id = await self.parse_document( file_path, output_dir, parse_method, display_stats, **kwargs ) - # Use provided doc_id or fall back to content-based doc_id if doc_id is None: doc_id = content_based_doc_id - # Step 2: Separate text and multimodal content text_content, multimodal_items = separate_content(content_list) - # Step 2.5: Set content source for context extraction in multimodal processing if hasattr(self, "set_content_source_for_context") and multimodal_items: self.logger.info( "Setting content source for context-aware multimodal processing..." @@ -1568,12 +1564,14 @@ async def process_document_complete( content_list, self.config.content_format ) - # Step 3: Insert pure text content with all parameters - stage = "text_insert" + if file_name is None: + file_name = self._get_file_reference(file_path) + + stage = "parallel_processing" + tasks = [] + task_names = [] + if text_content.strip(): - if file_name is None: - # Use full path or basename based on config - file_name = self._get_file_reference(file_path) if callback_manager is not None: callback_manager.dispatch( "on_text_insert_start", @@ -1582,42 +1580,50 @@ async def process_document_complete( doc_id=doc_id, ) insert_start = time.time() - await insert_text_content( - self.lightrag, - input=text_content, - file_paths=file_name, - split_by_character=split_by_character, - split_by_character_only=split_by_character_only, - ids=doc_id, - ) - if callback_manager is not None: - insert_duration = time.time() - insert_start - callback_manager.dispatch( - "on_text_insert_complete", - file_path=file_name, - duration_seconds=insert_duration, - doc_id=doc_id, + + async def _insert_text_with_callback(): + await insert_text_content( + self.lightrag, + input=text_content, + file_paths=file_name, + split_by_character=split_by_character, + split_by_character_only=split_by_character_only, + ids=doc_id, ) - else: - # Determine file reference even if no text content - if file_name is None: - file_name = self._get_file_reference(file_path) + if callback_manager is not None: + insert_duration = time.time() - insert_start + callback_manager.dispatch( + "on_text_insert_complete", + file_path=file_name, + duration_seconds=insert_duration, + doc_id=doc_id, + ) + + tasks.append(_insert_text_with_callback()) + task_names.append("text insertion") - # Step 4: Process multimodal content (using specialized processors) - stage = "multimodal" if multimodal_items: - await self._process_multimodal_content( - multimodal_items, file_name, doc_id + tasks.append( + self._process_multimodal_content( + multimodal_items, file_name, doc_id + ) ) + task_names.append("multimodal processing") else: - # If no multimodal content, mark multimodal processing as complete - # This ensures the document status properly reflects completion of all processing await self._mark_multimodal_processing_complete(doc_id) self.logger.debug( f"No multimodal content found in document {doc_id}, " "marked multimodal processing as complete", ) + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + for name, result in zip(task_names, results): + if isinstance(result, Exception): + self.logger.error( + f"{name} failed for {file_path}: {result}" + ) + except Exception as exc: if callback_manager is not None: callback_manager.dispatch( diff --git a/tests/test_parallel_processing.py b/tests/test_parallel_processing.py new file mode 100644 index 00000000..87ffb8a4 --- /dev/null +++ b/tests/test_parallel_processing.py @@ -0,0 +1,135 @@ +"""Parallel text + multimodal processing tests.""" +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +@pytest.fixture +def processor(): + from raganything import RAGAnything, RAGAnythingConfig + + proc = RAGAnything.__new__(RAGAnything) + proc.config = RAGAnythingConfig() + proc.logger = MagicMock() + proc.lightrag = AsyncMock() + proc._ensure_lightrag_initialized = AsyncMock() + proc._mark_multimodal_processing_complete = AsyncMock() + proc.set_content_source_for_context = MagicMock() + return proc + + +@pytest.mark.asyncio +async def test_text_and_multimodal_run_concurrently(processor): + execution_log = [] + + async def fake_parse(*args, **kwargs): + return [ + {"type": "text", "text": "Hello world"}, + {"type": "image", "data": "base64data"}, + ], "doc123" + + async def fake_insert_text(*args, **kwargs): + execution_log.append(("text_start", asyncio.get_event_loop().time())) + await asyncio.sleep(0.1) + execution_log.append(("text_end", asyncio.get_event_loop().time())) + + async def fake_process_multimodal(*args, **kwargs): + execution_log.append(("mm_start", asyncio.get_event_loop().time())) + await asyncio.sleep(0.1) + execution_log.append(("mm_end", asyncio.get_event_loop().time())) + + processor.parse_document = fake_parse + processor._process_multimodal_content = fake_process_multimodal + + with patch("raganything.processor.separate_content") as mock_sep, \ + patch("raganything.processor.insert_text_content", new=fake_insert_text): + mock_sep.return_value = ("Hello world", [{"type": "image", "data": "base64data"}]) + + await processor.process_document_complete("test.pdf") + + starts = {e[0].replace("_start", ""): e[1] for e in execution_log if "start" in e[0]} + ends = {e[0].replace("_end", ""): e[1] for e in execution_log if "end" in e[0]} + + assert starts["mm"] < ends["text"], ( + "Multimodal processing should start before text insertion finishes (parallel)" + ) + + +@pytest.mark.parametrize("failing_branch", ["text", "multimodal"]) +@pytest.mark.asyncio +async def test_one_branch_failing_does_not_block_the_other(processor, failing_branch): + survived = False + + async def fake_parse(*args, **kwargs): + return [ + {"type": "text", "text": "Hello"}, + {"type": "image", "data": "img"}, + ], "doc123" + + async def fake_insert_text(*args, **kwargs): + nonlocal survived + if failing_branch == "text": + raise RuntimeError("Text insertion failed") + survived = True + + async def fake_process_multimodal(*args, **kwargs): + nonlocal survived + if failing_branch == "multimodal": + raise RuntimeError("Multimodal processing failed") + survived = True + + processor.parse_document = fake_parse + processor._process_multimodal_content = fake_process_multimodal + + with patch("raganything.processor.separate_content") as mock_sep, \ + patch("raganything.processor.insert_text_content", new=fake_insert_text): + mock_sep.return_value = ("Hello", [{"type": "image", "data": "img"}]) + + await processor.process_document_complete("test.pdf") + + assert survived, f"Non-failing branch should complete when {failing_branch} fails" + + +@pytest.mark.asyncio +async def test_text_only_document(processor): + insert_called = False + + async def fake_parse(*args, **kwargs): + return [{"type": "text", "text": "Hello world"}], "doc123" + + async def fake_insert_text(*args, **kwargs): + nonlocal insert_called + insert_called = True + + processor.parse_document = fake_parse + + with patch("raganything.processor.separate_content") as mock_sep, \ + patch("raganything.processor.insert_text_content", new=fake_insert_text): + mock_sep.return_value = ("Hello world", []) + + await processor.process_document_complete("test.pdf") + + assert insert_called + processor._mark_multimodal_processing_complete.assert_called_once() + + +@pytest.mark.asyncio +async def test_multimodal_only_document(processor): + mm_called = False + + async def fake_parse(*args, **kwargs): + return [{"type": "image", "data": "img"}], "doc123" + + async def fake_process_multimodal(*args, **kwargs): + nonlocal mm_called + mm_called = True + + processor.parse_document = fake_parse + processor._process_multimodal_content = fake_process_multimodal + + with patch("raganything.processor.separate_content") as mock_sep: + mock_sep.return_value = ("", [{"type": "image", "data": "img"}]) + + await processor.process_document_complete("test.pdf") + + assert mm_called