Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
from .qwen3 import ColQwen3, ColQwen3MoE, ColQwen3MoEProcessor, ColQwen3Processor
3 changes: 3 additions & 0 deletions colpali_engine/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .colqwen3 import ColQwen3, ColQwen3Processor
from .colqwen3_moe import ColQwen3MoE, ColQwen3MoEProcessor

3 changes: 3 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .modeling_colqwen3 import ColQwen3
from .processing_colqwen3 import ColQwen3Processor

93 changes: 93 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import ClassVar

import torch
from torch import nn
from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel


class ColQwen3(Qwen3VLModel):
"""
ColQwen3 model implementation, following the architecture described in the
"ColPali: Efficient Document Retrieval with Vision Language Models" paper.
This wrapper adapts the Qwen3-VL backbone for multi-vector document retrieval.

Args:
config (Qwen3VLConfig): Model configuration.
mask_non_image_embeddings (bool): When ``True`` only image embeddings are kept in the output vectors.
Defaults to ``False`` meaning that all embeddings are returned.
"""

main_input_name: ClassVar[str] = "doc_input_ids"

def __init__(self, config: Qwen3VLConfig, mask_non_image_embeddings: bool = False):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim)
self.padding_side = "left"
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = super()._checkpoint_conversion_mapping
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(self, *args, **kwargs) -> torch.Tensor:
attention_mask = kwargs.get("attention_mask")
has_pixel_values = "pixel_values" in kwargs and kwargs["pixel_values"] is not None

if has_pixel_values:
image_grid_thw = kwargs.get("image_grid_thw")
if image_grid_thw is None:
raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.")

if not torch.is_tensor(image_grid_thw):
image_grid_thw = torch.as_tensor(image_grid_thw, device=kwargs["pixel_values"].device)

offsets = image_grid_thw.prod(dim=1)
unpadded = [
pixel_sequence[: int(offset.item())]
for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)
]

if unpadded:
kwargs["pixel_values"] = torch.cat(unpadded, dim=0)
else:
kwargs["pixel_values"] = None

kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)

last_hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
)

proj = self.custom_text_proj(last_hidden_states)
proj = proj / proj.norm(dim=-1, keepdim=True)

if attention_mask is not None:
proj = proj * attention_mask.unsqueeze(-1)

if has_pixel_values and self.mask_non_image_embeddings and kwargs.get("input_ids") is not None:
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask

return proj

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size

@property
def temporal_patch_size(self) -> int:
return getattr(self.visual.config, "temporal_patch_size", 1)

129 changes: 129 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import ClassVar, List, Optional, Tuple, Union

import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen3_vl import Qwen3VLProcessor

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor


class ColQwen3Processor(BaseVisualRetrieverProcessor, Qwen3VLProcessor):
"""
Processor for ColQwen3.

Args:
max_num_visual_tokens: Maximum number of visual tokens allowed during preprocessing.
*args: Variable positional arguments forwarded to :class:`~transformers.Qwen3VLProcessor`.
**kwargs: Keyword arguments forwarded to :class:`~transformers.Qwen3VLProcessor`.
"""

visual_prompt_prefix: ClassVar[str] = (
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
)
query_augmentation_token: ClassVar[str] = "<|endoftext|>"
image_token: ClassVar[str] = "<|image_pad|>"

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.tokenizer.padding_side = "left"

@classmethod
def from_pretrained(
cls,
*args,
device_map: Optional[str] = None,
**kwargs,
):
instance = super().from_pretrained(
*args,
device_map=device_map,
**kwargs,
)

if "max_num_visual_tokens" in kwargs:
patch_size = getattr(instance.image_processor, "patch_size", 28)
instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * patch_size * patch_size
instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels

return instance

def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]:
"""
Process a batch of PIL images for ColQwen3.
"""

images = [image.convert("RGB") for image in images]

batch_doc = self(
text=[self.visual_prompt_prefix] * len(images),
images=images,
padding="longest",
return_tensors="pt",
)

if batch_doc["pixel_values"].numel() == 0:
return batch_doc

offsets = batch_doc["image_grid_thw"].prod(dim=1)
pixel_values = list(
torch.split(batch_doc["pixel_values"], offsets.tolist())
) # [(num_patches_img_0, patch_dim), ..., (num_patches_img_n, patch_dim)]

batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True)

return batch_doc

