From 27de473e5b2de0ec4df5e677c2dbd58e24b3fa0e Mon Sep 17 00:00:00 2001 From: Huang Xin Date: Tue, 16 Dec 2025 17:09:00 +0800 Subject: [PATCH 1/7] Add ColQwen3 Support --- CHANGELOG.md | 8 + README.md | 1 + colpali_engine/__init__.py | 4 + colpali_engine/models/__init__.py | 1 + colpali_engine/models/qwen3/__init__.py | 2 + .../models/qwen3/biqwen3/__init__.py | 2 + .../models/qwen3/biqwen3/modeling_biqwen3.py | 94 +++++++++++ .../qwen3/biqwen3/processing_biqwen3.py | 37 +++++ .../models/qwen3/colqwen3/__init__.py | 2 + .../qwen3/colqwen3/modeling_colqwen3.py | 101 ++++++++++++ .../qwen3/colqwen3/processing_colqwen3.py | 154 ++++++++++++++++++ scripts/configs/qwen3/train_colqwen3_model.py | 100 ++++++++++++ .../qwen3/colqwen3/test_modeling_colqwen3.py | 135 +++++++++++++++ .../colqwen3/test_processing_colqwen3.py | 61 +++++++ 14 files changed, 702 insertions(+) create mode 100644 colpali_engine/models/qwen3/__init__.py create mode 100644 colpali_engine/models/qwen3/biqwen3/__init__.py create mode 100644 colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py create mode 100644 colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py create mode 100644 colpali_engine/models/qwen3/colqwen3/__init__.py create mode 100644 colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py create mode 100644 colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py create mode 100644 scripts/configs/qwen3/train_colqwen3_model.py create mode 100644 tests/models/qwen3/colqwen3/test_modeling_colqwen3.py create mode 100644 tests/models/qwen3/colqwen3/test_processing_colqwen3.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a5a82d88..3c0e624b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Added + +- Add ColQwen3 and BiQwen3 support (model + processor). + +### Tests + +- Cover ColQwen3 processing and modeling with slow integration tests. + ## [0.3.13] - 2025-11-15 ### Added diff --git a/README.md b/README.md index 48355bfc0..113055464 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ Using ColPali removes the need for potentially complex and brittle layout recogn | [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3 | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256). | ✅ | | [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8 | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`
• Supports dynamic resolution.
• Trained using 768 image patches per page and an effective batch size of 32. | ✅ | | [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4 | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters | ✅ | +| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | TBD | Apache 2.0 | • Based on the Qwen3-VL backbone.
• 320-dim ColBERT-style embeddings with dynamic resolution.
• Trained for multi-vector document retrieval. | ✅ | | [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M) | 80.1 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`. | ✅ | | [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M) | 82.3 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`. | ✅ | diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index 9108043c5..159bd8cba 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -7,6 +7,8 @@ BiQwen2_5, BiQwen2_5_Processor, BiQwen2Processor, + BiQwen3, + BiQwen3Processor, ColIdefics3, ColIdefics3Processor, ColModernVBert, @@ -16,6 +18,8 @@ ColQwen2, ColQwen2_5, ColQwen2_5_Processor, + ColQwen3, + ColQwen3Processor, # ColQwen2_5Omni, # ColQwen2_5OmniProcessor, ColQwen2Processor, diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index cb9a71ace..576f90af7 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -3,3 +3,4 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor +from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor diff --git a/colpali_engine/models/qwen3/__init__.py b/colpali_engine/models/qwen3/__init__.py new file mode 100644 index 000000000..efcee26f4 --- /dev/null +++ b/colpali_engine/models/qwen3/__init__.py @@ -0,0 +1,2 @@ +from .biqwen3 import BiQwen3, BiQwen3Processor +from .colqwen3 import ColQwen3, ColQwen3Processor diff --git a/colpali_engine/models/qwen3/biqwen3/__init__.py b/colpali_engine/models/qwen3/biqwen3/__init__.py new file mode 100644 index 000000000..1b7881d83 --- /dev/null +++ b/colpali_engine/models/qwen3/biqwen3/__init__.py @@ -0,0 +1,2 @@ +from .modeling_biqwen3 import BiQwen3 +from .processing_biqwen3 import BiQwen3Processor diff --git a/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py b/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py new file mode 100644 index 000000000..fc4804e4b --- /dev/null +++ b/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py @@ -0,0 +1,94 @@ +from typing import ClassVar, Literal + +import torch +from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel + + +class BiQwen3(Qwen3VLModel): + """ + BiQwen3 implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + Representations are pooled to obtain a single vector representation. Based on the Qwen3-VL backbone. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" + _checkpoint_conversion_mapping = { + r"^model\.visual": "visual", + r"^model\.language_model": "language_model", + r"^model\.": "", + } + + def __init__(self, config: Qwen3VLConfig, **kwargs): + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + attn_impl = kwargs.pop("attn_implementation", None) + use_cache = kwargs.pop("use_cache", None) + + super().__init__(config=config) + self.padding_side = "left" + self.post_init() + + if dtype is not None: + self.to(dtype=dtype) + if use_cache is not None: + self.config.use_cache = use_cache + if attn_impl is not None and hasattr(self, "set_attn_implementation"): + self.set_attn_implementation(attn_impl) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = "last", + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass for BiQwen3 model. + + Args: + pooling_strategy: The strategy to use for pooling the hidden states. + *args: Variable length argument list. + **kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: Dense embeddings (batch_size, hidden_size). + """ + if "pixel_values" in kwargs: + offsets = kwargs["image_grid_thw"].prod(dim=1).tolist() + kwargs["pixel_values"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], + dim=0, + ) + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) # (batch_size, sequence_length, hidden_size) + + if pooling_strategy == "cls": + pooled = last_hidden_states[:, 0] + elif pooling_strategy == "last": + pooled = last_hidden_states[:, -1] + elif pooling_strategy == "mean": + mask = kwargs["attention_mask"].unsqueeze(-1) + pooled = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + return pooled / pooled.norm(dim=-1, keepdim=True) + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py b/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py new file mode 100644 index 000000000..c13eaff0d --- /dev/null +++ b/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.qwen3.colqwen3 import ColQwen3Processor + + +class BiQwen3Processor(ColQwen3Processor): + """ + Processor for BiQwen3. + """ + + def process_texts( + self, + texts: List[str], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiQwen3. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/qwen3/colqwen3/__init__.py b/colpali_engine/models/qwen3/colqwen3/__init__.py new file mode 100644 index 000000000..6369cb69b --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colqwen3 import ColQwen3 +from .processing_colqwen3 import ColQwen3Processor diff --git a/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py new file mode 100644 index 000000000..019c87bd5 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py @@ -0,0 +1,101 @@ +from typing import ClassVar + +import torch +from torch import nn +from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel + + +class ColQwen3(Qwen3VLModel): + """ + ColQwen3 model implementation, following the architecture from the article "ColPali: Efficient Document Retrieval + with Vision Language Models" paper. Based on the Qwen3-VL backbone. + + Args: + config (Qwen3VLConfig): The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + _checkpoint_conversion_mapping = { + r"^model\.visual": "visual", + r"^model\.language_model": "language_model", + r"^model\.": "", + } + + def __init__( + self, + config: Qwen3VLConfig, + mask_non_image_embeddings: bool = False, + **kwargs, + ): + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + attn_impl = kwargs.pop("attn_implementation", None) + use_cache = kwargs.pop("use_cache", None) + + super().__init__(config=config) + + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "text_config"): + hidden_size = self.config.text_config.hidden_size + if hidden_size is None: + raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.") + + self.dim = 320 + self.custom_text_proj = nn.Linear(hidden_size, self.dim) + self.padding_side = "left" + self.mask_non_image_embeddings = mask_non_image_embeddings + self.post_init() + + if dtype is not None: + self.to(dtype=dtype) + if use_cache is not None: + self.config.use_cache = use_cache + if attn_impl is not None and hasattr(self, "set_attn_implementation"): + self.set_attn_implementation(attn_impl) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward(self, *args, **kwargs) -> torch.Tensor: + # Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding + if "pixel_values" in kwargs: + offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,) + kwargs["pixel_values"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], + dim=0, + ) + + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) # (batch_size, sequence_length, hidden_size) + + proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py new file mode 100644 index 000000000..436c7336e --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py @@ -0,0 +1,154 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen3_vl import Qwen3VLProcessor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColQwen3Processor(BaseVisualRetrieverProcessor, Qwen3VLProcessor): + """ + Processor for ColQwen3. + + Args: + *args: Variable length argument list to be passed to the parent `Qwen3VLProcessor` class. + max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model. + **kwargs: Arbitrary keyword arguments to be passed to the parent `Qwen3VLProcessor` class. + """ + + visual_prompt_prefix: ClassVar[str] = ( + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + ) + query_augmentation_token: ClassVar[str] = "<|endoftext|>" + image_token: ClassVar[str] = "<|image_pad|>" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.tokenizer.padding_side = "left" + + @classmethod + def from_pretrained( + cls, + *args, + device_map: Optional[str] = None, + **kwargs, + ): + instance = super().from_pretrained( + *args, + device_map=device_map, + **kwargs, + ) + + if "max_num_visual_tokens" in kwargs: + patch_size = getattr(instance.image_processor, "patch_size", None) + merge_size = getattr(instance.image_processor, "merge_size", None) + if patch_size is None or merge_size is None: + raise ValueError("Qwen3VL image processor is missing `patch_size` or `merge_size`.") + tile = patch_size * merge_size + instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * tile * tile + instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels + + return instance + + def process_images( + self, + images: List[Image.Image], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColQwen3. + + Args: + images: List of PIL images. + """ + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=[self.visual_prompt_prefix] * len(images), + images=images, + padding="longest", + return_tensors="pt", + ) + + # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs. + offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] # (batch_size,) + + # Split the pixel_values tensor into a list of tensors, one per image + pixel_values = list( + torch.split(batch_doc["pixel_values"], offsets.tolist()) + ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)] + + # Pad the list of pixel_value tensors to the same length along the sequence dimension + batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence( + pixel_values, batch_first=True + ) # (batch_size, max_num_patches, pixel_values) + + return batch_doc + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColQwen3. + + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + """ + Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of + size (height, width) with the given patch size. + + The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in + as a `Qwen3VLForConditionalGeneration` attribute under `model.spatial_merge_size`. + """ + patch_size = self.image_processor.patch_size + + height_new, width_new = smart_resize( + width=image_size[0], + height=image_size[1], + factor=patch_size * self.image_processor.merge_size, + min_pixels=self.image_processor.size["shortest_edge"], + max_pixels=self.image_processor.size["longest_edge"], + ) + + n_patches_x = width_new // patch_size // spatial_merge_size + n_patches_y = height_new // patch_size // spatial_merge_size + + return n_patches_x, n_patches_y + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + """ + Get a tensor mask that identifies the image tokens in the batch. + """ + return batch_images.input_ids == self.image_token_id diff --git a/scripts/configs/qwen3/train_colqwen3_model.py b/scripts/configs/qwen3/train_colqwen3_model.py new file mode 100644 index 000000000..bef8f3df8 --- /dev/null +++ b/scripts/configs/qwen3/train_colqwen3_model.py @@ -0,0 +1,100 @@ +import argparse +import shutil +from pathlib import Path + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import TrainingArguments + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss +from colpali_engine.models import ColQwen3, ColQwen3Processor +from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining +from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig +from colpali_engine.utils.dataset_transformation import load_train_set + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy") + p.add_argument("--lr", type=float, default=2e-4, help="learning rate") + p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function") + p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use") + p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use") + p.add_argument("--peft", action="store_true", help="use PEFT for training") + + return p.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + if args.loss == "ce": + loss_func = ColbertLoss( + temperature=args.tau, + normalize_scores=True, + use_smooth_max=False, + pos_aware_negative_filtering=False, + ) + elif args.loss == "pairwise": + loss_func = ColbertPairwiseCELoss( + normalize_scores=False, + ) + else: + raise ValueError(f"Unknown loss function: {args.loss}") + + config = ColModelTrainingConfig( + output_dir=args.output_dir, + processor=ColQwen3Processor.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colqwen3-base", + max_num_visual_tokens=768, + ), + model=ColQwen3.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colqwen3-base", + torch_dtype=torch.bfloat16, + use_cache=False, + attn_implementation="flash_attention_2", + ), + train_dataset=load_train_set(), + eval_dataset=ColPaliEngineDataset( + load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image" + ), + run_eval=True, + loss_func=loss_func, + tr_args=TrainingArguments( + output_dir=None, + overwrite_output_dir=True, + num_train_epochs=5, + per_device_train_batch_size=64, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + per_device_eval_batch_size=16, + eval_strategy="steps", + dataloader_num_workers=8, + save_steps=500, + logging_steps=10, + eval_steps=100, + warmup_steps=100, + learning_rate=args.lr, + save_total_limit=1, + ), + peft_config=LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights="gaussian", + bias="none", + task_type="FEATURE_EXTRACTION", + target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)", + ) + if args.peft + else None, + ) + + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name) + + trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config) + trainer.train() + trainer.save() diff --git a/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py b/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py new file mode 100644 index 000000000..239713d6a --- /dev/null +++ b/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py @@ -0,0 +1,135 @@ +import logging +from typing import Generator, cast + +import pytest +import torch +from datasets import load_dataset +from PIL import Image +from transformers.utils.import_utils import is_flash_attn_2_available + +from colpali_engine.models import ColQwen3, ColQwen3Processor +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "TomoroAI/tomoro-colqwen3-embed-4b" + + +@pytest.fixture(scope="module") +def model_without_mask(model_name: str) -> Generator[ColQwen3, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColQwen3, + ColQwen3.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + mask_non_image_embeddings=False, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def model_with_mask(model_name: str) -> Generator[ColQwen3, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColQwen3, + ColQwen3.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + mask_non_image_embeddings=True, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def processor(model_name: str) -> Generator[ColQwen3Processor, None, None]: + yield cast(ColQwen3Processor, ColQwen3Processor.from_pretrained(model_name)) + + +class TestColQwen3Model: + @pytest.mark.slow + def test_load_model_from_pretrained(self, model_without_mask: ColQwen3): + assert isinstance(model_without_mask, ColQwen3) + + +class TestColQwen3ModelIntegration: + @pytest.mark.slow + def test_forward_images_integration( + self, + model_without_mask: ColQwen3, + processor: ColQwen3Processor, + ): + images = [ + Image.new("RGB", (64, 64), color="white"), + Image.new("RGB", (32, 32), color="black"), + ] + batch_images = processor.process_images(images).to(model_without_mask.device) + + with torch.no_grad(): + outputs = model_without_mask(**batch_images) + + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_visual_tokens, emb_dim = outputs.shape + assert batch_size == len(images) + assert n_visual_tokens >= 1 + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_forward_queries_integration( + self, + model_without_mask: ColQwen3, + processor: ColQwen3Processor, + ): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + batch_queries = processor.process_queries(queries).to(model_without_mask.device) + + with torch.no_grad(): + outputs = model_without_mask(**batch_queries) + + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_query_tokens, emb_dim = outputs.shape + assert batch_size == len(queries) + assert n_query_tokens >= 1 + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_retrieval_integration( + self, + model_without_mask: ColQwen3, + processor: ColQwen3Processor, + ): + ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") + + batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device) + batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device) + + with torch.inference_mode(): + image_embeddings = model_without_mask(**batch_images) + query_embeddings = model_without_mask(**batch_queries) + + scores = processor.score_multi_vector( + qs=query_embeddings, + ps=image_embeddings, + ) + + assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" + assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" + assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() diff --git a/tests/models/qwen3/colqwen3/test_processing_colqwen3.py b/tests/models/qwen3/colqwen3/test_processing_colqwen3.py new file mode 100644 index 000000000..971afa578 --- /dev/null +++ b/tests/models/qwen3/colqwen3/test_processing_colqwen3.py @@ -0,0 +1,61 @@ +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from colpali_engine.models import ColQwen3Processor + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "TomoroAI/tomoro-colqwen3-embed-4b" + + +@pytest.fixture(scope="module") +def processor_from_pretrained(model_name: str) -> Generator[ColQwen3Processor, None, None]: + yield cast(ColQwen3Processor, ColQwen3Processor.from_pretrained(model_name)) + + +def test_load_processor_from_pretrained(processor_from_pretrained: ColQwen3Processor): + assert isinstance(processor_from_pretrained, ColQwen3Processor) + + +def test_process_images(processor_from_pretrained: ColQwen3Processor): + image_size = (64, 32) + image = Image.new("RGB", image_size, color="black") + images = [image] + + batch_feature = processor_from_pretrained.process_images(images) + + assert "pixel_values" in batch_feature + assert isinstance(batch_feature["pixel_values"], torch.Tensor) + assert batch_feature["pixel_values"].shape[0] == len(images) + assert batch_feature["pixel_values"].shape[1] >= 1 + assert batch_feature["pixel_values"].shape[-1] > 0 + + +def test_process_texts(processor_from_pretrained: ColQwen3Processor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + batch_encoding = processor_from_pretrained.process_texts(queries) + + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) + + +def test_process_queries(processor_from_pretrained: ColQwen3Processor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + batch_encoding = processor_from_pretrained.process_queries(queries) + + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) From 0bef0c811c4c17b0a616c307f9e1144b4658edb7 Mon Sep 17 00:00:00 2001 From: tankm Date: Wed, 17 Dec 2025 10:45:41 +0800 Subject: [PATCH 2/7] fix(colqwen3): fix ruff lint errors --- colpali_engine/__init__.py | 4 ++-- .../colmodernvbert/generate_interpretability_maps.py | 4 ++-- .../modernvbert/test_interpretability_colmodernvbert.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index 159bd8cba..1044b1484 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -18,9 +18,9 @@ ColQwen2, ColQwen2_5, ColQwen2_5_Processor, - ColQwen3, - ColQwen3Processor, # ColQwen2_5Omni, # ColQwen2_5OmniProcessor, ColQwen2Processor, + ColQwen3, + ColQwen3Processor, ) diff --git a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py index 24185ba1d..6af55a695 100644 --- a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py +++ b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py @@ -9,9 +9,9 @@ python examples/interpretability/colmodernvbert/simple_interpretability_example.py """ -from pathlib import Path import uuid -from typing import cast, Any +from pathlib import Path +from typing import Any, cast import matplotlib.pyplot as plt import torch diff --git a/tests/models/modernvbert/test_interpretability_colmodernvbert.py b/tests/models/modernvbert/test_interpretability_colmodernvbert.py index 310300092..e0b945b37 100644 --- a/tests/models/modernvbert/test_interpretability_colmodernvbert.py +++ b/tests/models/modernvbert/test_interpretability_colmodernvbert.py @@ -13,10 +13,10 @@ import torch from PIL import Image -from colpali_engine.models import ColModernVBert, ColModernVBertProcessor from colpali_engine.interpretability.similarity_map_utils import ( normalize_similarity_map, ) +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor @pytest.fixture(scope="module") @@ -117,7 +117,7 @@ def test_get_n_patches_aspect_ratio_preservation( # The aspect ratio of patches should be close to 2:1 patch_ratio = n_patches_x / n_patches_y - expected_ratio = 2.0 + # expected_ratio = 2.0 # Allow tolerance due to: # 1. Image splitting into 512x512 sub-patches (quantization effects) From e4d96433e1ff9366499b1a2f44235c884a78c39b Mon Sep 17 00:00:00 2001 From: Amrit Bath Date: Thu, 18 Dec 2025 09:12:40 -0800 Subject: [PATCH 3/7] Re-enable colqwen 2.5 Omni; fix resize token embeddings (#367) * looks like colqwen 2.5 omni support was accidentally removed in https://github.com/illuin-tech/colpali/pull/339 EDIT: that was based upon just looking at the main __init__.py. looking at the other files, perhaps it was intentionally removed... * found & fixed resize_token_embeddings() breakage --- colpali_engine/__init__.py | 4 ++-- colpali_engine/models/__init__.py | 1 + colpali_engine/models/qwen_omni/__init__.py | 1 + .../models/qwen_omni/colqwen_omni/__init__.py | 2 ++ .../qwen_omni/colqwen_omni/modeling_colqwen_omni.py | 10 +++++++++- 5 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 colpali_engine/models/qwen_omni/__init__.py create mode 100644 colpali_engine/models/qwen_omni/colqwen_omni/__init__.py diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index 1044b1484..1f5098f74 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -18,8 +18,8 @@ ColQwen2, ColQwen2_5, ColQwen2_5_Processor, - # ColQwen2_5Omni, - # ColQwen2_5OmniProcessor, + ColQwen2_5Omni, + ColQwen2_5OmniProcessor, ColQwen2Processor, ColQwen3, ColQwen3Processor, diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 576f90af7..b98544fcb 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -4,3 +4,4 @@ from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor +from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen_omni/__init__.py b/colpali_engine/models/qwen_omni/__init__.py new file mode 100644 index 000000000..7dd081290 --- /dev/null +++ b/colpali_engine/models/qwen_omni/__init__.py @@ -0,0 +1 @@ +from .colqwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py b/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py new file mode 100644 index 000000000..b754b5527 --- /dev/null +++ b/colpali_engine/models/qwen_omni/colqwen_omni/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colqwen_omni import ColQwen2_5Omni +from .processing_colqwen_omni import ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen_omni/colqwen_omni/modeling_colqwen_omni.py b/colpali_engine/models/qwen_omni/colqwen_omni/modeling_colqwen_omni.py index 0a77471e2..300efcc1c 100644 --- a/colpali_engine/models/qwen_omni/colqwen_omni/modeling_colqwen_omni.py +++ b/colpali_engine/models/qwen_omni/colqwen_omni/modeling_colqwen_omni.py @@ -21,9 +21,17 @@ def __init__(self, config: Qwen2_5OmniThinkerConfig, mask_non_image_embeddings: self.lm_head = nn.Identity() # Disable the original lm_head self.padding_side = "left" self.mask_non_image_embeddings = mask_non_image_embeddings - self.lm_head = nn.Identity() # Disable the original lm_head self.post_init() + def get_output_embeddings(self) -> None: # -> None | Any: + """ + Transformers >=4.54.0 fails during resize_token_embeddings() due to a new get_output_embeddings() + impl. The latter used to return None unless overridden, but they made it try harder to return + *something*. Of course, this was not flagged as a breaking change. Eventually I found the + responsible PR, and it endorses this change: https://github.com/huggingface/transformers/pull/39339 + """ + return None + def forward(self, *args, **kwargs) -> torch.Tensor: # # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding # if "pixel_values" in kwargs: From 1edd6a990f580e523aede4b5fc967d660d8529f7 Mon Sep 17 00:00:00 2001 From: Manuel Faysse <43467008+ManuelFay@users.noreply.github.com> Date: Fri, 19 Dec 2025 11:15:42 +0100 Subject: [PATCH 4/7] lint (#368) * lint * lint examples --- .../interpretability/similarity_map_utils.py | 12 +- .../interpretability/similarity_maps.py | 5 +- .../colidefics3/processing_colidefics3.py | 8 +- .../colvbert/processing_colmodernvbert.py | 4 +- colpali_engine/utils/processing_utils.py | 4 +- .../generate_interpretability_maps.py | 12 +- .../test_processing_colidefics3.py | 34 ++---- .../test_interpretability_colmodernvbert.py | 109 ++++++------------ 8 files changed, 59 insertions(+), 129 deletions(-) diff --git a/colpali_engine/interpretability/similarity_map_utils.py b/colpali_engine/interpretability/similarity_map_utils.py index 9ab6f192d..8b2f5e8d5 100644 --- a/colpali_engine/interpretability/similarity_map_utils.py +++ b/colpali_engine/interpretability/similarity_map_utils.py @@ -76,14 +76,14 @@ def normalize_similarity_map( if value_range is None: # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y) - min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min( - dim=-2, keepdim=True - )[0] # (1, 1) or (batch_size, 1, 1) + min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[ + 0 + ] # (1, 1) or (batch_size, 1, 1) # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y) - max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max( - dim=-2, keepdim=True - )[0] # (1, 1) or (batch_size, 1, 1) + max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[ + 0 + ] # (1, 1) or (batch_size, 1, 1) else: min_vals, max_vals = value_range broadcast_shape = (1,) * similarity_map.ndim diff --git a/colpali_engine/interpretability/similarity_maps.py b/colpali_engine/interpretability/similarity_maps.py index ce95d653e..477942b40 100644 --- a/colpali_engine/interpretability/similarity_maps.py +++ b/colpali_engine/interpretability/similarity_maps.py @@ -43,10 +43,7 @@ def plot_similarity_map( # Normalize the similarity map and convert it to Pillow image similarity_map_array = ( - normalize_similarity_map(similarity_map, value_range=normalization_range) - .to(torch.float32) - .cpu() - .numpy() + normalize_similarity_map(similarity_map, value_range=normalization_range).to(torch.float32).cpu().numpy() ) # (n_patches_x, n_patches_y) # Reshape the similarity map to match the PIL shape convention diff --git a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py index acd4b0e37..8dea42928 100644 --- a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py +++ b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py @@ -22,9 +22,7 @@ class ColIdefics3Processor( query_augmentation_token: ClassVar[str] = "" image_token: ClassVar[str] = "" - visual_prompt_prefix: ClassVar[str] = ( - "<|im_start|>User:Describe the image.\nAssistant:" - ) + visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:Describe the image.\nAssistant:" def __init__(self, *args, image_seq_len=64, **kwargs): super().__init__(*args, image_seq_len=image_seq_len, **kwargs) @@ -105,9 +103,7 @@ def get_n_patches( longest_edge = self.image_processor.size.get("longest_edge", 4 * patch_size) # Step 1: Calculate resized dimensions using the mixin helper method - height_new, width_new = self._calculate_resized_dimensions( - image_size, longest_edge - ) + height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge) # Step 2: Calculate the number of patches in each direction # This mirrors the split_image logic from Idefics3ImageProcessor diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py index 0aa42aee4..786c94339 100644 --- a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -110,9 +110,7 @@ def get_n_patches( longest_edge = self.image_processor.size.get("longest_edge", 2048) # Step 1: Calculate resized dimensions using the mixin helper method - height_new, width_new = self._calculate_resized_dimensions( - image_size, longest_edge - ) + height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge) # Step 2: Calculate number of sub-patches (512x512 patches) # This mirrors the split_image logic from Idefics3ImageProcessor diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 86ef191bd..a25779e63 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -495,9 +495,7 @@ def get_similarity_maps_from_embeddings( # query: (query_tokens, dim) # image_grid: (n_patches_x, n_patches_y, dim) # result: (query_tokens, n_patches_x, n_patches_y) - similarity_map = torch.einsum( - "nk,ijk->nij", query_embeddings[idx], image_embedding_grid - ) + similarity_map = torch.einsum("nk,ijk->nij", query_embeddings[idx], image_embedding_grid) similarity_maps.append(similarity_map) diff --git a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py index 6af55a695..c26a5463b 100644 --- a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py +++ b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py @@ -33,9 +33,7 @@ def main(): print("Loading a real document from DocVQA dataset...") from datasets import load_dataset - dataset = load_dataset( - "vidore/docvqa_test_subsampled", split="test", streaming=True - ) + dataset = load_dataset("vidore/docvqa_test_subsampled", split="test", streaming=True) # streaming datasets may yield values that type checkers treat as Sequence; # cast to dict so string indexing (sample["image"]) is accepted by the type checker. sample = dict(next(iter(dataset))) @@ -81,9 +79,7 @@ def main(): ) # Get the similarity map for our input image - similarity_maps = similarity_maps_batch[ - 0 - ] # (query_length, n_patches_x, n_patches_y) + similarity_maps = similarity_maps_batch[0] # (query_length, n_patches_x, n_patches_y) print(f"Similarity map shape: {similarity_maps.shape}") # Get query tokens (filtering out special tokens) @@ -105,9 +101,7 @@ def main(): # Clean tokens for display (remove special characters that may cause encoding issues) display_tokens = [t.replace("Ġ", " ").replace("▁", " ") for t in filtered_tokens] print(f"\nQuery tokens: {display_tokens}") - print( - f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]" - ) + print(f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]") # Generate all similarity maps print("\nGenerating similarity maps for all tokens...") diff --git a/tests/models/idefics3/colidefics3/test_processing_colidefics3.py b/tests/models/idefics3/colidefics3/test_processing_colidefics3.py index fce61a173..21a28d4c7 100644 --- a/tests/models/idefics3/colidefics3/test_processing_colidefics3.py +++ b/tests/models/idefics3/colidefics3/test_processing_colidefics3.py @@ -74,15 +74,11 @@ def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor): Test that get_n_patches returns the correct number of patches for various image sizes. """ # Get the patch size from the image processor - patch_size = processor_from_pretrained.image_processor.max_image_size.get( - "longest_edge", 512 - ) + patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512) # Test case 1: Small square image image_size = (100, 100) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert isinstance(n_patches_x, int) assert isinstance(n_patches_y, int) assert n_patches_x > 0 @@ -90,23 +86,17 @@ def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor): # Test case 2: Wide image (width > height) image_size = (100, 200) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert n_patches_x >= n_patches_y # More patches along width # Test case 3: Tall image (height > width) image_size = (200, 100) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert n_patches_y >= n_patches_x # More patches along height # Test case 4: Square image image_size = (500, 500) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert n_patches_x == n_patches_y # Equal patches for square image @@ -126,22 +116,18 @@ def test_get_n_patches_matches_actual_processing( actual_num_patches = batch_feature["pixel_values"].shape[1] # Get the patch size from the image processor - patch_size = processor_from_pretrained.image_processor.max_image_size.get( - "longest_edge", 512 - ) + patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512) # Calculate expected patches using get_n_patches # Note: image_size for get_n_patches is (height, width), but PIL uses (width, height) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - (image_size[1], image_size[0]), patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((image_size[1], image_size[0]), patch_size) expected_num_patches = n_patches_x * n_patches_y # The actual number of patches includes the global image patch (+1) # So we compare with expected + 1 - assert ( - actual_num_patches == expected_num_patches + 1 - ), f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}" + assert actual_num_patches == expected_num_patches + 1, ( + f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}" + ) def test_get_image_mask(processor_from_pretrained: ColIdefics3Processor): diff --git a/tests/models/modernvbert/test_interpretability_colmodernvbert.py b/tests/models/modernvbert/test_interpretability_colmodernvbert.py index e0b945b37..9913d4a30 100644 --- a/tests/models/modernvbert/test_interpretability_colmodernvbert.py +++ b/tests/models/modernvbert/test_interpretability_colmodernvbert.py @@ -28,9 +28,7 @@ def model_name() -> str: def processor_from_pretrained( model_name: str, ) -> Generator[ColModernVBertProcessor, None, None]: - yield cast( - ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name) - ) + yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)) @pytest.fixture(scope="module") @@ -41,79 +39,55 @@ def model_from_pretrained(model_name: str) -> Generator[ColModernVBert, None, No class TestGetNPatches: """Test the get_n_patches method for calculating patch dimensions.""" - def test_get_n_patches_returns_integers( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_returns_integers(self, processor_from_pretrained: ColModernVBertProcessor): """Test that get_n_patches returns integer values.""" patch_size = 14 # Common patch size for vision transformers image_size = (100, 100) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert isinstance(n_patches_x, int) assert isinstance(n_patches_y, int) assert n_patches_x > 0 assert n_patches_y > 0 - def test_get_n_patches_wide_image( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_wide_image(self, processor_from_pretrained: ColModernVBertProcessor): """Test that wide images have more patches along width.""" patch_size = 14 image_size = (100, 200) # (height, width) - wider than tall - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) # n_patches_x is along width, n_patches_y is along height - assert ( - n_patches_x >= n_patches_y - ), f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}" + assert n_patches_x >= n_patches_y, f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}" - def test_get_n_patches_tall_image( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_tall_image(self, processor_from_pretrained: ColModernVBertProcessor): """Test that tall images have more patches along height.""" patch_size = 14 image_size = (200, 100) # (height, width) - taller than wide - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) - assert ( - n_patches_y >= n_patches_x - ), f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}" + assert n_patches_y >= n_patches_x, f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}" - def test_get_n_patches_square_image( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_square_image(self, processor_from_pretrained: ColModernVBertProcessor): """Test that square images have equal patches in both dimensions.""" patch_size = 14 image_size = (500, 500) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) - assert ( - n_patches_x == n_patches_y - ), f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}" + assert n_patches_x == n_patches_y, ( + f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}" + ) - def test_get_n_patches_aspect_ratio_preservation( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_aspect_ratio_preservation(self, processor_from_pretrained: ColModernVBertProcessor): """Test that aspect ratio is approximately preserved in patch dimensions.""" patch_size = 14 # Test with a 2:1 aspect ratio image image_size = (300, 600) # height=300, width=600 - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) # The aspect ratio of patches should be close to 2:1 patch_ratio = n_patches_x / n_patches_y @@ -124,15 +98,13 @@ def test_get_n_patches_aspect_ratio_preservation( # 2. Even-dimension rounding in resize logic # 3. Ceiling division in patch calculations # These factors can cause ~25% deviation from the ideal aspect ratio - assert 1.5 <= patch_ratio <= 2.5, f"Expected ~2:1 ratio, got {patch_ratio:.2f}" + assert 1.5 <= patch_ratio <= 2.5, f"Expected ~{expected_ratio}:1 ratio, got {patch_ratio:.2f}" class TestGetImageMask: """Test the get_image_mask method for identifying image tokens.""" - def test_get_image_mask_shape( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_image_mask_shape(self, processor_from_pretrained: ColModernVBertProcessor): """Test that image mask has the same shape as input_ids.""" image = Image.new("RGB", (64, 32), color="red") batch_feature = processor_from_pretrained.process_images([image]) @@ -142,9 +114,7 @@ def test_get_image_mask_shape( assert image_mask.shape == batch_feature.input_ids.shape assert image_mask.dtype == torch.bool - def test_get_image_mask_has_image_tokens( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_image_mask_has_image_tokens(self, processor_from_pretrained: ColModernVBertProcessor): """Test that the mask identifies some image tokens.""" image = Image.new("RGB", (64, 32), color="blue") batch_feature = processor_from_pretrained.process_images([image]) @@ -152,13 +122,9 @@ def test_get_image_mask_has_image_tokens( image_mask = processor_from_pretrained.get_image_mask(batch_feature) # There should be image tokens present - assert ( - image_mask.sum() > 0 - ), "Expected to find image tokens in the processed batch" + assert image_mask.sum() > 0, "Expected to find image tokens in the processed batch" - def test_get_image_mask_batch_consistency( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_image_mask_batch_consistency(self, processor_from_pretrained: ColModernVBertProcessor): """Test that image mask works correctly with batched images.""" images = [ Image.new("RGB", (64, 32), color="red"), @@ -197,14 +163,13 @@ def test_similarity_maps_shape( # Get patch size from the model or processor # ModernVBert uses patch_size from its config - patch_size = ( - 14 # Default for many vision transformers (unused but required for API) - ) + patch_size = 14 # Default for many vision transformers (unused but required for API) # Calculate expected patches # Note: image_size for get_n_patches is (height, width) n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - (image_size_pil[1], image_size_pil[0]), patch_size # (height, width) + (image_size_pil[1], image_size_pil[0]), + patch_size, # (height, width) ) # Get embeddings @@ -230,9 +195,9 @@ def test_similarity_maps_shape( # similarity_maps[0] should have shape (query_tokens, n_patches_x, n_patches_y) expected_shape = (query_length, n_patches_x, n_patches_y) - assert ( - similarity_maps[0].shape == expected_shape - ), f"Expected shape {expected_shape}, got {similarity_maps[0].shape}" + assert similarity_maps[0].shape == expected_shape, ( + f"Expected shape {expected_shape}, got {similarity_maps[0].shape}" + ) @pytest.mark.slow def test_similarity_maps_values( @@ -248,9 +213,7 @@ def test_similarity_maps_values( batch_queries = processor_from_pretrained.process_texts([query]) patch_size = 14 - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - (64, 64), patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((64, 64), patch_size) with torch.no_grad(): image_embeddings = model_from_pretrained(**batch_images) @@ -274,9 +237,7 @@ def test_similarity_maps_values( # After normalization, values should be in [0, 1] assert normalized_map.min() >= 0.0 assert normalized_map.max() <= 1.0 - assert ( - normalized_map.max() == 1.0 - ) # Max should be exactly 1.0 after normalization + assert normalized_map.max() == 1.0 # Max should be exactly 1.0 after normalization @pytest.mark.slow def test_patch_count_matches_mask_count( @@ -303,9 +264,9 @@ def test_patch_count_matches_mask_count( expected_local_patches = n_patches_x * n_patches_y # LOCAL tokens should match exactly - assert ( - actual_local_tokens == expected_local_patches - ), f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}" + assert actual_local_tokens == expected_local_patches, ( + f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}" + ) @pytest.mark.slow def test_global_patch_excluded( @@ -326,9 +287,9 @@ def test_global_patch_excluded( # The difference should be exactly image_seq_len (global patch tokens) image_seq_len = processor_from_pretrained.image_seq_len - assert ( - full_count - local_count == image_seq_len - ), f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}" + assert full_count - local_count == image_seq_len, ( + f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}" + ) class TestInterpretabilityConsistency: From d8ee1ec1b7d67a5f64e7a8f85547e7668eaf4627 Mon Sep 17 00:00:00 2001 From: Huang Xin Date: Fri, 26 Dec 2025 20:30:38 +0800 Subject: [PATCH 5/7] Bump the transformer versions for ColQwen3 support --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7fb2ee391..999dc7150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "scipy", "torch>=2.2.0,<2.9.0", "torchvision", - "transformers>=4.53.1,<4.58.0", + "transformers>=4.57.0,<4.58.0", ] [project.optional-dependencies] From f68b85f338645f90b73065c53261ba7861158873 Mon Sep 17 00:00:00 2001 From: Huang Xin Date: Fri, 26 Dec 2025 23:13:20 +0800 Subject: [PATCH 6/7] Fix lint error --- .../models/modernvbert/test_interpretability_colmodernvbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/modernvbert/test_interpretability_colmodernvbert.py b/tests/models/modernvbert/test_interpretability_colmodernvbert.py index 9913d4a30..cea4dd407 100644 --- a/tests/models/modernvbert/test_interpretability_colmodernvbert.py +++ b/tests/models/modernvbert/test_interpretability_colmodernvbert.py @@ -91,7 +91,7 @@ def test_get_n_patches_aspect_ratio_preservation(self, processor_from_pretrained # The aspect ratio of patches should be close to 2:1 patch_ratio = n_patches_x / n_patches_y - # expected_ratio = 2.0 + expected_ratio = 2.0 # Allow tolerance due to: # 1. Image splitting into 512x512 sub-patches (quantization effects) From c5244fee81772f2050319e4a27e589ecf6e7a98e Mon Sep 17 00:00:00 2001 From: Huang Xin Date: Sat, 27 Dec 2025 11:48:55 +0800 Subject: [PATCH 7/7] Update performance number --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 113055464..98c8127a5 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Using ColPali removes the need for potentially complex and brittle layout recogn | [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3 | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256). | ✅ | | [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8 | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`
• Supports dynamic resolution.
• Trained using 768 image patches per page and an effective batch size of 32. | ✅ | | [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4 | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters | ✅ | -| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | TBD | Apache 2.0 | • Based on the Qwen3-VL backbone.
• 320-dim ColBERT-style embeddings with dynamic resolution.
• Trained for multi-vector document retrieval. | ✅ | +| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | 90.6 | Apache 2.0 | • Based on the Qwen3-VL backbone.
• 320-dim ColBERT-style embeddings with dynamic resolution.
• Trained for multi-vector document retrieval. | ✅ | | [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M) | 80.1 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`. | ✅ | | [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M) | 82.3 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`. | ✅ |