Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Added

- Add ColQwen3 and BiQwen3 support (model + processor).

### Tests

- Cover ColQwen3 processing and modeling with slow integration tests.

## [0.3.13] - 2025-11-15

### Added
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Using ColPali removes the need for potentially complex and brittle layout recogn
| [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3 | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256). | ✅ |
| [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8 | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`<br />• Supports dynamic resolution.<br />• Trained using 768 image patches per page and an effective batch size of 32. | ✅ |
| [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4 | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters | ✅ |
| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | TBD | Apache 2.0 | • Based on the Qwen3-VL backbone.<br />• 320-dim ColBERT-style embeddings with dynamic resolution.<br />• Trained for multi-vector document retrieval. | ✅ |
| [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M) | 80.1 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`. | ✅ |
| [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M) | 82.3 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`. | ✅ |

Expand Down
4 changes: 4 additions & 0 deletions colpali_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
BiQwen2_5,
BiQwen2_5_Processor,
BiQwen2Processor,
BiQwen3,
BiQwen3Processor,
ColIdefics3,
ColIdefics3Processor,
ColModernVBert,
Expand All @@ -19,4 +21,6 @@
ColQwen2_5Omni,
ColQwen2_5OmniProcessor,
ColQwen2Processor,
ColQwen3,
ColQwen3Processor,
)
1 change: 1 addition & 0 deletions colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor
from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .biqwen3 import BiQwen3, BiQwen3Processor
from .colqwen3 import ColQwen3, ColQwen3Processor
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen3/biqwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_biqwen3 import BiQwen3
from .processing_biqwen3 import BiQwen3Processor
94 changes: 94 additions & 0 deletions colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import ClassVar, Literal

import torch
from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel


class BiQwen3(Qwen3VLModel):
"""
BiQwen3 implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
Representations are pooled to obtain a single vector representation. Based on the Qwen3-VL backbone.
"""

main_input_name: ClassVar[str] = "doc_input_ids"
_checkpoint_conversion_mapping = {
r"^model\.visual": "visual",
r"^model\.language_model": "language_model",
r"^model\.": "",
}

def __init__(self, config: Qwen3VLConfig, **kwargs):
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
attn_impl = kwargs.pop("attn_implementation", None)
use_cache = kwargs.pop("use_cache", None)

super().__init__(config=config)
self.padding_side = "left"
self.post_init()

if dtype is not None:
self.to(dtype=dtype)
if use_cache is not None:
self.config.use_cache = use_cache
if attn_impl is not None and hasattr(self, "set_attn_implementation"):
self.set_attn_implementation(attn_impl)

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(
self,
pooling_strategy: Literal["cls", "last", "mean"] = "last",
*args,
**kwargs,
) -> torch.Tensor:
"""
Forward pass for BiQwen3 model.

Args:
pooling_strategy: The strategy to use for pooling the hidden states.
*args: Variable length argument list.
**kwargs: Additional keyword arguments.

Returns:
torch.Tensor: Dense embeddings (batch_size, hidden_size).
"""
if "pixel_values" in kwargs:
offsets = kwargs["image_grid_thw"].prod(dim=1).tolist()
kwargs["pixel_values"] = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
dim=0,
)
kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)

last_hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
) # (batch_size, sequence_length, hidden_size)

if pooling_strategy == "cls":
pooled = last_hidden_states[:, 0]
elif pooling_strategy == "last":
pooled = last_hidden_states[:, -1]
elif pooling_strategy == "mean":
mask = kwargs["attention_mask"].unsqueeze(-1)
pooled = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
else:
raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")

return pooled / pooled.norm(dim=-1, keepdim=True)

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size
37 changes: 37 additions & 0 deletions colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, Optional, Union

import torch
from transformers import BatchEncoding, BatchFeature

from colpali_engine.models.qwen3.colqwen3 import ColQwen3Processor


class BiQwen3Processor(ColQwen3Processor):
"""
Processor for BiQwen3.
"""

def process_texts(
self,
texts: List[str],
) -> Union[BatchFeature, BatchEncoding]:
"""
Process texts for BiQwen3.
"""
return self(
text=texts,
return_tensors="pt",
padding="longest",
)

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the cosine similarity for the given query and passage embeddings.
"""
return self.score_single_vector(qs, ps, device=device)
2 changes: 2 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_colqwen3 import ColQwen3
from .processing_colqwen3 import ColQwen3Processor
101 changes: 101 additions & 0 deletions colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import ClassVar

import torch
from torch import nn
from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel


class ColQwen3(Qwen3VLModel):
"""
ColQwen3 model implementation, following the architecture from the article "ColPali: Efficient Document Retrieval
with Vision Language Models" paper. Based on the Qwen3-VL backbone.

Args:
config (Qwen3VLConfig): The model configuration.
mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
except those of the image at inference.
Defaults to False --> Do not mask any embeddings during forward pass.
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
_checkpoint_conversion_mapping = {
r"^model\.visual": "visual",
r"^model\.language_model": "language_model",
r"^model\.": "",
}

def __init__(
self,
config: Qwen3VLConfig,
mask_non_image_embeddings: bool = False,
**kwargs,
):
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
attn_impl = kwargs.pop("attn_implementation", None)
use_cache = kwargs.pop("use_cache", None)

super().__init__(config=config)

hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None and hasattr(self.config, "text_config"):
hidden_size = self.config.text_config.hidden_size
if hidden_size is None:
raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.")

self.dim = 320
self.custom_text_proj = nn.Linear(hidden_size, self.dim)
self.padding_side = "left"
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

if dtype is not None:
self.to(dtype=dtype)
if use_cache is not None:
self.config.use_cache = use_cache
if attn_impl is not None and hasattr(self, "set_attn_implementation"):
self.set_attn_implementation(attn_impl)

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

def forward(self, *args, **kwargs) -> torch.Tensor:
# Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding
if "pixel_values" in kwargs:
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,)
kwargs["pixel_values"] = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
dim=0,
)

kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)
last_hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
) # (batch_size, sequence_length, hidden_size)

proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)

if "pixel_values" in kwargs and self.mask_non_image_embeddings:
# Pools only the image embeddings
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask
return proj

@property
def patch_size(self) -> int:
return self.visual.config.patch_size

@property
def spatial_merge_size(self) -> int:
return self.visual.config.spatial_merge_size
Loading
Loading