Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Added

- Add ColQwen3 and BiQwen3 support (model + processor).

### Tests

- Cover ColQwen3 processing and modeling with slow integration tests.

## [0.3.13] - 2025-11-15

### Added
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Using ColPali removes the need for potentially complex and brittle layout recogn
| [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3 | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256). | ✅ |
| [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8 | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`<br />• Supports dynamic resolution.<br />• Trained using 768 image patches per page and an effective batch size of 32. | ✅ |
| [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4 | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters | ✅ |
| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | 90.6 | Apache 2.0 | • Based on the Qwen3-VL backbone.<br />• 320-dim ColBERT-style embeddings with dynamic resolution.<br />• Trained for multi-vector document retrieval. | ✅ |
| [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M) | 80.1 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`. | ✅ |
| [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M) | 82.3 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`. | ✅ |

Expand Down
4 changes: 4 additions & 0 deletions colpali_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
BiQwen2_5,
BiQwen2_5_Processor,
BiQwen2Processor,
BiQwen3,
BiQwen3Processor,
ColIdefics3,
ColIdefics3Processor,
ColModernVBert,
Expand All @@ -19,4 +21,6 @@
ColQwen2_5Omni,
ColQwen2_5OmniProcessor,
ColQwen2Processor,
ColQwen3,
ColQwen3Processor,
)
1 change: 1 addition & 0 deletions colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor
from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .biqwen3 import BiQwen3, BiQwen3Processor
from .colqwen3 import ColQwen3, ColQwen3Processor
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen3/biqwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_biqwen3 import BiQwen3
from .processing_biqwen3 import BiQwen3Processor
94 changes: 94 additions & 0 deletions colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import ClassVar, Literal

import torch
from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel


class BiQwen3(Qwen3VLModel):
"""
BiQwen3 implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
Representations are pooled to obtain a single vector representation. Based on the Qwen3-VL backbone.
"""

main_input_name: ClassVar[str] = "doc_input_ids"
_checkpoint_conversion_mapping = {
r"^model\.visual": "visual",
r"^model\.language_model": "language_model",
r"^model\.": "",
}

def __init__(self, config: Qwen3VLConfig, **kwargs):
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
attn_impl = kwargs.pop("attn_implementation", None)
use_cache = kwargs.pop("use_cache", None)

super().__init__(config=config)
self.padding_side = "left"
self.post_init()

if dtype is not None:
self.to(dtype=dtype)
if use_cache is not None:
self.config.use_cache = use_cache
if attn_impl is not None and hasattr(self, "set_attn_implementation"):
self.set_attn_implementation(attn_impl)

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(
self,
pooling_strategy: Literal["cls", "last", "mean"] = "last",
*args,
**kwargs,
) -> torch.Tensor:
"""
Forward pass for BiQwen3 model.

Args:
pooling_strategy: The strategy to use for pooling the hidden states.
*args: Variable length argument list.
**kwargs: Additional keyword arguments.

Returns:
torch.Tensor: Dense embeddings (batch_size, hidden_size).
"""
if "pixel_values" in kwargs:
offsets = kwargs["image_grid_thw"].prod(dim=1).tolist()
kwargs["pixel_values"] = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
dim=0,
)
kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)

last_hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
) # (batch_size, sequence_length, hidden_size)

if pooling_strategy == "cls":
pooled = last_hidden_states[:, 0]
elif pooling_strategy == "last":
pooled = last_hidden_states[:, -1]
elif pooling_strategy == "mean":
mask = kwargs["attention_mask"].unsqueeze(-1)
pooled = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
else:
raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")

return pooled / pooled.norm(dim=-1, keepdim=True)

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size
37 changes: 37 additions & 0 deletions colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, Optional, Union

import torch
from transformers import BatchEncoding, BatchFeature

from colpali_engine.models.qwen3.colqwen3 import ColQwen3Processor


class BiQwen3Processor(ColQwen3Processor):
"""
Processor for BiQwen3.
"""

def process_texts(
self,
texts: List[str],
) -> Union[BatchFeature, BatchEncoding]:
"""
Process texts for BiQwen3.
"""
return self(
text=texts,
return_tensors="pt",
padding="longest",
)

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the cosine similarity for the given query and passage embeddings.
"""
return self.score_single_vector(qs, ps, device=device)
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_colqwen3 import ColQwen3
from .processing_colqwen3 import ColQwen3Processor
101 changes: 101 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import ClassVar

import torch
from torch import nn
from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel


class ColQwen3(Qwen3VLModel):
"""
ColQwen3 model implementation, following the architecture from the article "ColPali: Efficient Document Retrieval
with Vision Language Models" paper. Based on the Qwen3-VL backbone.

Args:
config (Qwen3VLConfig): The model configuration.
mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
except those of the image at inference.
Defaults to False --> Do not mask any embeddings during forward pass.
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
_checkpoint_conversion_mapping = {
r"^model\.visual": "visual",
r"^model\.language_model": "language_model",
r"^model\.": "",
}

def __init__(
self,
config: Qwen3VLConfig,
mask_non_image_embeddings: bool = False,
**kwargs,
):
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
attn_impl = kwargs.pop("attn_implementation", None)
use_cache = kwargs.pop("use_cache", None)

super().__init__(config=config)

hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "text_config"):
hidden_size = self.config.text_config.hidden_size
if hidden_size is None:
raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.")

self.dim = 320
self.custom_text_proj = nn.Linear(hidden_size, self.dim)
self.padding_side = "left"
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

if dtype is not None:
self.to(dtype=dtype)
if use_cache is not None:
self.config.use_cache = use_cache
if attn_impl is not None and hasattr(self, "set_attn_implementation"):
self.set_attn_implementation(attn_impl)

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(self, *args, **kwargs) -> torch.Tensor:
# Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding
if "pixel_values" in kwargs:
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,)
kwargs["pixel_values"] = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
dim=0,
)

kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)
last_hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
) # (batch_size, sequence_length, hidden_size)

proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)

if "pixel_values" in kwargs and self.mask_non_image_embeddings:
# Pools only the image embeddings
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask
return proj

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size
Loading