diff --git a/src/vidore_benchmark/__init__.py b/src/vidore_benchmark/__init__.py index fdae5877..657467d4 100644 --- a/src/vidore_benchmark/__init__.py +++ b/src/vidore_benchmark/__init__.py @@ -9,6 +9,7 @@ ColPaliRetriever, ColQwen2_5_Retriever, ColQwen2Retriever, + ColQwenOmniRetriever, DSEQwen2Retriever, DummyVisionRetriever, HFEndpointRetriever, diff --git a/src/vidore_benchmark/retrievers/__init__.py b/src/vidore_benchmark/retrievers/__init__.py index 3969189d..2e4ed805 100644 --- a/src/vidore_benchmark/retrievers/__init__.py +++ b/src/vidore_benchmark/retrievers/__init__.py @@ -8,6 +8,7 @@ from .colpali_retriever import ColPaliRetriever from .colqwen2_5_retriever import ColQwen2_5_Retriever from .colqwen2_retriever import ColQwen2Retriever +from .colqwenomni_retriever import ColQwenOmniRetriever from .dse_qwen2_retriever import DSEQwen2Retriever from .dummy_vision_retriever import DummyVisionRetriever from .hf_api_retriever import HFEndpointRetriever diff --git a/src/vidore_benchmark/retrievers/colqwen2_5_retriever.py b/src/vidore_benchmark/retrievers/colqwen2_5_retriever.py index 5be55421..5428d1d9 100644 --- a/src/vidore_benchmark/retrievers/colqwen2_5_retriever.py +++ b/src/vidore_benchmark/retrievers/colqwen2_5_retriever.py @@ -21,7 +21,7 @@ @register_vision_retriever("colqwen2_5") -class ColQwen2_5_Retriever(BaseVisionRetriever): #noqa: N801 +class ColQwen2_5_Retriever(BaseVisionRetriever): # noqa: N801 """ ColQwen2.5 retriever that implements the model from "ColPali: Efficient Document Retrieval with Vision Language Models". diff --git a/src/vidore_benchmark/retrievers/colqwenomni_retriever.py b/src/vidore_benchmark/retrievers/colqwenomni_retriever.py new file mode 100644 index 00000000..55cfe9b5 --- /dev/null +++ b/src/vidore_benchmark/retrievers/colqwenomni_retriever.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import logging +from typing import List, Optional, Union + +import torch +from dotenv import load_dotenv +from PIL import Image +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers.utils.import_utils import is_flash_attn_2_available + +from vidore_benchmark.retrievers.base_vision_retriever import BaseVisionRetriever +from vidore_benchmark.retrievers.registry_utils import register_vision_retriever +from vidore_benchmark.utils.data_utils import ListDataset +from vidore_benchmark.utils.torch_utils import get_torch_device + +logger = logging.getLogger(__name__) + +load_dotenv(override=True) + + +@register_vision_retriever("colqwen-omni") +class ColQwenOmniRetriever(BaseVisionRetriever): + """ + ColQwenOmni retriever that implements the model from "ColPali: Efficient Document Retrieval + with Vision Language Models". Based on the ColQwen2.5 Omni model. + """ + + def __init__( + self, + pretrained_model_name_or_path: str = "vidore/colqwen-omni-v0.1", + device: str = "auto", + num_workers: int = 0, + **kwargs, + ): + super().__init__(use_visual_embedding=True) + + try: + from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor + except ImportError: + raise ImportError( + 'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` ' + "to use ColQwenOmniRetriever." + ) + + self.device = get_torch_device(device) + self.num_workers = num_workers + + # Load the model and LORA adapter + self.model = ColQwen2_5Omni.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=self.device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + ).eval() + + # Load the processor + self.processor = ColQwen2_5OmniProcessor.from_pretrained(pretrained_model_name_or_path) + + def process_images(self, images: List[Image.Image], **kwargs): + return self.processor.process_images(images=images).to(self.device) + + def process_queries(self, queries: List[str], **kwargs): + return self.processor.process_queries(queries=queries).to(self.device) + + def forward_queries(self, queries: List[str], batch_size: int, **kwargs) -> List[torch.Tensor]: + dataloader = DataLoader( + dataset=ListDataset[str](queries), + batch_size=batch_size, + shuffle=False, + collate_fn=self.process_queries, + num_workers=self.num_workers, + ) + + query_embeddings: List[torch.Tensor] = [] + + with torch.no_grad(): + for batch_query in tqdm(dataloader, desc="Forward pass queries...", leave=False): + embeddings_query = self.model(**batch_query).to("cpu") + query_embeddings.extend(list(torch.unbind(embeddings_query))) + + return query_embeddings + + def forward_passages(self, passages: List[Image.Image], batch_size: int, **kwargs) -> List[torch.Tensor]: + dataloader = DataLoader( + dataset=ListDataset[Image.Image](passages), + batch_size=batch_size, + shuffle=False, + collate_fn=self.process_images, + num_workers=self.num_workers, + ) + + passage_embeddings: List[torch.Tensor] = [] + + with torch.no_grad(): + for batch_doc in tqdm(dataloader, desc="Forward pass documents...", leave=False): + embeddings_doc = self.model(**batch_doc).to("cpu") + passage_embeddings.extend(list(torch.unbind(embeddings_doc))) + + return passage_embeddings + + def get_scores( + self, + query_embeddings: Union[torch.Tensor, List[torch.Tensor]], + passage_embeddings: Union[torch.Tensor, List[torch.Tensor]], + batch_size: Optional[int] = 128, + ) -> torch.Tensor: + if batch_size is None: + raise ValueError("`batch_size` must be provided for ColQwenRetriever's scoring") + scores = self.processor.score( + query_embeddings, + passage_embeddings, + batch_size=batch_size, + device="cpu", + ) + return scores