Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vidore_benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ColPaliRetriever,
ColQwen2_5_Retriever,
ColQwen2Retriever,
ColQwenOmniRetriever,
DSEQwen2Retriever,
DummyVisionRetriever,
HFEndpointRetriever,
Expand Down
1 change: 1 addition & 0 deletions src/vidore_benchmark/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .colpali_retriever import ColPaliRetriever
from .colqwen2_5_retriever import ColQwen2_5_Retriever
from .colqwen2_retriever import ColQwen2Retriever
from .colqwenomni_retriever import ColQwenOmniRetriever
from .dse_qwen2_retriever import DSEQwen2Retriever
from .dummy_vision_retriever import DummyVisionRetriever
from .hf_api_retriever import HFEndpointRetriever
Expand Down
2 changes: 1 addition & 1 deletion src/vidore_benchmark/retrievers/colqwen2_5_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@register_vision_retriever("colqwen2_5")
class ColQwen2_5_Retriever(BaseVisionRetriever): #noqa: N801
class ColQwen2_5_Retriever(BaseVisionRetriever): # noqa: N801
"""
ColQwen2.5 retriever that implements the model from "ColPali: Efficient Document Retrieval
with Vision Language Models".
Expand Down
117 changes: 117 additions & 0 deletions src/vidore_benchmark/retrievers/colqwenomni_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import logging
from typing import List, Optional, Union

import torch
from dotenv import load_dotenv
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.utils.import_utils import is_flash_attn_2_available

from vidore_benchmark.retrievers.base_vision_retriever import BaseVisionRetriever
from vidore_benchmark.retrievers.registry_utils import register_vision_retriever
from vidore_benchmark.utils.data_utils import ListDataset
from vidore_benchmark.utils.torch_utils import get_torch_device

logger = logging.getLogger(__name__)

load_dotenv(override=True)


@register_vision_retriever("colqwen-omni")
class ColQwenOmniRetriever(BaseVisionRetriever):
"""
ColQwenOmni retriever that implements the model from "ColPali: Efficient Document Retrieval
with Vision Language Models". Based on the ColQwen2.5 Omni model.
"""

def __init__(
self,
pretrained_model_name_or_path: str = "vidore/colqwen-omni-v0.1",
device: str = "auto",
num_workers: int = 0,
**kwargs,
):
super().__init__(use_visual_embedding=True)

try:
from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
except ImportError:
raise ImportError(
'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` '
"to use ColQwenOmniRetriever."
)

self.device = get_torch_device(device)
self.num_workers = num_workers

# Load the model and LORA adapter
self.model = ColQwen2_5Omni.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=self.device,
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()

# Load the processor
self.processor = ColQwen2_5OmniProcessor.from_pretrained(pretrained_model_name_or_path)

def process_images(self, images: List[Image.Image], **kwargs):
return self.processor.process_images(images=images).to(self.device)

def process_queries(self, queries: List[str], **kwargs):
return self.processor.process_queries(queries=queries).to(self.device)

def forward_queries(self, queries: List[str], batch_size: int, **kwargs) -> List[torch.Tensor]:
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=batch_size,
shuffle=False,
collate_fn=self.process_queries,
num_workers=self.num_workers,
)

query_embeddings: List[torch.Tensor] = []

with torch.no_grad():
for batch_query in tqdm(dataloader, desc="Forward pass queries...", leave=False):
embeddings_query = self.model(**batch_query).to("cpu")
query_embeddings.extend(list(torch.unbind(embeddings_query)))

return query_embeddings

def forward_passages(self, passages: List[Image.Image], batch_size: int, **kwargs) -> List[torch.Tensor]:
dataloader = DataLoader(
dataset=ListDataset[Image.Image](passages),
batch_size=batch_size,
shuffle=False,
collate_fn=self.process_images,
num_workers=self.num_workers,
)

passage_embeddings: List[torch.Tensor] = []

with torch.no_grad():
for batch_doc in tqdm(dataloader, desc="Forward pass documents...", leave=False):
embeddings_doc = self.model(**batch_doc).to("cpu")
passage_embeddings.extend(list(torch.unbind(embeddings_doc)))

return passage_embeddings

def get_scores(
self,
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
batch_size: Optional[int] = 128,
) -> torch.Tensor:
if batch_size is None:
raise ValueError("`batch_size` must be provided for ColQwenRetriever's scoring")
scores = self.processor.score(
query_embeddings,
passage_embeddings,
batch_size=batch_size,
device="cpu",
)
return scores