diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3a5a82d88..3c0e624b1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
## Unreleased
+### Added
+
+- Add ColQwen3 and BiQwen3 support (model + processor).
+
+### Tests
+
+- Cover ColQwen3 processing and modeling with slow integration tests.
+
## [0.3.13] - 2025-11-15
### Added
diff --git a/README.md b/README.md
index 48355bfc0..98c8127a5 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,7 @@ Using ColPali removes the need for potentially complex and brittle layout recogn
| [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3 | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256). | ✅ |
| [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8 | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`
• Supports dynamic resolution.
• Trained using 768 image patches per page and an effective batch size of 32. | ✅ |
| [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4 | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters | ✅ |
+| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | 90.6 | Apache 2.0 | • Based on the Qwen3-VL backbone.
• 320-dim ColBERT-style embeddings with dynamic resolution.
• Trained for multi-vector document retrieval. | ✅ |
| [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M) | 80.1 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`. | ✅ |
| [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M) | 82.3 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`. | ✅ |
diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py
index f72652aed..1f5098f74 100644
--- a/colpali_engine/__init__.py
+++ b/colpali_engine/__init__.py
@@ -7,6 +7,8 @@
BiQwen2_5,
BiQwen2_5_Processor,
BiQwen2Processor,
+ BiQwen3,
+ BiQwen3Processor,
ColIdefics3,
ColIdefics3Processor,
ColModernVBert,
@@ -19,4 +21,6 @@
ColQwen2_5Omni,
ColQwen2_5OmniProcessor,
ColQwen2Processor,
+ ColQwen3,
+ ColQwen3Processor,
)
diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py
index 1129e1612..b98544fcb 100644
--- a/colpali_engine/models/__init__.py
+++ b/colpali_engine/models/__init__.py
@@ -3,4 +3,5 @@
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor
from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor
+from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor
from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor
diff --git a/colpali_engine/models/qwen3/__init__.py b/colpali_engine/models/qwen3/__init__.py
new file mode 100644
index 000000000..efcee26f4
--- /dev/null
+++ b/colpali_engine/models/qwen3/__init__.py
@@ -0,0 +1,2 @@
+from .biqwen3 import BiQwen3, BiQwen3Processor
+from .colqwen3 import ColQwen3, ColQwen3Processor
diff --git a/colpali_engine/models/qwen3/biqwen3/__init__.py b/colpali_engine/models/qwen3/biqwen3/__init__.py
new file mode 100644
index 000000000..1b7881d83
--- /dev/null
+++ b/colpali_engine/models/qwen3/biqwen3/__init__.py
@@ -0,0 +1,2 @@
+from .modeling_biqwen3 import BiQwen3
+from .processing_biqwen3 import BiQwen3Processor
diff --git a/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py b/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py
new file mode 100644
index 000000000..fc4804e4b
--- /dev/null
+++ b/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py
@@ -0,0 +1,94 @@
+from typing import ClassVar, Literal
+
+import torch
+from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel
+
+
+class BiQwen3(Qwen3VLModel):
+ """
+ BiQwen3 implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
+ Representations are pooled to obtain a single vector representation. Based on the Qwen3-VL backbone.
+ """
+
+ main_input_name: ClassVar[str] = "doc_input_ids"
+ _checkpoint_conversion_mapping = {
+ r"^model\.visual": "visual",
+ r"^model\.language_model": "language_model",
+ r"^model\.": "",
+ }
+
+ def __init__(self, config: Qwen3VLConfig, **kwargs):
+ dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
+ attn_impl = kwargs.pop("attn_implementation", None)
+ use_cache = kwargs.pop("use_cache", None)
+
+ super().__init__(config=config)
+ self.padding_side = "left"
+ self.post_init()
+
+ if dtype is not None:
+ self.to(dtype=dtype)
+ if use_cache is not None:
+ self.config.use_cache = use_cache
+ if attn_impl is not None and hasattr(self, "set_attn_implementation"):
+ self.set_attn_implementation(attn_impl)
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ key_mapping = kwargs.pop("key_mapping", None)
+ if key_mapping is None:
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
+ return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+ def forward(
+ self,
+ pooling_strategy: Literal["cls", "last", "mean"] = "last",
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Forward pass for BiQwen3 model.
+
+ Args:
+ pooling_strategy: The strategy to use for pooling the hidden states.
+ *args: Variable length argument list.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ torch.Tensor: Dense embeddings (batch_size, hidden_size).
+ """
+ if "pixel_values" in kwargs:
+ offsets = kwargs["image_grid_thw"].prod(dim=1).tolist()
+ kwargs["pixel_values"] = torch.cat(
+ [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+ dim=0,
+ )
+ kwargs.pop("return_dict", True)
+ kwargs.pop("output_hidden_states", None)
+ kwargs.pop("use_cache", None)
+
+ last_hidden_states = (
+ super()
+ .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+ .last_hidden_state
+ ) # (batch_size, sequence_length, hidden_size)
+
+ if pooling_strategy == "cls":
+ pooled = last_hidden_states[:, 0]
+ elif pooling_strategy == "last":
+ pooled = last_hidden_states[:, -1]
+ elif pooling_strategy == "mean":
+ mask = kwargs["attention_mask"].unsqueeze(-1)
+ pooled = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
+ else:
+ raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")
+
+ return pooled / pooled.norm(dim=-1, keepdim=True)
+
+ @property
+ def patch_size(self) -> int:
+ return self.visual.config.patch_size
+
+ @property
+ def spatial_merge_size(self) -> int:
+ return self.visual.config.spatial_merge_size
diff --git a/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py b/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py
new file mode 100644
index 000000000..c13eaff0d
--- /dev/null
+++ b/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py
@@ -0,0 +1,37 @@
+from typing import List, Optional, Union
+
+import torch
+from transformers import BatchEncoding, BatchFeature
+
+from colpali_engine.models.qwen3.colqwen3 import ColQwen3Processor
+
+
+class BiQwen3Processor(ColQwen3Processor):
+ """
+ Processor for BiQwen3.
+ """
+
+ def process_texts(
+ self,
+ texts: List[str],
+ ) -> Union[BatchFeature, BatchEncoding]:
+ """
+ Process texts for BiQwen3.
+ """
+ return self(
+ text=texts,
+ return_tensors="pt",
+ padding="longest",
+ )
+
+ def score(
+ self,
+ qs: List[torch.Tensor],
+ ps: List[torch.Tensor],
+ device: Optional[Union[str, torch.device]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Compute the cosine similarity for the given query and passage embeddings.
+ """
+ return self.score_single_vector(qs, ps, device=device)
diff --git a/colpali_engine/models/qwen3/colqwen3/__init__.py b/colpali_engine/models/qwen3/colqwen3/__init__.py
new file mode 100644
index 000000000..6369cb69b
--- /dev/null
+++ b/colpali_engine/models/qwen3/colqwen3/__init__.py
@@ -0,0 +1,2 @@
+from .modeling_colqwen3 import ColQwen3
+from .processing_colqwen3 import ColQwen3Processor
diff --git a/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py
new file mode 100644
index 000000000..019c87bd5
--- /dev/null
+++ b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py
@@ -0,0 +1,101 @@
+from typing import ClassVar
+
+import torch
+from torch import nn
+from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel
+
+
+class ColQwen3(Qwen3VLModel):
+ """
+ ColQwen3 model implementation, following the architecture from the article "ColPali: Efficient Document Retrieval
+ with Vision Language Models" paper. Based on the Qwen3-VL backbone.
+
+ Args:
+ config (Qwen3VLConfig): The model configuration.
+ mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings
+ except those of the image at inference.
+ Defaults to False --> Do not mask any embeddings during forward pass.
+ """
+
+ main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
+ _checkpoint_conversion_mapping = {
+ r"^model\.visual": "visual",
+ r"^model\.language_model": "language_model",
+ r"^model\.": "",
+ }
+
+ def __init__(
+ self,
+ config: Qwen3VLConfig,
+ mask_non_image_embeddings: bool = False,
+ **kwargs,
+ ):
+ dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
+ attn_impl = kwargs.pop("attn_implementation", None)
+ use_cache = kwargs.pop("use_cache", None)
+
+ super().__init__(config=config)
+
+ hidden_size = getattr(self.config, "hidden_size", None)
+ if hidden_size is None and hasattr(self.config, "text_config"):
+ hidden_size = self.config.text_config.hidden_size
+ if hidden_size is None:
+ raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.")
+
+ self.dim = 320
+ self.custom_text_proj = nn.Linear(hidden_size, self.dim)
+ self.padding_side = "left"
+ self.mask_non_image_embeddings = mask_non_image_embeddings
+ self.post_init()
+
+ if dtype is not None:
+ self.to(dtype=dtype)
+ if use_cache is not None:
+ self.config.use_cache = use_cache
+ if attn_impl is not None and hasattr(self, "set_attn_implementation"):
+ self.set_attn_implementation(attn_impl)
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ key_mapping = kwargs.pop("key_mapping", None)
+ if key_mapping is None:
+ key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
+ return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
+
+ def forward(self, *args, **kwargs) -> torch.Tensor:
+ # Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding
+ if "pixel_values" in kwargs:
+ offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,)
+ kwargs["pixel_values"] = torch.cat(
+ [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
+ dim=0,
+ )
+
+ kwargs.pop("return_dict", True)
+ kwargs.pop("output_hidden_states", None)
+ kwargs.pop("use_cache", None)
+ last_hidden_states = (
+ super()
+ .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
+ .last_hidden_state
+ ) # (batch_size, sequence_length, hidden_size)
+
+ proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
+
+ # L2 normalization
+ proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)
+
+ if "pixel_values" in kwargs and self.mask_non_image_embeddings:
+ # Pools only the image embeddings
+ image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
+ proj = proj * image_mask
+ return proj
+
+ @property
+ def patch_size(self) -> int:
+ return self.visual.config.patch_size
+
+ @property
+ def spatial_merge_size(self) -> int:
+ return self.visual.config.spatial_merge_size
diff --git a/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py
new file mode 100644
index 000000000..436c7336e
--- /dev/null
+++ b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py
@@ -0,0 +1,154 @@
+from typing import ClassVar, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import BatchEncoding, BatchFeature
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
+from transformers.models.qwen3_vl import Qwen3VLProcessor
+
+from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
+
+
+class ColQwen3Processor(BaseVisualRetrieverProcessor, Qwen3VLProcessor):
+ """
+ Processor for ColQwen3.
+
+ Args:
+ *args: Variable length argument list to be passed to the parent `Qwen3VLProcessor` class.
+ max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
+ **kwargs: Arbitrary keyword arguments to be passed to the parent `Qwen3VLProcessor` class.
+ """
+
+ visual_prompt_prefix: ClassVar[str] = (
+ "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
+ )
+ query_augmentation_token: ClassVar[str] = "<|endoftext|>"
+ image_token: ClassVar[str] = "<|image_pad|>"
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.tokenizer.padding_side = "left"
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ *args,
+ device_map: Optional[str] = None,
+ **kwargs,
+ ):
+ instance = super().from_pretrained(
+ *args,
+ device_map=device_map,
+ **kwargs,
+ )
+
+ if "max_num_visual_tokens" in kwargs:
+ patch_size = getattr(instance.image_processor, "patch_size", None)
+ merge_size = getattr(instance.image_processor, "merge_size", None)
+ if patch_size is None or merge_size is None:
+ raise ValueError("Qwen3VL image processor is missing `patch_size` or `merge_size`.")
+ tile = patch_size * merge_size
+ instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * tile * tile
+ instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
+
+ return instance
+
+ def process_images(
+ self,
+ images: List[Image.Image],
+ ) -> Union[BatchFeature, BatchEncoding]:
+ """
+ Process images for ColQwen3.
+
+ Args:
+ images: List of PIL images.
+ """
+
+ images = [image.convert("RGB") for image in images]
+
+ batch_doc = self(
+ text=[self.visual_prompt_prefix] * len(images),
+ images=images,
+ padding="longest",
+ return_tensors="pt",
+ )
+
+ # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs.
+ offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] # (batch_size,)
+
+ # Split the pixel_values tensor into a list of tensors, one per image
+ pixel_values = list(
+ torch.split(batch_doc["pixel_values"], offsets.tolist())
+ ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)]
+
+ # Pad the list of pixel_value tensors to the same length along the sequence dimension
+ batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(
+ pixel_values, batch_first=True
+ ) # (batch_size, max_num_patches, pixel_values)
+
+ return batch_doc
+
+ def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
+ """
+ Process texts for ColQwen3.
+
+ Args:
+ texts: List of input texts.
+
+ Returns:
+ Union[BatchFeature, BatchEncoding]: Processed texts.
+ """
+ return self(
+ text=texts,
+ return_tensors="pt",
+ padding="longest",
+ )
+
+ def score(
+ self,
+ qs: List[torch.Tensor],
+ ps: List[torch.Tensor],
+ device: Optional[Union[str, torch.device]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
+ """
+ return self.score_multi_vector(qs, ps, device=device, **kwargs)
+
+ def get_n_patches(
+ self,
+ image_size: Tuple[int, int],
+ spatial_merge_size: int,
+ ) -> Tuple[int, int]:
+ """
+ Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
+ size (height, width) with the given patch size.
+
+ The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
+ as a `Qwen3VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
+ """
+ patch_size = self.image_processor.patch_size
+
+ height_new, width_new = smart_resize(
+ width=image_size[0],
+ height=image_size[1],
+ factor=patch_size * self.image_processor.merge_size,
+ min_pixels=self.image_processor.size["shortest_edge"],
+ max_pixels=self.image_processor.size["longest_edge"],
+ )
+
+ n_patches_x = width_new // patch_size // spatial_merge_size
+ n_patches_y = height_new // patch_size // spatial_merge_size
+
+ return n_patches_x, n_patches_y
+
+ def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
+ """
+ Get a tensor mask that identifies the image tokens in the batch.
+ """
+ return batch_images.input_ids == self.image_token_id
diff --git a/pyproject.toml b/pyproject.toml
index 7fb2ee391..999dc7150 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@ dependencies = [
"scipy",
"torch>=2.2.0,<2.9.0",
"torchvision",
- "transformers>=4.53.1,<4.58.0",
+ "transformers>=4.57.0,<4.58.0",
]
[project.optional-dependencies]
diff --git a/scripts/configs/qwen3/train_colqwen3_model.py b/scripts/configs/qwen3/train_colqwen3_model.py
new file mode 100644
index 000000000..bef8f3df8
--- /dev/null
+++ b/scripts/configs/qwen3/train_colqwen3_model.py
@@ -0,0 +1,100 @@
+import argparse
+import shutil
+from pathlib import Path
+
+import torch
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import TrainingArguments
+
+from colpali_engine.data.dataset import ColPaliEngineDataset
+from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss
+from colpali_engine.models import ColQwen3, ColQwen3Processor
+from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining
+from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
+from colpali_engine.utils.dataset_transformation import load_train_set
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+ p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy")
+ p.add_argument("--lr", type=float, default=2e-4, help="learning rate")
+ p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function")
+ p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use")
+ p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use")
+ p.add_argument("--peft", action="store_true", help="use PEFT for training")
+
+ return p.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ if args.loss == "ce":
+ loss_func = ColbertLoss(
+ temperature=args.tau,
+ normalize_scores=True,
+ use_smooth_max=False,
+ pos_aware_negative_filtering=False,
+ )
+ elif args.loss == "pairwise":
+ loss_func = ColbertPairwiseCELoss(
+ normalize_scores=False,
+ )
+ else:
+ raise ValueError(f"Unknown loss function: {args.loss}")
+
+ config = ColModelTrainingConfig(
+ output_dir=args.output_dir,
+ processor=ColQwen3Processor.from_pretrained(
+ pretrained_model_name_or_path="./models/base_models/colqwen3-base",
+ max_num_visual_tokens=768,
+ ),
+ model=ColQwen3.from_pretrained(
+ pretrained_model_name_or_path="./models/base_models/colqwen3-base",
+ torch_dtype=torch.bfloat16,
+ use_cache=False,
+ attn_implementation="flash_attention_2",
+ ),
+ train_dataset=load_train_set(),
+ eval_dataset=ColPaliEngineDataset(
+ load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image"
+ ),
+ run_eval=True,
+ loss_func=loss_func,
+ tr_args=TrainingArguments(
+ output_dir=None,
+ overwrite_output_dir=True,
+ num_train_epochs=5,
+ per_device_train_batch_size=64,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ per_device_eval_batch_size=16,
+ eval_strategy="steps",
+ dataloader_num_workers=8,
+ save_steps=500,
+ logging_steps=10,
+ eval_steps=100,
+ warmup_steps=100,
+ learning_rate=args.lr,
+ save_total_limit=1,
+ ),
+ peft_config=LoraConfig(
+ r=32,
+ lora_alpha=32,
+ lora_dropout=0.1,
+ init_lora_weights="gaussian",
+ bias="none",
+ task_type="FEATURE_EXTRACTION",
+ target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)",
+ )
+ if args.peft
+ else None,
+ )
+
+ Path(config.output_dir).mkdir(parents=True, exist_ok=True)
+ shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name)
+
+ trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config)
+ trainer.train()
+ trainer.save()
diff --git a/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py b/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py
new file mode 100644
index 000000000..239713d6a
--- /dev/null
+++ b/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py
@@ -0,0 +1,135 @@
+import logging
+from typing import Generator, cast
+
+import pytest
+import torch
+from datasets import load_dataset
+from PIL import Image
+from transformers.utils.import_utils import is_flash_attn_2_available
+
+from colpali_engine.models import ColQwen3, ColQwen3Processor
+from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="module")
+def model_name() -> str:
+ return "TomoroAI/tomoro-colqwen3-embed-4b"
+
+
+@pytest.fixture(scope="module")
+def model_without_mask(model_name: str) -> Generator[ColQwen3, None, None]:
+ device = get_torch_device("auto")
+ logger.info(f"Device used: {device}")
+
+ yield cast(
+ ColQwen3,
+ ColQwen3.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ device_map=device,
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
+ mask_non_image_embeddings=False,
+ ).eval(),
+ )
+ tear_down_torch()
+
+
+@pytest.fixture(scope="module")
+def model_with_mask(model_name: str) -> Generator[ColQwen3, None, None]:
+ device = get_torch_device("auto")
+ logger.info(f"Device used: {device}")
+
+ yield cast(
+ ColQwen3,
+ ColQwen3.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ device_map=device,
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
+ mask_non_image_embeddings=True,
+ ).eval(),
+ )
+ tear_down_torch()
+
+
+@pytest.fixture(scope="module")
+def processor(model_name: str) -> Generator[ColQwen3Processor, None, None]:
+ yield cast(ColQwen3Processor, ColQwen3Processor.from_pretrained(model_name))
+
+
+class TestColQwen3Model:
+ @pytest.mark.slow
+ def test_load_model_from_pretrained(self, model_without_mask: ColQwen3):
+ assert isinstance(model_without_mask, ColQwen3)
+
+
+class TestColQwen3ModelIntegration:
+ @pytest.mark.slow
+ def test_forward_images_integration(
+ self,
+ model_without_mask: ColQwen3,
+ processor: ColQwen3Processor,
+ ):
+ images = [
+ Image.new("RGB", (64, 64), color="white"),
+ Image.new("RGB", (32, 32), color="black"),
+ ]
+ batch_images = processor.process_images(images).to(model_without_mask.device)
+
+ with torch.no_grad():
+ outputs = model_without_mask(**batch_images)
+
+ assert isinstance(outputs, torch.Tensor)
+ assert outputs.dim() == 3
+ batch_size, n_visual_tokens, emb_dim = outputs.shape
+ assert batch_size == len(images)
+ assert n_visual_tokens >= 1
+ assert emb_dim == model_without_mask.dim
+
+ @pytest.mark.slow
+ def test_forward_queries_integration(
+ self,
+ model_without_mask: ColQwen3,
+ processor: ColQwen3Processor,
+ ):
+ queries = [
+ "Is attention really all you need?",
+ "Are Benjamin, Antoine, Merve, and Jo best friends?",
+ ]
+ batch_queries = processor.process_queries(queries).to(model_without_mask.device)
+
+ with torch.no_grad():
+ outputs = model_without_mask(**batch_queries)
+
+ assert isinstance(outputs, torch.Tensor)
+ assert outputs.dim() == 3
+ batch_size, n_query_tokens, emb_dim = outputs.shape
+ assert batch_size == len(queries)
+ assert n_query_tokens >= 1
+ assert emb_dim == model_without_mask.dim
+
+ @pytest.mark.slow
+ def test_retrieval_integration(
+ self,
+ model_without_mask: ColQwen3,
+ processor: ColQwen3Processor,
+ ):
+ ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
+
+ batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device)
+ batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device)
+
+ with torch.inference_mode():
+ image_embeddings = model_without_mask(**batch_images)
+ query_embeddings = model_without_mask(**batch_queries)
+
+ scores = processor.score_multi_vector(
+ qs=query_embeddings,
+ ps=image_embeddings,
+ )
+
+ assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
+ assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
+ assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all()
diff --git a/tests/models/qwen3/colqwen3/test_processing_colqwen3.py b/tests/models/qwen3/colqwen3/test_processing_colqwen3.py
new file mode 100644
index 000000000..971afa578
--- /dev/null
+++ b/tests/models/qwen3/colqwen3/test_processing_colqwen3.py
@@ -0,0 +1,61 @@
+from typing import Generator, cast
+
+import pytest
+import torch
+from PIL import Image
+
+from colpali_engine.models import ColQwen3Processor
+
+
+@pytest.fixture(scope="module")
+def model_name() -> str:
+ return "TomoroAI/tomoro-colqwen3-embed-4b"
+
+
+@pytest.fixture(scope="module")
+def processor_from_pretrained(model_name: str) -> Generator[ColQwen3Processor, None, None]:
+ yield cast(ColQwen3Processor, ColQwen3Processor.from_pretrained(model_name))
+
+
+def test_load_processor_from_pretrained(processor_from_pretrained: ColQwen3Processor):
+ assert isinstance(processor_from_pretrained, ColQwen3Processor)
+
+
+def test_process_images(processor_from_pretrained: ColQwen3Processor):
+ image_size = (64, 32)
+ image = Image.new("RGB", image_size, color="black")
+ images = [image]
+
+ batch_feature = processor_from_pretrained.process_images(images)
+
+ assert "pixel_values" in batch_feature
+ assert isinstance(batch_feature["pixel_values"], torch.Tensor)
+ assert batch_feature["pixel_values"].shape[0] == len(images)
+ assert batch_feature["pixel_values"].shape[1] >= 1
+ assert batch_feature["pixel_values"].shape[-1] > 0
+
+
+def test_process_texts(processor_from_pretrained: ColQwen3Processor):
+ queries = [
+ "Is attention really all you need?",
+ "Are Benjamin, Antoine, Merve, and Jo best friends?",
+ ]
+
+ batch_encoding = processor_from_pretrained.process_texts(queries)
+
+ assert "input_ids" in batch_encoding
+ assert isinstance(batch_encoding["input_ids"], torch.Tensor)
+ assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)
+
+
+def test_process_queries(processor_from_pretrained: ColQwen3Processor):
+ queries = [
+ "Is attention really all you need?",
+ "Are Benjamin, Antoine, Merve, and Jo best friends?",
+ ]
+
+ batch_encoding = processor_from_pretrained.process_queries(queries)
+
+ assert "input_ids" in batch_encoding
+ assert isinstance(batch_encoding["input_ids"], torch.Tensor)
+ assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)