diff --git a/aidial_rag/app.py b/aidial_rag/app.py index 17e33dc..83c948a 100644 --- a/aidial_rag/app.py +++ b/aidial_rag/app.py @@ -32,7 +32,6 @@ RequestType, get_configuration, ) -from aidial_rag.dial_api_client import create_dial_api_client from aidial_rag.document_record import Chunk, DocumentRecord from aidial_rag.documents import load_documents from aidial_rag.index_record import ChunkMetadata, RetrievalType @@ -248,7 +247,7 @@ async def chat_completion( self, request: Request, response: Response ) -> None: loop = asyncio.get_running_loop() - with create_request_context( + async with create_request_context( self.app_config.dial_url, request, response ) as request_context: choice = request_context.choice @@ -274,21 +273,19 @@ async def chat_completion( get_attachment_links(request_context, messages) ) - dial_api_client = await create_dial_api_client(request_context) index_storage = self.index_storage_holder.get_storage( - dial_api_client + request_context.dial_api_client ) indexing_tasks = create_indexing_tasks( attachment_links, - dial_api_client, + request_context.dial_api_client, ) indexing_results = await load_documents( request_context, indexing_tasks, index_storage, - dial_api_client, config=request_config, ) diff --git a/aidial_rag/attachment_link.py b/aidial_rag/attachment_link.py index 53643d4..94b701d 100644 --- a/aidial_rag/attachment_link.py +++ b/aidial_rag/attachment_link.py @@ -61,9 +61,10 @@ def to_dial_metadata_url( if not request_context.is_dial_url(absolute_url): return None - return urljoin( + absolute_metadata_url = urljoin( request_context.dial_metadata_base_url, link, allow_fragments=True ) + return to_dial_relative_url(request_context, absolute_metadata_url) class AttachmentLink(BaseModel): diff --git a/aidial_rag/dial_api_client.py b/aidial_rag/dial_api_client.py index 4e1a46f..e0bdc00 100644 --- a/aidial_rag/dial_api_client.py +++ b/aidial_rag/dial_api_client.py @@ -1,19 +1,20 @@ import io +from contextlib import asynccontextmanager +from typing import AsyncGenerator import aiohttp -from aidial_rag.request_context import RequestContext +from aidial_rag.dial_config import DialConfig -async def _get_bucket_id(dial_base_url, headers: dict) -> str: +async def _get_bucket_id(session: aiohttp.ClientSession, headers: dict) -> str: relative_url = ( "bucket" # /v1/ is already included in the base url for the Dial API ) - async with aiohttp.ClientSession(base_url=dial_base_url) as session: - async with session.get(relative_url, headers=headers) as response: - response.raise_for_status() - data = await response.json() - return data["bucket"] + async with session.get(relative_url, headers=headers) as response: + response.raise_for_status() + data = await response.json() + return data["bucket"] def _to_form_data(key: str, data: bytes, content_type: str) -> aiohttp.FormData: @@ -25,39 +26,35 @@ def _to_form_data(key: str, data: bytes, content_type: str) -> aiohttp.FormData: class DialApiClient: - def __init__(self, dial_api_base_url: str, headers: dict, bucket_id: str): + def __init__(self, client_session: aiohttp.ClientSession, bucket_id: str): + self._client_session = client_session self.bucket_id = bucket_id - self._dial_api_base_url = dial_api_base_url - self._headers = headers + @property + def session(self) -> aiohttp.ClientSession: + return self._client_session async def get_file(self, relative_url: str) -> bytes | None: - async with aiohttp.ClientSession( - base_url=self._dial_api_base_url - ) as session: - async with session.get( - relative_url, headers=self._headers - ) as response: - response.raise_for_status() - return await response.read() + async with self.session.get(relative_url) as response: + response.raise_for_status() + return await response.read() async def put_file( self, relative_url: str, data: bytes, content_type: str ) -> dict: - async with aiohttp.ClientSession( - base_url=self._dial_api_base_url - ) as session: - form_data = _to_form_data(relative_url, data, content_type) - async with session.put( - relative_url, data=form_data, headers=self._headers - ) as response: - response.raise_for_status() - return await response.json() + form_data = _to_form_data(relative_url, data, content_type) + async with self.session.put(relative_url, data=form_data) as response: + response.raise_for_status() + return await response.json() +@asynccontextmanager async def create_dial_api_client( - request_context: RequestContext, -) -> DialApiClient: - headers = request_context.get_api_key_headers() - bucket_id = await _get_bucket_id(request_context.dial_base_url, headers) - return DialApiClient(request_context.dial_base_url, headers, bucket_id) + config: DialConfig, +) -> AsyncGenerator[DialApiClient, None]: + headers = {"api-key": config.api_key.get_secret_value()} + async with aiohttp.ClientSession( + base_url=config.dial_base_url, headers=headers + ) as session: + bucket_id = await _get_bucket_id(session, headers) + yield DialApiClient(session, bucket_id) diff --git a/aidial_rag/dial_config.py b/aidial_rag/dial_config.py index d533fd6..42ecaca 100644 --- a/aidial_rag/dial_config.py +++ b/aidial_rag/dial_config.py @@ -6,3 +6,7 @@ class DialConfig(BaseConfig): dial_url: str api_key: SecretStr + + @property + def dial_base_url(self) -> str: + return f"{self.dial_url}/v1/" diff --git a/aidial_rag/dial_user_limits.py b/aidial_rag/dial_user_limits.py index 502ff1e..0d4d128 100644 --- a/aidial_rag/dial_user_limits.py +++ b/aidial_rag/dial_user_limits.py @@ -1,7 +1,6 @@ -import aiohttp from pydantic import BaseModel, Field -from aidial_rag.dial_config import DialConfig +from aidial_rag.dial_api_client import DialApiClient class TokenStats(BaseModel): @@ -20,18 +19,15 @@ class UserLimitsForModel(BaseModel): async def get_user_limits_for_model( - dial_config: DialConfig, deployment_name: str + dial_api_client: DialApiClient, deployment_name: str ) -> UserLimitsForModel: """Returns the user limits for the specified model deployment. See https://epam-rail.com/dial_api#tag/Limits for the API documentation. """ - headers = {"Api-Key": dial_config.api_key.get_secret_value()} - limits_url = ( - f"{dial_config.dial_url}/v1/deployments/{deployment_name}/limits" - ) - async with aiohttp.ClientSession() as session: - async with session.get(limits_url, headers=headers) as response: - response.raise_for_status() - limits_json = await response.json() - return UserLimitsForModel.model_validate(limits_json) + + limits_relative_url = f"deployments/{deployment_name}/limits" + async with dial_api_client.session.get(limits_relative_url) as response: + response.raise_for_status() + limits_json = await response.json() + return UserLimitsForModel.model_validate(limits_json) diff --git a/aidial_rag/document_loaders.py b/aidial_rag/document_loaders.py index 568a1db..8d07198 100644 --- a/aidial_rag/document_loaders.py +++ b/aidial_rag/document_loaders.py @@ -19,13 +19,13 @@ from aidial_rag.attachment_link import AttachmentLink from aidial_rag.base_config import BaseConfig, IndexRebuildTrigger from aidial_rag.content_stream import SupportsWriteStr +from aidial_rag.dial_api_client import DialApiClient from aidial_rag.errors import InvalidDocumentError from aidial_rag.image_processor.extract_pages import ( are_image_pages_supported, extract_number_of_pages, ) from aidial_rag.print_stats import print_documents_stats -from aidial_rag.request_context import RequestContext from aidial_rag.resources.cpu_pools import run_in_indexing_cpu_pool from aidial_rag.utils import format_size, get_bytes_length, timed_block @@ -85,18 +85,17 @@ class ParserConfig(BaseConfig): async def download_attachment( - url, headers, download_config: HttpClientConfig + url: str, session: aiohttp.ClientSession, download_config: HttpClientConfig ) -> tuple[str, bytes]: - async with aiohttp.ClientSession() as session: - async with session.get( - url, headers=headers, timeout=download_config.get_client_timeout() - ) as response: - response.raise_for_status() - content_type = response.headers.get("Content-Type", "") + async with session.get( + url, timeout=download_config.get_client_timeout() + ) as response: + response.raise_for_status() + content_type = response.headers.get("Content-Type", "") - content = await response.read() # Await the coroutine - logging.debug(f"Downloaded {url}: {len(content)} bytes") - return content_type, content + content = await response.read() + logging.debug(f"Downloaded {url}: {len(content)} bytes") + return content_type, content def add_source_metadata( @@ -121,7 +120,7 @@ def add_pdf_source_metadata( async def load_dial_document_metadata( - request_context: RequestContext, + dial_api_client: DialApiClient, attachment_link: AttachmentLink, config: HttpClientConfig, ) -> dict: @@ -131,29 +130,36 @@ async def load_dial_document_metadata( metadata_url = attachment_link.dial_metadata_url assert metadata_url is not None - headers = request_context.get_file_access_headers(metadata_url) - async with aiohttp.ClientSession( - timeout=config.get_client_timeout() - ) as session: - async with session.get(metadata_url, headers=headers) as response: - if not response.ok: - error_message = f"{response.status} {response.reason}" - raise InvalidDocumentError(error_message) - return await response.json() + async with dial_api_client.session.get( + metadata_url, timeout=config.get_client_timeout() + ) as response: + if not response.ok: + error_message = f"{response.status} {response.reason}" + raise InvalidDocumentError(error_message) + return await response.json() async def load_attachment( + dial_api_client: DialApiClient, attachment_link: AttachmentLink, - headers: dict, download_config: HttpClientConfig | None = None, ) -> tuple[str, str, bytes]: if download_config is None: download_config = HttpClientConfig() - absolute_url = attachment_link.absolute_url file_name = attachment_link.display_name - content_type, attachment_bytes = await download_attachment( - absolute_url, headers, download_config - ) + + if attachment_link.is_dial_document: + content_type, attachment_bytes = await download_attachment( + attachment_link.dial_link, dial_api_client.session, download_config + ) + else: + # Use separate session for non-Dial documents + # to avoid passing Dial headers to non-Dial servers + async with aiohttp.ClientSession() as session: + content_type, attachment_bytes = await download_attachment( + attachment_link.absolute_url, session, download_config + ) + if attachment_bytes: return file_name, content_type, attachment_bytes raise InvalidDocumentError( diff --git a/aidial_rag/documents.py b/aidial_rag/documents.py index fedc566..f67546d 100644 --- a/aidial_rag/documents.py +++ b/aidial_rag/documents.py @@ -78,7 +78,9 @@ async def check_document_access( ) as access_stage: try: await load_dial_document_metadata( - request_context, attachment_link, config.check_access + request_context.dial_api_client, + attachment_link, + config.check_access, ) except InvalidDocumentError as e: access_stage.append_content(e.message) @@ -102,6 +104,7 @@ def get_default_image_chunk(attachment_link: AttachmentLink): async def load_document_impl( + dial_api_client: DialApiClient, dial_config: DialConfig, dial_limited_resources: DialLimitedResources, attachment_link: AttachmentLink, @@ -116,16 +119,9 @@ async def load_document_impl( ) io_stream = MultiStream(MarkdownStream(stage_stream), logger_stream) - absolute_url = attachment_link.absolute_url - headers = ( - {"api-key": dial_config.api_key.get_secret_value()} - if absolute_url.startswith(dial_config.dial_url) - else {} - ) - file_name, content_type, original_doc_bytes = await load_attachment( + dial_api_client, attachment_link, - headers, download_config=config.download, ) logger.debug(f"Successfully loaded document {file_name} of {content_type}") @@ -235,10 +231,10 @@ async def load_document( request_context: RequestContext, task: IndexingTask, index_storage: IndexStorage, - dial_api_client: DialApiClient, config: RequestConfig, ) -> DocumentRecord: attachment_link = task.attachment_link + dial_api_client = request_context.dial_api_client with handle_document_processing_error( attachment_link, config.log_document_links ): @@ -247,7 +243,6 @@ async def load_document( choice = request_context.choice - # TODO: Move check_document_access to the DialApiClient await check_document_access(request_context, attachment_link, config) doc_record = None @@ -270,6 +265,7 @@ async def load_document( io_stream = doc_stage.content_stream try: doc_record = await load_document_impl( + dial_api_client, request_context.dial_config, request_context.dial_limited_resources, attachment_link, @@ -295,12 +291,11 @@ async def load_document_task( request_context: RequestContext, task: IndexingTask, index_storage: IndexStorage, - dial_api_client: DialApiClient, config: RequestConfig, ) -> DocumentIndexingResult: try: doc_record = await load_document( - request_context, task, index_storage, dial_api_client, config + request_context, task, index_storage, config ) return DocumentIndexingSuccess( task=task, @@ -318,16 +313,13 @@ async def load_documents( request_context: RequestContext, tasks: Iterable[IndexingTask], index_storage: IndexStorage, - dial_api_client: DialApiClient, config: RequestConfig, ) -> List[DocumentIndexingResult]: # TODO: Rewrite this function using TaskGroup to cancel all tasks if one of them fails # if ignore_document_loading_errors is not set in the config return await asyncio.gather( *[ - load_document_task( - request_context, task, index_storage, dial_api_client, config - ) + load_document_task(request_context, task, index_storage, config) for task in tasks ], ) diff --git a/aidial_rag/request_context.py b/aidial_rag/request_context.py index 5c85314..d6c2130 100644 --- a/aidial_rag/request_context.py +++ b/aidial_rag/request_context.py @@ -1,8 +1,9 @@ -from contextlib import contextmanager +from contextlib import asynccontextmanager from aidial_sdk.chat_completion import Choice, Request, Response from pydantic import BaseModel, SecretStr +from aidial_rag.dial_api_client import DialApiClient, create_dial_api_client from aidial_rag.dial_config import DialConfig from aidial_rag.dial_user_limits import get_user_limits_for_model from aidial_rag.errors import convert_and_log_exceptions @@ -13,6 +14,7 @@ class RequestContext(BaseModel): dial_url: str api_key: SecretStr choice: Choice + dial_api_client: DialApiClient dial_limited_resources: DialLimitedResources class Config: @@ -44,22 +46,26 @@ def get_api_key_headers(self) -> dict: return {"api-key": self.api_key.get_secret_value()} -@contextmanager -def create_request_context(dial_url: str, request: Request, response: Response): +@asynccontextmanager +async def create_request_context( + dial_url: str, request: Request, response: Response +): with convert_and_log_exceptions(): with response.create_single_choice() as choice: dial_config = DialConfig( dial_url=dial_url, api_key=SecretStr(request.api_key) ) - request_context = RequestContext( - dial_url=dial_url, - api_key=dial_config.api_key, - choice=choice, - dial_limited_resources=DialLimitedResources( - lambda model_name: get_user_limits_for_model( - dial_config, model_name - ) - ), - ) - yield request_context + async with create_dial_api_client(dial_config) as dial_api_client: + request_context = RequestContext( + dial_url=dial_url, + api_key=dial_config.api_key, + choice=choice, + dial_api_client=dial_api_client, + dial_limited_resources=DialLimitedResources( + lambda model_name: get_user_limits_for_model( + dial_api_client, model_name + ) + ), + ) + yield request_context diff --git a/tests/test_attachment_link.py b/tests/test_attachment_link.py index 10cdad3..b3bb77b 100644 --- a/tests/test_attachment_link.py +++ b/tests/test_attachment_link.py @@ -1,10 +1,12 @@ from unittest.mock import MagicMock +import aiohttp import pytest from aidial_sdk.chat_completion import Choice from pydantic import SecretStr from aidial_rag.attachment_link import AttachmentLink +from aidial_rag.dial_api_client import DialApiClient from aidial_rag.errors import InvalidAttachmentError from aidial_rag.request_context import RequestContext from aidial_rag.resources.dial_limited_resources import DialLimitedResources @@ -12,15 +14,19 @@ @pytest.fixture -def request_context(): +async def request_context_coro(): return RequestContext( dial_url="http://core.dial", api_key=SecretStr(""), choice=Choice(queue=MagicMock(), choice_index=0), + dial_api_client=DialApiClient( + client_session=aiohttp.ClientSession(), bucket_id="" + ), dial_limited_resources=DialLimitedResources(user_limits_mock()), ) +@pytest.mark.asyncio @pytest.mark.parametrize( "link, expected_absolute_url, expected_display_name", [ @@ -92,15 +98,17 @@ def request_context(): ), ], ) -def test_attachment_link_from_link( - request_context, link, expected_absolute_url, expected_display_name +async def test_attachment_link_from_link( + request_context_coro, link, expected_absolute_url, expected_display_name ): + request_context = await request_context_coro attachment_link = AttachmentLink.from_link(request_context, link) assert attachment_link.dial_link == link assert attachment_link.absolute_url == expected_absolute_url assert attachment_link.display_name == expected_display_name +@pytest.mark.asyncio @pytest.mark.parametrize( "link", [ @@ -108,11 +116,13 @@ def test_attachment_link_from_link( "file.txt", ], ) -def test_attachment_link_errors(request_context, link): +async def test_attachment_link_errors(request_context_coro, link): + request_context = await request_context_coro with pytest.raises(InvalidAttachmentError): AttachmentLink.from_link(request_context, link) +@pytest.mark.asyncio @pytest.mark.parametrize( "link, expected_dial_link, expected_absolute_url, expected_metadata_url", [ @@ -126,23 +136,24 @@ def test_attachment_link_errors(request_context, link): "files/bucket/file.txt", "files/bucket/file.txt", "http://core.dial/v1/files/bucket/file.txt", - "http://core.dial/v1/metadata/files/bucket/file.txt", + "metadata/files/bucket/file.txt", ), ( "http://core.dial/v1/files/bucket/file.txt", "files/bucket/file.txt", "http://core.dial/v1/files/bucket/file.txt", - "http://core.dial/v1/metadata/files/bucket/file.txt", + "metadata/files/bucket/file.txt", ), ], ) -def test_metadata_url( - request_context, +async def test_metadata_url( + request_context_coro, link, expected_dial_link, expected_absolute_url, expected_metadata_url, ): + request_context = await request_context_coro attachment_link = AttachmentLink.from_link(request_context, link) assert attachment_link.dial_link == expected_dial_link assert attachment_link.absolute_url == expected_absolute_url diff --git a/tests/test_attachment_stored.py b/tests/test_attachment_stored.py index 4373f8f..8960c76 100644 --- a/tests/test_attachment_stored.py +++ b/tests/test_attachment_stored.py @@ -24,11 +24,17 @@ @pytest.fixture -def request_context(): +def dial_api_client(): + return MockDialApiClient() + + +@pytest.fixture +def request_context(dial_api_client): return RequestContext( dial_url="http://localhost:8080", api_key=SecretStr("ABRAKADABRA"), choice=Choice(queue=MagicMock(), choice_index=0), + dial_api_client=dial_api_client, dial_limited_resources=DialLimitedResources(user_limits_mock()), ) @@ -37,6 +43,7 @@ class MockDialApiClient(DialApiClient): def __init__(self): self.bucket_id = "test_bucket" self.storage = {} + self._client_session = None async def get_file(self, relative_url): if relative_url in self.storage: @@ -48,11 +55,6 @@ async def put_file(self, relative_url, data, content_type): return {} -@pytest.fixture -def dial_api_client(): - return MockDialApiClient() - - @pytest.fixture def index_storage(dial_api_client): return IndexStorageHolder().get_storage(dial_api_client) @@ -85,7 +87,7 @@ async def test_attachment_test(mock_fetch, request_context, attachment_link): headers = request_context.get_file_access_headers(absolute_url) filename, _content_type, bytes_value = await load_attachment( - attachment_link, headers + request_context.dial_api_client, attachment_link, headers ) assert filename == "folder 1/file-example_PDF 500_kB.pdf" @@ -104,7 +106,6 @@ async def test_load_document_success( mock_fetch, mock_check_document_access, request_context, - dial_api_client, index_storage, attachment_link, ): @@ -115,9 +116,10 @@ async def test_load_document_success( MagicMock(), 0, 0, name ) + bucket_id = request_context.dial_api_client.bucket_id indexing_task = IndexingTask( attachment_link=attachment_link, - index_url=link_to_index_url(attachment_link, dial_api_client.bucket_id), + index_url=link_to_index_url(attachment_link, bucket_id), ) # Download and store @@ -125,7 +127,6 @@ async def test_load_document_success( request_context, indexing_task, index_storage, - dial_api_client, config=request_config, ) assert isinstance(doc_record, DocumentRecord) @@ -151,7 +152,6 @@ async def test_load_document_invalid_document( mock_fetch, mock_check_document_access, request_context, - dial_api_client, index_storage, attachment_link, ): @@ -162,8 +162,8 @@ async def test_load_document_invalid_document( MagicMock(), 0, 0, name ) - dial_api_client = MockDialApiClient() - index_url = link_to_index_url(attachment_link, dial_api_client.bucket_id) + bucket_id = request_context.dial_api_client.bucket_id + index_url = link_to_index_url(attachment_link, bucket_id) with pytest.raises(DocumentProcessingError) as exc_info: await load_document( @@ -173,7 +173,6 @@ async def test_load_document_invalid_document( index_url=index_url, ), index_storage, - dial_api_client, config=request_config, ) assert isinstance(exc_info.value.__cause__, InvalidDocumentError) diff --git a/tests/test_load_documents.py b/tests/test_load_documents.py index 4344c03..f7c70ee 100644 --- a/tests/test_load_documents.py +++ b/tests/test_load_documents.py @@ -1,4 +1,5 @@ import sys +from unittest.mock import MagicMock import pytest @@ -20,8 +21,9 @@ async def load_document(name): display_name=name, ) + dial_api_client = MagicMock() _file_name, content_type, buffer = await load_attachment( - attachment_link, {} + dial_api_client, attachment_link ) mime_type, _ = parse_content_type(content_type) chunks = await parse_document( diff --git a/tests/test_multimodal_retriever.py b/tests/test_multimodal_retriever.py index 14a9cf2..aa77de6 100644 --- a/tests/test_multimodal_retriever.py +++ b/tests/test_multimodal_retriever.py @@ -7,6 +7,7 @@ from pydantic import SecretStr from aidial_rag.attachment_link import AttachmentLink +from aidial_rag.dial_api_client import create_dial_api_client from aidial_rag.dial_config import DialConfig from aidial_rag.dial_user_limits import get_user_limits_for_model from aidial_rag.document_loaders import load_attachment, parse_document @@ -30,26 +31,6 @@ PORT = 5008 -async def load_document(name): - document_link = f"http://localhost:{PORT}/{name}" - - attachment_link = AttachmentLink( - dial_link=document_link, - absolute_url=document_link, - display_name=name, - ) - - _file_name, content_type, buffer = await load_attachment( - attachment_link, {} - ) - mime_type, _ = parse_content_type(content_type) - document = await parse_document( - sys.stderr, buffer, mime_type, attachment_link, mime_type - ) - assert document - return document - - @pytest.fixture def local_server(): with start_local_server(data_dir=DATA_DIR, port=PORT) as server: @@ -76,81 +57,85 @@ def has_dial_access(): async def run_test_retrievers( - local_server, multimodal_index_config: MultimodalIndexConfig + local_server, + multimodal_index_config: MultimodalIndexConfig, ): dial_config = DialConfig( dial_url=os.environ.get("DIAL_URL", "http://localhost:8080"), api_key=SecretStr(os.environ.get("DIAL_RAG_API_KEY", "dial_api_key")), ) - name = "alps_wiki.pdf" - document_link = f"http://localhost:{PORT}/{name}" + async with create_dial_api_client(dial_config) as dial_api_client: + name = "alps_wiki.pdf" + document_link = f"http://localhost:{PORT}/{name}" - attachment_link = AttachmentLink( - dial_link=document_link, - absolute_url=document_link, - display_name=name, - ) + attachment_link = AttachmentLink( + dial_link=document_link, + absolute_url=document_link, + display_name=name, + ) - _file_name, content_type, buffer = await load_attachment( - attachment_link, {} - ) - mime_type, _ = parse_content_type(content_type) - text_chunks = await parse_document( - sys.stderr, buffer, mime_type, attachment_link, mime_type - ) + _file_name, content_type, buffer = await load_attachment( + dial_api_client, attachment_link + ) + mime_type, _ = parse_content_type(content_type) + text_chunks = await parse_document( + sys.stderr, buffer, mime_type, attachment_link, mime_type + ) - index_config = IndexingConfig( - multimodal_index=multimodal_index_config, - description_index=None, - ) + index_config = IndexingConfig( + multimodal_index=multimodal_index_config, + description_index=None, + ) - chunks = await build_chunks_list(text_chunks) - multimodal_index = await MultimodalRetriever.build_index( - dial_config=dial_config, - dial_limited_resources=DialLimitedResources( - lambda model_name: get_user_limits_for_model( - dial_config, model_name - ) - ), - index_config=multimodal_index_config, - mime_type=mime_type, - original_document=buffer, - stageio=sys.stderr, - ) + chunks = await build_chunks_list(text_chunks) + multimodal_index = await MultimodalRetriever.build_index( + dial_config=dial_config, + dial_limited_resources=DialLimitedResources( + lambda model_name: get_user_limits_for_model( + dial_api_client, model_name + ) + ), + index_config=multimodal_index_config, + mime_type=mime_type, + original_document=buffer, + stageio=sys.stderr, + ) - doc_record = DocumentRecord( - format_version=FORMAT_VERSION, - index_settings=index_config.collect_fields_that_rebuild_index(), - chunks=chunks, - text_index=None, - embeddings_index=None, - multimodal_embeddings_index=multimodal_index, - description_embeddings_index=None, - document_bytes=buffer, - mime_type=mime_type, - ) - doc_records = [doc_record] + doc_record = DocumentRecord( + format_version=FORMAT_VERSION, + index_settings=index_config.collect_fields_that_rebuild_index(), + chunks=chunks, + text_index=None, + embeddings_index=None, + multimodal_embeddings_index=multimodal_index, + description_embeddings_index=None, + document_bytes=buffer, + mime_type=mime_type, + ) + doc_records = [doc_record] - multimodal_retriever = MultimodalRetriever.from_doc_records( - dial_config=dial_config, - index_config=multimodal_index_config, - document_records=doc_records, - k=7, - ) + multimodal_retriever = MultimodalRetriever.from_doc_records( + dial_config=dial_config, + index_config=multimodal_index_config, + document_records=doc_records, + k=7, + ) - res = await run_retrevier( - multimodal_retriever, doc_records, "image of butterfly" - ) - assert len(res) - assert res[0].metadata["page_number"] == 13 + res = await run_retrevier( + multimodal_retriever, doc_records, "image of butterfly" + ) + assert len(res) + assert res[0].metadata["page_number"] == 13 @pytest.mark.skipif( not has_dial_access(), reason="DIAL_URL and DIAL_RAG_API_KEY are not set" ) @pytest.mark.asyncio -async def test_multimodalembedding_001(local_server): +async def test_multimodalembedding_001( + local_server, +): await run_test_retrievers( local_server, multimodal_index_config=MultimodalIndexConfig( diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index 44b6af9..9e211a4 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -1,10 +1,14 @@ import sys from operator import itemgetter +import aiohttp import pytest from langchain.schema.runnable import RunnablePassthrough +from pydantic import SecretStr from aidial_rag.attachment_link import AttachmentLink +from aidial_rag.dial_api_client import DialApiClient +from aidial_rag.dial_config import DialConfig from aidial_rag.document_loaders import load_attachment, parse_document from aidial_rag.document_record import ( FORMAT_VERSION, @@ -52,9 +56,18 @@ async def test_retrievers(local_server): display_name=name, ) - _file_name, content_type, buffer = await load_attachment( - attachment_link, {} + dial_config = DialConfig( + dial_url=f"http://localhost:{PORT}", api_key=SecretStr("") ) + + async with aiohttp.ClientSession( + base_url=dial_config.dial_base_url + ) as session: + _file_name, content_type, buffer = await load_attachment( + DialApiClient(session, bucket_id=""), + attachment_link, + ) + mime_type, _ = parse_content_type(content_type) text_chunks = await parse_document( sys.stderr, buffer, mime_type, attachment_link, mime_type @@ -115,9 +128,18 @@ async def test_pdf_with_no_text(local_server): display_name=name, ) - _file_name, content_type, buffer = await load_attachment( - attachment_link, {} + dial_config = DialConfig( + dial_url=f"http://localhost:{PORT}", api_key=SecretStr("") ) + + async with aiohttp.ClientSession( + base_url=dial_config.dial_base_url + ) as session: + _file_name, content_type, buffer = await load_attachment( + DialApiClient(session, bucket_id=""), + attachment_link, + ) + mime_type, _ = parse_content_type(content_type) text_chunks = await parse_document( sys.stderr, buffer, mime_type, attachment_link, mime_type