Skip to content

Commit ef1ddad

Browse files
add colqwen2_5 (#103)
* add colqwen2_5 * Apply suggestions from code review Co-authored-by: QuentinJGMace <[email protected]> --------- Co-authored-by: QuentinJGMace <[email protected]>
1 parent dca3489 commit ef1ddad

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

src/vidore_benchmark/retrievers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .colidefics3_retriever import ColIdefics3Retriever
88
from .colpali_retriever import ColPaliRetriever
99
from .colqwen2_retriever import ColQwen2Retriever
10+
from .colqwen2_5_retriever import ColQwen2_5_Retriever
1011
from .dse_qwen2_retriever import DSEQwen2Retriever
1112
from .dummy_vision_retriever import DummyVisionRetriever
1213
from .hf_api_retriever import HFEndpointRetriever
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import List, Optional, Union, cast
5+
6+
import torch
7+
from dotenv import load_dotenv
8+
from PIL import Image
9+
from torch.utils.data import DataLoader
10+
from tqdm import tqdm
11+
from transformers.utils.import_utils import is_flash_attn_2_available
12+
13+
from vidore_benchmark.retrievers.base_vision_retriever import BaseVisionRetriever
14+
from vidore_benchmark.retrievers.registry_utils import register_vision_retriever
15+
from vidore_benchmark.utils.data_utils import ListDataset
16+
from vidore_benchmark.utils.torch_utils import get_torch_device
17+
18+
logger = logging.getLogger(__name__)
19+
20+
load_dotenv(override=True)
21+
22+
23+
@register_vision_retriever("colqwen2_5")
24+
class ColQwen2_5_Retriever(BaseVisionRetriever):
25+
"""
26+
ColQwen2.5 retriever that implements the model from "ColPali: Efficient Document Retrieval
27+
with Vision Language Models".
28+
"""
29+
30+
def __init__(
31+
self,
32+
pretrained_model_name_or_path: str = "vidore/colqwen2-v1.0",
33+
device: str = "auto",
34+
num_workers: int = 0,
35+
**kwargs,
36+
):
37+
super().__init__(use_visual_embedding=True)
38+
39+
try:
40+
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
41+
except ImportError:
42+
raise ImportError(
43+
'Install the missing dependencies with `pip install "vidore-benchmark[colpali-engine]"` '
44+
"to use ColQwen2_5_Retriever."
45+
)
46+
47+
self.device = get_torch_device(device)
48+
self.num_workers = num_workers
49+
50+
# Load the model and LORA adapter
51+
self.model = cast(
52+
ColQwen2_5,
53+
ColQwen2_5.from_pretrained(
54+
pretrained_model_name_or_path,
55+
torch_dtype=torch.bfloat16,
56+
device_map=self.device,
57+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
58+
).eval(),
59+
)
60+
61+
# Load the processor
62+
self.processor = cast(
63+
ColQwen2_5_Processor,
64+
ColQwen2_5_Processor.from_pretrained(pretrained_model_name_or_path),
65+
)
66+
67+
def process_images(self, images: List[Image.Image], **kwargs):
68+
return self.processor.process_images(images=images).to(self.device)
69+
70+
def process_queries(self, queries: List[str], **kwargs):
71+
return self.processor.process_queries(queries=queries).to(self.device)
72+
73+
def forward_queries(self, queries: List[str], batch_size: int, **kwargs) -> List[torch.Tensor]:
74+
dataloader = DataLoader(
75+
dataset=ListDataset[str](queries),
76+
batch_size=batch_size,
77+
shuffle=False,
78+
collate_fn=self.process_queries,
79+
num_workers=self.num_workers,
80+
)
81+
82+
query_embeddings: List[torch.Tensor] = []
83+
84+
with torch.no_grad():
85+
for batch_query in tqdm(dataloader, desc="Forward pass queries...", leave=False):
86+
embeddings_query = self.model(**batch_query).to("cpu")
87+
query_embeddings.extend(list(torch.unbind(embeddings_query)))
88+
89+
return query_embeddings
90+
91+
def forward_passages(self, passages: List[Image.Image], batch_size: int, **kwargs) -> List[torch.Tensor]:
92+
dataloader = DataLoader(
93+
dataset=ListDataset[Image.Image](passages),
94+
batch_size=batch_size,
95+
shuffle=False,
96+
collate_fn=self.process_images,
97+
num_workers=self.num_workers,
98+
)
99+
100+
passage_embeddings: List[torch.Tensor] = []
101+
102+
with torch.no_grad():
103+
for batch_doc in tqdm(dataloader, desc="Forward pass documents...", leave=False):
104+
embeddings_doc = self.model(**batch_doc).to("cpu")
105+
passage_embeddings.extend(list(torch.unbind(embeddings_doc)))
106+
107+
return passage_embeddings
108+
109+
def get_scores(
110+
self,
111+
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
112+
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
113+
batch_size: Optional[int] = 128,
114+
) -> torch.Tensor:
115+
if batch_size is None:
116+
raise ValueError("`batch_size` must be provided for ColQwenRetriever's scoring")
117+
scores = self.processor.score(
118+
query_embeddings,
119+
passage_embeddings,
120+
batch_size=batch_size,
121+
device="cpu",
122+
)
123+
return scores

0 commit comments

Comments
 (0)