def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
"""
Process a batch of raw texts for ColQwen3.
"""
return self(
text=texts,
return_tensors="pt",
padding="longest",
)

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for query and passage embeddings.
"""
return self.score_multi_vector(qs, ps, device=device, **kwargs)

def get_n_patches(
self,
image_size: Tuple[int, int],
spatial_merge_size: int,
) -> Tuple[int, int]:
"""
Compute the number of spatial patches for an image of ``image_size``.
"""
patch_size = self.image_processor.patch_size
merge_size = getattr(self.image_processor, "merge_size", 1)

height_new, width_new = smart_resize(
width=image_size[0],
height=image_size[1],
factor=patch_size * merge_size,
min_pixels=self.image_processor.size["shortest_edge"],
max_pixels=self.image_processor.size["longest_edge"],
)

n_patches_x = width_new // patch_size // spatial_merge_size
n_patches_y = height_new // patch_size // spatial_merge_size

return n_patches_x, n_patches_y

def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
"""
Return a boolean tensor identifying image tokens inside ``batch_images``.
"""
return batch_images.input_ids == self.image_token_id

3 changes: 3 additions & 0 deletions colpali_engine/models/qwen3/colqwen3_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .modeling_colqwen3_moe import ColQwen3MoE
from .processing_colqwen3_moe import ColQwen3MoEProcessor

91 changes: 91 additions & 0 deletions colpali_engine/models/qwen3/colqwen3_moe/modeling_colqwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import ClassVar

import torch
from torch import nn
from transformers.models.qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeModel


class ColQwen3MoE(Qwen3VLMoeModel):
"""
ColQwen3-MoE model implementation. This adapts the Qwen3-VL-MoE backbone to the ColPali multi-vector
retrieval setting.

Args:
config (Qwen3VLMoeConfig): Model configuration.
mask_non_image_embeddings (bool): When ``True`` only image embeddings are preserved.
"""

main_input_name: ClassVar[str] = "doc_input_ids"

def __init__(self, config: Qwen3VLMoeConfig, mask_non_image_embeddings: bool = False):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim)
self.padding_side = "left"
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = super()._checkpoint_conversion_mapping
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(self, *args, **kwargs) -> torch.Tensor:
attention_mask = kwargs.get("attention_mask")
has_pixel_values = "pixel_values" in kwargs and kwargs["pixel_values"] is not None

if has_pixel_values:
image_grid_thw = kwargs.get("image_grid_thw")
if image_grid_thw is None:
raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.")

if not torch.is_tensor(image_grid_thw):
image_grid_thw = torch.as_tensor(image_grid_thw, device=kwargs["pixel_values"].device)

offsets = image_grid_thw.prod(dim=1)
unpadded = [
pixel_sequence[: int(offset.item())]
for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)
]

if unpadded:
kwargs["pixel_values"] = torch.cat(unpadded, dim=0)
else:
kwargs["pixel_values"] = None

kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)

last_hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
)

proj = self.custom_text_proj(last_hidden_states)
proj = proj / proj.norm(dim=-1, keepdim=True)

if attention_mask is not None:
proj = proj * attention_mask.unsqueeze(-1)

if has_pixel_values and self.mask_non_image_embeddings and kwargs.get("input_ids") is not None:
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask

return proj

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size

@property
def temporal_patch_size(self) -> int:
return getattr(self.visual.config, "temporal_patch_size", 1)

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import ClassVar, List, Optional, Tuple, Union

import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature

from colpali_engine.models.qwen3.colqwen3.processing_colqwen3 import ColQwen3Processor


class ColQwen3MoEProcessor(ColQwen3Processor):
"""
Processor for the ColQwen3-MoE model variant. The MoE backbone shares the same vision and text
preprocessing pipeline as the dense Qwen3-VL models, but exposing a dedicated class keeps the API
symmetric with the available ColPali wrappers.
"""

moe_variant: ClassVar[bool] = True

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

@classmethod
def from_pretrained(
cls,
*args,
device_map: Optional[str] = None,
**kwargs,
):
return super().from_pretrained(*args, device_map=device_map, **kwargs)

def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]:
return super().process_images(images)

def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
return super().process_texts(texts)

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
return super().score(qs, ps, device=device, **kwargs)

def get_n_patches(
self,
image_size: Tuple[int, int],
spatial_merge_size: int,
) -> Tuple[int, int]:
return super().get_n_patches(image_size, spatial_merge_size)

def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
return super().get_image_mask(batch_images)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"scipy",
"torch>=2.2.0,<2.9.0",
"torchvision",
"transformers>=4.53.1,<4.58.0",
"transformers>=4.57.1,<4.58.0",
]

[project.optional-dependencies]
Expand Down
Loading
Loading