Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vidore_benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ColIdefics3Retriever,
ColPaliRetriever,
ColQwen2Retriever,
ColQwenOmniRetriever,
DSEQwen2Retriever,
DummyVisionRetriever,
HFEndpointRetriever,
Expand Down
1 change: 1 addition & 0 deletions src/vidore_benchmark/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .colidefics3_retriever import ColIdefics3Retriever
from .colpali_retriever import ColPaliRetriever
from .colqwen2_retriever import ColQwen2Retriever
from .colqwenomni_retriever import ColQwenOmniRetriever
from .dse_qwen2_retriever import DSEQwen2Retriever
from .dummy_vision_retriever import DummyVisionRetriever
from .hf_api_retriever import HFEndpointRetriever
Expand Down
117 changes: 117 additions & 0 deletions src/vidore_benchmark/retrievers/colqwenomni_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import logging
from typing import List, Optional, Union

import torch
from dotenv import load_dotenv
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.utils.import_utils import is_flash_attn_2_available

from vidore_benchmark.retrievers.base_vision_retriever import BaseVisionRetriever
from vidore_benchmark.retrievers.registry_utils import register_vision_retriever
from vidore_benchmark.utils.data_utils import ListDataset
from vidore_benchmark.utils.torch_utils import get_torch_device

logger = logging.getLogger(__name__)

load_dotenv(override=True)


@register_vision_retriever("colqwen-omni")
class ColQwenOmniRetriever(BaseVisionRetriever):
"""
ColQwen2 retriever that implements the model from "ColPali: Efficient Document Retrieval
with Vision Language Models".
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ColQwenOmniRetriever(BaseVisionRetriever):
"""
ColQwen2 retriever that implements the model from "ColPali: Efficient Document Retrieval
with Vision Language Models".
"""
class ColQwenOmniRetriever(BaseVisionRetriever):
"""
ColQwenOmni retriever that implements the model from "ColPali: Efficient Document Retrieval
with Vision Language Models" for a Qwen-2.5 Omni backbone.
"""


def __init__(
self,
pretrained_model_name_or_path: str = "vidore/colqwen-omni-v0.1",
device: str = "auto",
num_workers: int = 0,
**kwargs,
):
super().__init__(use_visual_embedding=True)

try:
from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
except ImportError:
raise ImportError(
'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` '
"to use ColQwen2Retriever."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` '
"to use ColQwen2Retriever."
'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` '
"to use ColQwenOmniRetriever."

)

self.device = get_torch_device(device)
self.num_workers = num_workers

# Load the model and LORA adapter
self.model = ColQwen2_5Omni.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=self.device,
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()

# Load the processor
self.processor = ColQwen2_5OmniProcessor.from_pretrained(pretrained_model_name_or_path)

def process_images(self, images: List[Image.Image], **kwargs):
return self.processor.process_images(images=images).to(self.device)

def process_queries(self, queries: List[str], **kwargs):
return self.processor.process_queries(queries=queries).to(self.device)

def forward_queries(self, queries: List[str], batch_size: int, **kwargs) -> List[torch.Tensor]:
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=batch_size,
shuffle=False,
collate_fn=self.process_queries,
num_workers=self.num_workers,
)

query_embeddings: List[torch.Tensor] = []

with torch.no_grad():
for batch_query in tqdm(dataloader, desc="Forward pass queries...", leave=False):
embeddings_query = self.model(**batch_query).to("cpu")
query_embeddings.extend(list(torch.unbind(embeddings_query)))

return query_embeddings

def forward_passages(self, passages: List[Image.Image], batch_size: int, **kwargs) -> List[torch.Tensor]:
dataloader = DataLoader(
dataset=ListDataset[Image.Image](passages),
batch_size=batch_size,
shuffle=False,
collate_fn=self.process_images,
num_workers=self.num_workers,
)

passage_embeddings: List[torch.Tensor] = []

with torch.no_grad():
for batch_doc in tqdm(dataloader, desc="Forward pass documents...", leave=False):
embeddings_doc = self.model(**batch_doc).to("cpu")
passage_embeddings.extend(list(torch.unbind(embeddings_doc)))

return passage_embeddings

def get_scores(
self,
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
batch_size: Optional[int] = 128,
) -> torch.Tensor:
if batch_size is None:
raise ValueError("`batch_size` must be provided for ColQwenRetriever's scoring")
scores = self.processor.score(
query_embeddings,
passage_embeddings,
batch_size=batch_size,
device="cpu",
)
return scores