|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import List, Optional, Union, cast |
| 5 | + |
| 6 | +import torch |
| 7 | +from dotenv import load_dotenv |
| 8 | +from PIL import Image |
| 9 | +from torch.utils.data import DataLoader |
| 10 | +from tqdm import tqdm |
| 11 | +from transformers.utils.import_utils import is_flash_attn_2_available |
| 12 | + |
| 13 | +from vidore_benchmark.retrievers.base_vision_retriever import BaseVisionRetriever |
| 14 | +from vidore_benchmark.retrievers.registry_utils import register_vision_retriever |
| 15 | +from vidore_benchmark.utils.data_utils import ListDataset |
| 16 | +from vidore_benchmark.utils.torch_utils import get_torch_device |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | +load_dotenv(override=True) |
| 21 | + |
| 22 | + |
| 23 | +@register_vision_retriever("colqwen2_5") |
| 24 | +class ColQwen2_5_Retriever(BaseVisionRetriever): |
| 25 | + """ |
| 26 | + ColQwen2.5 retriever that implements the model from "ColPali: Efficient Document Retrieval |
| 27 | + with Vision Language Models". |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + pretrained_model_name_or_path: str = "vidore/colqwen2-v1.0", |
| 33 | + device: str = "auto", |
| 34 | + num_workers: int = 0, |
| 35 | + **kwargs, |
| 36 | + ): |
| 37 | + super().__init__(use_visual_embedding=True) |
| 38 | + |
| 39 | + try: |
| 40 | + from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor |
| 41 | + except ImportError: |
| 42 | + raise ImportError( |
| 43 | + 'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` ' |
| 44 | + "to use ColQwen2_5_Retriever." |
| 45 | + ) |
| 46 | + |
| 47 | + self.device = get_torch_device(device) |
| 48 | + self.num_workers = num_workers |
| 49 | + |
| 50 | + # Load the model and LORA adapter |
| 51 | + self.model = cast( |
| 52 | + ColQwen2_5, |
| 53 | + ColQwen2_5.from_pretrained( |
| 54 | + pretrained_model_name_or_path, |
| 55 | + torch_dtype=torch.bfloat16, |
| 56 | + device_map=self.device, |
| 57 | + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, |
| 58 | + ).eval(), |
| 59 | + ) |
| 60 | + |
| 61 | + # Load the processor |
| 62 | + self.processor = cast( |
| 63 | + ColQwen2_5_Processor, |
| 64 | + ColQwen2_5_Processor.from_pretrained(pretrained_model_name_or_path), |
| 65 | + ) |
| 66 | + |
| 67 | + def process_images(self, images: List[Image.Image], **kwargs): |
| 68 | + return self.processor.process_images(images=images).to(self.device) |
| 69 | + |
| 70 | + def process_queries(self, queries: List[str], **kwargs): |
| 71 | + return self.processor.process_queries(queries=queries).to(self.device) |
| 72 | + |
| 73 | + def forward_queries(self, queries: List[str], batch_size: int, **kwargs) -> List[torch.Tensor]: |
| 74 | + dataloader = DataLoader( |
| 75 | + dataset=ListDataset[str](queries), |
| 76 | + batch_size=batch_size, |
| 77 | + shuffle=False, |
| 78 | + collate_fn=self.process_queries, |
| 79 | + num_workers=self.num_workers, |
| 80 | + ) |
| 81 | + |
| 82 | + query_embeddings: List[torch.Tensor] = [] |
| 83 | + |
| 84 | + with torch.no_grad(): |
| 85 | + for batch_query in tqdm(dataloader, desc="Forward pass queries...", leave=False): |
| 86 | + embeddings_query = self.model(**batch_query).to("cpu") |
| 87 | + query_embeddings.extend(list(torch.unbind(embeddings_query))) |
| 88 | + |
| 89 | + return query_embeddings |
| 90 | + |
| 91 | + def forward_passages(self, passages: List[Image.Image], batch_size: int, **kwargs) -> List[torch.Tensor]: |
| 92 | + dataloader = DataLoader( |
| 93 | + dataset=ListDataset[Image.Image](passages), |
| 94 | + batch_size=batch_size, |
| 95 | + shuffle=False, |
| 96 | + collate_fn=self.process_images, |
| 97 | + num_workers=self.num_workers, |
| 98 | + ) |
| 99 | + |
| 100 | + passage_embeddings: List[torch.Tensor] = [] |
| 101 | + |
| 102 | + with torch.no_grad(): |
| 103 | + for batch_doc in tqdm(dataloader, desc="Forward pass documents...", leave=False): |
| 104 | + embeddings_doc = self.model(**batch_doc).to("cpu") |
| 105 | + passage_embeddings.extend(list(torch.unbind(embeddings_doc))) |
| 106 | + |
| 107 | + return passage_embeddings |
| 108 | + |
| 109 | + def get_scores( |
| 110 | + self, |
| 111 | + query_embeddings: Union[torch.Tensor, List[torch.Tensor]], |
| 112 | + passage_embeddings: Union[torch.Tensor, List[torch.Tensor]], |
| 113 | + batch_size: Optional[int] = 128, |
| 114 | + ) -> torch.Tensor: |
| 115 | + if batch_size is None: |
| 116 | + raise ValueError("`batch_size` must be provided for ColQwenRetriever's scoring") |
| 117 | + scores = self.processor.score( |
| 118 | + query_embeddings, |
| 119 | + passage_embeddings, |
| 120 | + batch_size=batch_size, |
| 121 | + device="cpu", |
| 122 | + ) |
| 123 | + return scores |
0 commit comments