diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a5a82d88..3c0e624b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Added + +- Add ColQwen3 and BiQwen3 support (model + processor). + +### Tests + +- Cover ColQwen3 processing and modeling with slow integration tests. + ## [0.3.13] - 2025-11-15 ### Added diff --git a/README.md b/README.md index 48355bfc0..98c8127a5 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ Using ColPali removes the need for potentially complex and brittle layout recogn | [vidore/colqwen2-v1.0](https://huggingface.co/vidore/colqwen2-v1.0) | 89.3 | Apache 2.0 | • Similar to `vidore/colqwen2-v0.1`, but trained with more powerful GPUs and with a larger effective batch size (256). | ✅ | | [vidore/colqwen2.5-v0.1](https://huggingface.co/vidore/colqwen2.5-v0.1) | 88.8 | Apache 2.0 | • Based on `Qwen/Qwen2 5-VL-3B-Instruct`
• Supports dynamic resolution.
• Trained using 768 image patches per page and an effective batch size of 32. | ✅ | | [vidore/colqwen2.5-v0.2](https://huggingface.co/vidore/colqwen2.5-v0.2) | 89.4 | Apache 2.0 | • Similar to `vidore/colqwen2.5-v0.1`, but trained with slightly different hyper parameters | ✅ | +| [TomoroAI/tomoro-colqwen3-embed-4b](https://huggingface.co/TomoroAI/tomoro-colqwen3-embed-4b) | 90.6 | Apache 2.0 | • Based on the Qwen3-VL backbone.
• 320-dim ColBERT-style embeddings with dynamic resolution.
• Trained for multi-vector document retrieval. | ✅ | | [vidore/colSmol-256M](https://huggingface.co/vidore/colSmol-256M) | 80.1 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-256M-Instruct`. | ✅ | | [vidore/colSmol-500M](https://huggingface.co/vidore/colSmol-500M) | 82.3 | Apache 2.0 | • Based on `HuggingFaceTB/SmolVLM-500M-Instruct`. | ✅ | diff --git a/colpali_engine/__init__.py b/colpali_engine/__init__.py index f72652aed..1f5098f74 100644 --- a/colpali_engine/__init__.py +++ b/colpali_engine/__init__.py @@ -7,6 +7,8 @@ BiQwen2_5, BiQwen2_5_Processor, BiQwen2Processor, + BiQwen3, + BiQwen3Processor, ColIdefics3, ColIdefics3Processor, ColModernVBert, @@ -19,4 +21,6 @@ ColQwen2_5Omni, ColQwen2_5OmniProcessor, ColQwen2Processor, + ColQwen3, + ColQwen3Processor, ) diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index 1129e1612..b98544fcb 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -3,4 +3,5 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor +from .qwen3 import BiQwen3, BiQwen3Processor, ColQwen3, ColQwen3Processor from .qwen_omni import ColQwen2_5Omni, ColQwen2_5OmniProcessor diff --git a/colpali_engine/models/qwen3/__init__.py b/colpali_engine/models/qwen3/__init__.py new file mode 100644 index 000000000..efcee26f4 --- /dev/null +++ b/colpali_engine/models/qwen3/__init__.py @@ -0,0 +1,2 @@ +from .biqwen3 import BiQwen3, BiQwen3Processor +from .colqwen3 import ColQwen3, ColQwen3Processor diff --git a/colpali_engine/models/qwen3/biqwen3/__init__.py b/colpali_engine/models/qwen3/biqwen3/__init__.py new file mode 100644 index 000000000..1b7881d83 --- /dev/null +++ b/colpali_engine/models/qwen3/biqwen3/__init__.py @@ -0,0 +1,2 @@ +from .modeling_biqwen3 import BiQwen3 +from .processing_biqwen3 import BiQwen3Processor diff --git a/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py b/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py new file mode 100644 index 000000000..fc4804e4b --- /dev/null +++ b/colpali_engine/models/qwen3/biqwen3/modeling_biqwen3.py @@ -0,0 +1,94 @@ +from typing import ClassVar, Literal + +import torch +from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel + + +class BiQwen3(Qwen3VLModel): + """ + BiQwen3 implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + Representations are pooled to obtain a single vector representation. Based on the Qwen3-VL backbone. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" + _checkpoint_conversion_mapping = { + r"^model\.visual": "visual", + r"^model\.language_model": "language_model", + r"^model\.": "", + } + + def __init__(self, config: Qwen3VLConfig, **kwargs): + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + attn_impl = kwargs.pop("attn_implementation", None) + use_cache = kwargs.pop("use_cache", None) + + super().__init__(config=config) + self.padding_side = "left" + self.post_init() + + if dtype is not None: + self.to(dtype=dtype) + if use_cache is not None: + self.config.use_cache = use_cache + if attn_impl is not None and hasattr(self, "set_attn_implementation"): + self.set_attn_implementation(attn_impl) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward( + self, + pooling_strategy: Literal["cls", "last", "mean"] = "last", + *args, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass for BiQwen3 model. + + Args: + pooling_strategy: The strategy to use for pooling the hidden states. + *args: Variable length argument list. + **kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: Dense embeddings (batch_size, hidden_size). + """ + if "pixel_values" in kwargs: + offsets = kwargs["image_grid_thw"].prod(dim=1).tolist() + kwargs["pixel_values"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], + dim=0, + ) + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) # (batch_size, sequence_length, hidden_size) + + if pooling_strategy == "cls": + pooled = last_hidden_states[:, 0] + elif pooling_strategy == "last": + pooled = last_hidden_states[:, -1] + elif pooling_strategy == "mean": + mask = kwargs["attention_mask"].unsqueeze(-1) + pooled = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) + else: + raise ValueError(f"Invalid pooling strategy: {pooling_strategy}") + + return pooled / pooled.norm(dim=-1, keepdim=True) + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py b/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py new file mode 100644 index 000000000..c13eaff0d --- /dev/null +++ b/colpali_engine/models/qwen3/biqwen3/processing_biqwen3.py @@ -0,0 +1,37 @@ +from typing import List, Optional, Union + +import torch +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.qwen3.colqwen3 import ColQwen3Processor + + +class BiQwen3Processor(ColQwen3Processor): + """ + Processor for BiQwen3. + """ + + def process_texts( + self, + texts: List[str], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for BiQwen3. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the cosine similarity for the given query and passage embeddings. + """ + return self.score_single_vector(qs, ps, device=device) diff --git a/colpali_engine/models/qwen3/colqwen3/__init__.py b/colpali_engine/models/qwen3/colqwen3/__init__.py new file mode 100644 index 000000000..6369cb69b --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/__init__.py @@ -0,0 +1,2 @@ +from .modeling_colqwen3 import ColQwen3 +from .processing_colqwen3 import ColQwen3Processor diff --git a/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py new file mode 100644 index 000000000..019c87bd5 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py @@ -0,0 +1,101 @@ +from typing import ClassVar + +import torch +from torch import nn +from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel + + +class ColQwen3(Qwen3VLModel): + """ + ColQwen3 model implementation, following the architecture from the article "ColPali: Efficient Document Retrieval + with Vision Language Models" paper. Based on the Qwen3-VL backbone. + + Args: + config (Qwen3VLConfig): The model configuration. + mask_non_image_embeddings (Optional[bool]): Whether to ignore all tokens embeddings + except those of the image at inference. + Defaults to False --> Do not mask any embeddings during forward pass. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related + _checkpoint_conversion_mapping = { + r"^model\.visual": "visual", + r"^model\.language_model": "language_model", + r"^model\.": "", + } + + def __init__( + self, + config: Qwen3VLConfig, + mask_non_image_embeddings: bool = False, + **kwargs, + ): + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + attn_impl = kwargs.pop("attn_implementation", None) + use_cache = kwargs.pop("use_cache", None) + + super().__init__(config=config) + + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "text_config"): + hidden_size = self.config.text_config.hidden_size + if hidden_size is None: + raise ValueError("Unable to determine text hidden size for Qwen3VLConfig.") + + self.dim = 320 + self.custom_text_proj = nn.Linear(hidden_size, self.dim) + self.padding_side = "left" + self.mask_non_image_embeddings = mask_non_image_embeddings + self.post_init() + + if dtype is not None: + self.to(dtype=dtype) + if use_cache is not None: + self.config.use_cache = use_cache + if attn_impl is not None and hasattr(self, "set_attn_implementation"): + self.set_attn_implementation(attn_impl) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None) + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward(self, *args, **kwargs) -> torch.Tensor: + # Handle the custom "pixel_values" input obtained with `ColQwen3Processor` through unpadding + if "pixel_values" in kwargs: + offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,) + kwargs["pixel_values"] = torch.cat( + [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)], + dim=0, + ) + + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) # (batch_size, sequence_length, hidden_size) + + proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) + + # L2 normalization + proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) + proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) + + if "pixel_values" in kwargs and self.mask_non_image_embeddings: + # Pools only the image embeddings + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + return proj + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size diff --git a/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py new file mode 100644 index 000000000..436c7336e --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py @@ -0,0 +1,154 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen3_vl import Qwen3VLProcessor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColQwen3Processor(BaseVisualRetrieverProcessor, Qwen3VLProcessor): + """ + Processor for ColQwen3. + + Args: + *args: Variable length argument list to be passed to the parent `Qwen3VLProcessor` class. + max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model. + **kwargs: Arbitrary keyword arguments to be passed to the parent `Qwen3VLProcessor` class. + """ + + visual_prompt_prefix: ClassVar[str] = ( + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + ) + query_augmentation_token: ClassVar[str] = "<|endoftext|>" + image_token: ClassVar[str] = "<|image_pad|>" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.tokenizer.padding_side = "left" + + @classmethod + def from_pretrained( + cls, + *args, + device_map: Optional[str] = None, + **kwargs, + ): + instance = super().from_pretrained( + *args, + device_map=device_map, + **kwargs, + ) + + if "max_num_visual_tokens" in kwargs: + patch_size = getattr(instance.image_processor, "patch_size", None) + merge_size = getattr(instance.image_processor, "merge_size", None) + if patch_size is None or merge_size is None: + raise ValueError("Qwen3VL image processor is missing `patch_size` or `merge_size`.") + tile = patch_size * merge_size + instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * tile * tile + instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels + + return instance + + def process_images( + self, + images: List[Image.Image], + ) -> Union[BatchFeature, BatchEncoding]: + """ + Process images for ColQwen3. + + Args: + images: List of PIL images. + """ + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=[self.visual_prompt_prefix] * len(images), + images=images, + padding="longest", + return_tensors="pt", + ) + + # NOTE: The following adjustment ensures correct behavior with DDP on multiple GPUs. + offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2] # (batch_size,) + + # Split the pixel_values tensor into a list of tensors, one per image + pixel_values = list( + torch.split(batch_doc["pixel_values"], offsets.tolist()) + ) # [(num_patches_image_0, pixel_values), ..., (num_patches_image_n, pixel_values)] + + # Pad the list of pixel_value tensors to the same length along the sequence dimension + batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence( + pixel_values, batch_first=True + ) # (batch_size, max_num_patches, pixel_values) + + return batch_doc + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process texts for ColQwen3. + + Args: + texts: List of input texts. + + Returns: + Union[BatchFeature, BatchEncoding]: Processed texts. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + """ + Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of + size (height, width) with the given patch size. + + The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in + as a `Qwen3VLForConditionalGeneration` attribute under `model.spatial_merge_size`. + """ + patch_size = self.image_processor.patch_size + + height_new, width_new = smart_resize( + width=image_size[0], + height=image_size[1], + factor=patch_size * self.image_processor.merge_size, + min_pixels=self.image_processor.size["shortest_edge"], + max_pixels=self.image_processor.size["longest_edge"], + ) + + n_patches_x = width_new // patch_size // spatial_merge_size + n_patches_y = height_new // patch_size // spatial_merge_size + + return n_patches_x, n_patches_y + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + """ + Get a tensor mask that identifies the image tokens in the batch. + """ + return batch_images.input_ids == self.image_token_id diff --git a/pyproject.toml b/pyproject.toml index 7fb2ee391..999dc7150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "scipy", "torch>=2.2.0,<2.9.0", "torchvision", - "transformers>=4.53.1,<4.58.0", + "transformers>=4.57.0,<4.58.0", ] [project.optional-dependencies] diff --git a/scripts/configs/qwen3/train_colqwen3_model.py b/scripts/configs/qwen3/train_colqwen3_model.py new file mode 100644 index 000000000..bef8f3df8 --- /dev/null +++ b/scripts/configs/qwen3/train_colqwen3_model.py @@ -0,0 +1,100 @@ +import argparse +import shutil +from pathlib import Path + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import TrainingArguments + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss +from colpali_engine.models import ColQwen3, ColQwen3Processor +from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining +from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig +from colpali_engine.utils.dataset_transformation import load_train_set + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--output-dir", type=str, required=True, help="where to write model + script copy") + p.add_argument("--lr", type=float, default=2e-4, help="learning rate") + p.add_argument("--tau", type=float, default=0.02, help="temperature for loss function") + p.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="trainer to use") + p.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="loss function to use") + p.add_argument("--peft", action="store_true", help="use PEFT for training") + + return p.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + if args.loss == "ce": + loss_func = ColbertLoss( + temperature=args.tau, + normalize_scores=True, + use_smooth_max=False, + pos_aware_negative_filtering=False, + ) + elif args.loss == "pairwise": + loss_func = ColbertPairwiseCELoss( + normalize_scores=False, + ) + else: + raise ValueError(f"Unknown loss function: {args.loss}") + + config = ColModelTrainingConfig( + output_dir=args.output_dir, + processor=ColQwen3Processor.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colqwen3-base", + max_num_visual_tokens=768, + ), + model=ColQwen3.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colqwen3-base", + torch_dtype=torch.bfloat16, + use_cache=False, + attn_implementation="flash_attention_2", + ), + train_dataset=load_train_set(), + eval_dataset=ColPaliEngineDataset( + load_dataset("./data_dir/colpali_train_set", split="test"), pos_target_column_name="image" + ), + run_eval=True, + loss_func=loss_func, + tr_args=TrainingArguments( + output_dir=None, + overwrite_output_dir=True, + num_train_epochs=5, + per_device_train_batch_size=64, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + per_device_eval_batch_size=16, + eval_strategy="steps", + dataloader_num_workers=8, + save_steps=500, + logging_steps=10, + eval_steps=100, + warmup_steps=100, + learning_rate=args.lr, + save_total_limit=1, + ), + peft_config=LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights="gaussian", + bias="none", + task_type="FEATURE_EXTRACTION", + target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)", + ) + if args.peft + else None, + ) + + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name) + + trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config) + trainer.train() + trainer.save() diff --git a/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py b/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py new file mode 100644 index 000000000..239713d6a --- /dev/null +++ b/tests/models/qwen3/colqwen3/test_modeling_colqwen3.py @@ -0,0 +1,135 @@ +import logging +from typing import Generator, cast + +import pytest +import torch +from datasets import load_dataset +from PIL import Image +from transformers.utils.import_utils import is_flash_attn_2_available + +from colpali_engine.models import ColQwen3, ColQwen3Processor +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "TomoroAI/tomoro-colqwen3-embed-4b" + + +@pytest.fixture(scope="module") +def model_without_mask(model_name: str) -> Generator[ColQwen3, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColQwen3, + ColQwen3.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + mask_non_image_embeddings=False, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def model_with_mask(model_name: str) -> Generator[ColQwen3, None, None]: + device = get_torch_device("auto") + logger.info(f"Device used: {device}") + + yield cast( + ColQwen3, + ColQwen3.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, + attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, + mask_non_image_embeddings=True, + ).eval(), + ) + tear_down_torch() + + +@pytest.fixture(scope="module") +def processor(model_name: str) -> Generator[ColQwen3Processor, None, None]: + yield cast(ColQwen3Processor, ColQwen3Processor.from_pretrained(model_name)) + + +class TestColQwen3Model: + @pytest.mark.slow + def test_load_model_from_pretrained(self, model_without_mask: ColQwen3): + assert isinstance(model_without_mask, ColQwen3) + + +class TestColQwen3ModelIntegration: + @pytest.mark.slow + def test_forward_images_integration( + self, + model_without_mask: ColQwen3, + processor: ColQwen3Processor, + ): + images = [ + Image.new("RGB", (64, 64), color="white"), + Image.new("RGB", (32, 32), color="black"), + ] + batch_images = processor.process_images(images).to(model_without_mask.device) + + with torch.no_grad(): + outputs = model_without_mask(**batch_images) + + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_visual_tokens, emb_dim = outputs.shape + assert batch_size == len(images) + assert n_visual_tokens >= 1 + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_forward_queries_integration( + self, + model_without_mask: ColQwen3, + processor: ColQwen3Processor, + ): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + batch_queries = processor.process_queries(queries).to(model_without_mask.device) + + with torch.no_grad(): + outputs = model_without_mask(**batch_queries) + + assert isinstance(outputs, torch.Tensor) + assert outputs.dim() == 3 + batch_size, n_query_tokens, emb_dim = outputs.shape + assert batch_size == len(queries) + assert n_query_tokens >= 1 + assert emb_dim == model_without_mask.dim + + @pytest.mark.slow + def test_retrieval_integration( + self, + model_without_mask: ColQwen3, + processor: ColQwen3Processor, + ): + ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test") + + batch_images = processor.process_images(images=ds["image"]).to(model_without_mask.device) + batch_queries = processor.process_queries(queries=ds["query"]).to(model_without_mask.device) + + with torch.inference_mode(): + image_embeddings = model_without_mask(**batch_images) + query_embeddings = model_without_mask(**batch_queries) + + scores = processor.score_multi_vector( + qs=query_embeddings, + ps=image_embeddings, + ) + + assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}" + assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" + assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() diff --git a/tests/models/qwen3/colqwen3/test_processing_colqwen3.py b/tests/models/qwen3/colqwen3/test_processing_colqwen3.py new file mode 100644 index 000000000..971afa578 --- /dev/null +++ b/tests/models/qwen3/colqwen3/test_processing_colqwen3.py @@ -0,0 +1,61 @@ +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from colpali_engine.models import ColQwen3Processor + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "TomoroAI/tomoro-colqwen3-embed-4b" + + +@pytest.fixture(scope="module") +def processor_from_pretrained(model_name: str) -> Generator[ColQwen3Processor, None, None]: + yield cast(ColQwen3Processor, ColQwen3Processor.from_pretrained(model_name)) + + +def test_load_processor_from_pretrained(processor_from_pretrained: ColQwen3Processor): + assert isinstance(processor_from_pretrained, ColQwen3Processor) + + +def test_process_images(processor_from_pretrained: ColQwen3Processor): + image_size = (64, 32) + image = Image.new("RGB", image_size, color="black") + images = [image] + + batch_feature = processor_from_pretrained.process_images(images) + + assert "pixel_values" in batch_feature + assert isinstance(batch_feature["pixel_values"], torch.Tensor) + assert batch_feature["pixel_values"].shape[0] == len(images) + assert batch_feature["pixel_values"].shape[1] >= 1 + assert batch_feature["pixel_values"].shape[-1] > 0 + + +def test_process_texts(processor_from_pretrained: ColQwen3Processor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + batch_encoding = processor_from_pretrained.process_texts(queries) + + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) + + +def test_process_queries(processor_from_pretrained: ColQwen3Processor): + queries = [ + "Is attention really all you need?", + "Are Benjamin, Antoine, Merve, and Jo best friends?", + ] + + batch_encoding = processor_from_pretrained.process_queries(queries) + + assert "input_ids" in batch_encoding + assert isinstance(batch_encoding["input_ids"], torch.Tensor) + assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries)