diff --git a/colpali_engine/models/__init__.py b/colpali_engine/models/__init__.py index cb9a71ace..3664b41e5 100644 --- a/colpali_engine/models/__init__.py +++ b/colpali_engine/models/__init__.py @@ -3,3 +3,4 @@ from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor from .qwen2 import BiQwen2, BiQwen2Processor, ColQwen2, ColQwen2Processor from .qwen2_5 import BiQwen2_5, BiQwen2_5_Processor, ColQwen2_5, ColQwen2_5_Processor +from .qwen3 import ColQwen3, ColQwen3MoE, ColQwen3MoEProcessor, ColQwen3Processor diff --git a/colpali_engine/models/qwen3/__init__.py b/colpali_engine/models/qwen3/__init__.py new file mode 100644 index 000000000..2dc947091 --- /dev/null +++ b/colpali_engine/models/qwen3/__init__.py @@ -0,0 +1,3 @@ +from .colqwen3 import ColQwen3, ColQwen3Processor +from .colqwen3_moe import ColQwen3MoE, ColQwen3MoEProcessor + diff --git a/colpali_engine/models/qwen3/colqwen3/__init__.py b/colpali_engine/models/qwen3/colqwen3/__init__.py new file mode 100644 index 000000000..5cb065754 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/__init__.py @@ -0,0 +1,3 @@ +from .modeling_colqwen3 import ColQwen3 +from .processing_colqwen3 import ColQwen3Processor + diff --git a/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py new file mode 100644 index 000000000..3455a3070 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/modeling_colqwen3.py @@ -0,0 +1,93 @@ +from typing import ClassVar + +import torch +from torch import nn +from transformers.models.qwen3_vl import Qwen3VLConfig, Qwen3VLModel + + +class ColQwen3(Qwen3VLModel): + """ + ColQwen3 model implementation, following the architecture described in the + "ColPali: Efficient Document Retrieval with Vision Language Models" paper. + This wrapper adapts the Qwen3-VL backbone for multi-vector document retrieval. + + Args: + config (Qwen3VLConfig): Model configuration. + mask_non_image_embeddings (bool): When ``True`` only image embeddings are kept in the output vectors. + Defaults to ``False`` meaning that all embeddings are returned. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" + + def __init__(self, config: Qwen3VLConfig, mask_non_image_embeddings: bool = False): + super().__init__(config=config) + self.dim = 128 + self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim) + self.padding_side = "left" + self.mask_non_image_embeddings = mask_non_image_embeddings + self.post_init() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = super()._checkpoint_conversion_mapping + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward(self, *args, **kwargs) -> torch.Tensor: + attention_mask = kwargs.get("attention_mask") + has_pixel_values = "pixel_values" in kwargs and kwargs["pixel_values"] is not None + + if has_pixel_values: + image_grid_thw = kwargs.get("image_grid_thw") + if image_grid_thw is None: + raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.") + + if not torch.is_tensor(image_grid_thw): + image_grid_thw = torch.as_tensor(image_grid_thw, device=kwargs["pixel_values"].device) + + offsets = image_grid_thw.prod(dim=1) + unpadded = [ + pixel_sequence[: int(offset.item())] + for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets) + ] + + if unpadded: + kwargs["pixel_values"] = torch.cat(unpadded, dim=0) + else: + kwargs["pixel_values"] = None + + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) + + proj = self.custom_text_proj(last_hidden_states) + proj = proj / proj.norm(dim=-1, keepdim=True) + + if attention_mask is not None: + proj = proj * attention_mask.unsqueeze(-1) + + if has_pixel_values and self.mask_non_image_embeddings and kwargs.get("input_ids") is not None: + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + + return proj + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size + + @property + def temporal_patch_size(self) -> int: + return getattr(self.visual.config, "temporal_patch_size", 1) + diff --git a/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py new file mode 100644 index 000000000..5c8d9d6b1 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3/processing_colqwen3.py @@ -0,0 +1,129 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen3_vl import Qwen3VLProcessor + +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + + +class ColQwen3Processor(BaseVisualRetrieverProcessor, Qwen3VLProcessor): + """ + Processor for ColQwen3. + + Args: + max_num_visual_tokens: Maximum number of visual tokens allowed during preprocessing. + *args: Variable positional arguments forwarded to :class:`~transformers.Qwen3VLProcessor`. + **kwargs: Keyword arguments forwarded to :class:`~transformers.Qwen3VLProcessor`. + """ + + visual_prompt_prefix: ClassVar[str] = ( + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>" + ) + query_augmentation_token: ClassVar[str] = "<|endoftext|>" + image_token: ClassVar[str] = "<|image_pad|>" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.tokenizer.padding_side = "left" + + @classmethod + def from_pretrained( + cls, + *args, + device_map: Optional[str] = None, + **kwargs, + ): + instance = super().from_pretrained( + *args, + device_map=device_map, + **kwargs, + ) + + if "max_num_visual_tokens" in kwargs: + patch_size = getattr(instance.image_processor, "patch_size", 28) + instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * patch_size * patch_size + instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels + + return instance + + def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]: + """ + Process a batch of PIL images for ColQwen3. + """ + + images = [image.convert("RGB") for image in images] + + batch_doc = self( + text=[self.visual_prompt_prefix] * len(images), + images=images, + padding="longest", + return_tensors="pt", + ) + + if batch_doc["pixel_values"].numel() == 0: + return batch_doc + + offsets = batch_doc["image_grid_thw"].prod(dim=1) + pixel_values = list( + torch.split(batch_doc["pixel_values"], offsets.tolist()) + ) # [(num_patches_img_0, patch_dim), ..., (num_patches_img_n, patch_dim)] + + batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True) + + return batch_doc + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + """ + Process a batch of raw texts for ColQwen3. + """ + return self( + text=texts, + return_tensors="pt", + padding="longest", + ) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + """ + Compute the MaxSim score (ColBERT-like) for query and passage embeddings. + """ + return self.score_multi_vector(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + """ + Compute the number of spatial patches for an image of ``image_size``. + """ + patch_size = self.image_processor.patch_size + merge_size = getattr(self.image_processor, "merge_size", 1) + + height_new, width_new = smart_resize( + width=image_size[0], + height=image_size[1], + factor=patch_size * merge_size, + min_pixels=self.image_processor.size["shortest_edge"], + max_pixels=self.image_processor.size["longest_edge"], + ) + + n_patches_x = width_new // patch_size // spatial_merge_size + n_patches_y = height_new // patch_size // spatial_merge_size + + return n_patches_x, n_patches_y + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + """ + Return a boolean tensor identifying image tokens inside ``batch_images``. + """ + return batch_images.input_ids == self.image_token_id + diff --git a/colpali_engine/models/qwen3/colqwen3_moe/__init__.py b/colpali_engine/models/qwen3/colqwen3_moe/__init__.py new file mode 100644 index 000000000..f76e87b80 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3_moe/__init__.py @@ -0,0 +1,3 @@ +from .modeling_colqwen3_moe import ColQwen3MoE +from .processing_colqwen3_moe import ColQwen3MoEProcessor + diff --git a/colpali_engine/models/qwen3/colqwen3_moe/modeling_colqwen3_moe.py b/colpali_engine/models/qwen3/colqwen3_moe/modeling_colqwen3_moe.py new file mode 100644 index 000000000..1c6d074c0 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3_moe/modeling_colqwen3_moe.py @@ -0,0 +1,91 @@ +from typing import ClassVar + +import torch +from torch import nn +from transformers.models.qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeModel + + +class ColQwen3MoE(Qwen3VLMoeModel): + """ + ColQwen3-MoE model implementation. This adapts the Qwen3-VL-MoE backbone to the ColPali multi-vector + retrieval setting. + + Args: + config (Qwen3VLMoeConfig): Model configuration. + mask_non_image_embeddings (bool): When ``True`` only image embeddings are preserved. + """ + + main_input_name: ClassVar[str] = "doc_input_ids" + + def __init__(self, config: Qwen3VLMoeConfig, mask_non_image_embeddings: bool = False): + super().__init__(config=config) + self.dim = 128 + self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim) + self.padding_side = "left" + self.mask_non_image_embeddings = mask_non_image_embeddings + self.post_init() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + key_mapping = kwargs.pop("key_mapping", None) + if key_mapping is None: + key_mapping = super()._checkpoint_conversion_mapping + return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping) + + def forward(self, *args, **kwargs) -> torch.Tensor: + attention_mask = kwargs.get("attention_mask") + has_pixel_values = "pixel_values" in kwargs and kwargs["pixel_values"] is not None + + if has_pixel_values: + image_grid_thw = kwargs.get("image_grid_thw") + if image_grid_thw is None: + raise ValueError("`image_grid_thw` must be provided when `pixel_values` is passed.") + + if not torch.is_tensor(image_grid_thw): + image_grid_thw = torch.as_tensor(image_grid_thw, device=kwargs["pixel_values"].device) + + offsets = image_grid_thw.prod(dim=1) + unpadded = [ + pixel_sequence[: int(offset.item())] + for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets) + ] + + if unpadded: + kwargs["pixel_values"] = torch.cat(unpadded, dim=0) + else: + kwargs["pixel_values"] = None + + kwargs.pop("return_dict", True) + kwargs.pop("output_hidden_states", None) + kwargs.pop("use_cache", None) + + last_hidden_states = ( + super() + .forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True) + .last_hidden_state + ) + + proj = self.custom_text_proj(last_hidden_states) + proj = proj / proj.norm(dim=-1, keepdim=True) + + if attention_mask is not None: + proj = proj * attention_mask.unsqueeze(-1) + + if has_pixel_values and self.mask_non_image_embeddings and kwargs.get("input_ids") is not None: + image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) + proj = proj * image_mask + + return proj + + @property + def patch_size(self) -> int: + return self.visual.config.patch_size + + @property + def spatial_merge_size(self) -> int: + return self.visual.config.spatial_merge_size + + @property + def temporal_patch_size(self) -> int: + return getattr(self.visual.config, "temporal_patch_size", 1) + diff --git a/colpali_engine/models/qwen3/colqwen3_moe/processing_colqwen3_moe.py b/colpali_engine/models/qwen3/colqwen3_moe/processing_colqwen3_moe.py new file mode 100644 index 000000000..c947e02e4 --- /dev/null +++ b/colpali_engine/models/qwen3/colqwen3_moe/processing_colqwen3_moe.py @@ -0,0 +1,54 @@ +from typing import ClassVar, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import BatchEncoding, BatchFeature + +from colpali_engine.models.qwen3.colqwen3.processing_colqwen3 import ColQwen3Processor + + +class ColQwen3MoEProcessor(ColQwen3Processor): + """ + Processor for the ColQwen3-MoE model variant. The MoE backbone shares the same vision and text + preprocessing pipeline as the dense Qwen3-VL models, but exposing a dedicated class keeps the API + symmetric with the available ColPali wrappers. + """ + + moe_variant: ClassVar[bool] = True + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def from_pretrained( + cls, + *args, + device_map: Optional[str] = None, + **kwargs, + ): + return super().from_pretrained(*args, device_map=device_map, **kwargs) + + def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]: + return super().process_images(images) + + def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]: + return super().process_texts(texts) + + def score( + self, + qs: List[torch.Tensor], + ps: List[torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> torch.Tensor: + return super().score(qs, ps, device=device, **kwargs) + + def get_n_patches( + self, + image_size: Tuple[int, int], + spatial_merge_size: int, + ) -> Tuple[int, int]: + return super().get_n_patches(image_size, spatial_merge_size) + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + return super().get_image_mask(batch_images) diff --git a/pyproject.toml b/pyproject.toml index 7fb2ee391..cb5ece124 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "scipy", "torch>=2.2.0,<2.9.0", "torchvision", - "transformers>=4.53.1,<4.58.0", + "transformers>=4.57.1,<4.58.0", ] [project.optional-dependencies] diff --git a/scripts/configs/qwen3/train_colqwen3_model.py b/scripts/configs/qwen3/train_colqwen3_model.py new file mode 100644 index 000000000..3eb7d9ddd --- /dev/null +++ b/scripts/configs/qwen3/train_colqwen3_model.py @@ -0,0 +1,101 @@ +import argparse +import shutil +from pathlib import Path + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import TrainingArguments + +from colpali_engine.data.dataset import ColPaliEngineDataset +from colpali_engine.loss.late_interaction_losses import ColbertLoss, ColbertPairwiseCELoss +from colpali_engine.models import ColQwen3, ColQwen3Processor +from colpali_engine.trainer.colmodel_torch_training import ColModelTorchTraining +from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig +from colpali_engine.utils.dataset_transformation import load_train_set + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=str, required=True, help="Where to write model + script copy") + parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") + parser.add_argument("--tau", type=float, default=0.02, help="Temperature for loss function") + parser.add_argument("--trainer", type=str, default="hf", choices=["torch", "hf"], help="Trainer to use") + parser.add_argument("--loss", type=str, default="ce", choices=["ce", "pairwise"], help="Loss function to use") + parser.add_argument("--peft", action="store_true", help="Use PEFT for training") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + if args.loss == "ce": + loss_func = ColbertLoss( + temperature=args.tau, + normalize_scores=True, + use_smooth_max=False, + pos_aware_negative_filtering=False, + ) + elif args.loss == "pairwise": + loss_func = ColbertPairwiseCELoss( + normalize_scores=False, + ) + else: + raise ValueError(f"Unknown loss function: {args.loss}") + + config = ColModelTrainingConfig( + output_dir=args.output_dir, + processor=ColQwen3Processor.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colqwen3-base", + max_num_visual_tokens=1024, + ), + model=ColQwen3.from_pretrained( + pretrained_model_name_or_path="./models/base_models/colqwen3-base", + torch_dtype=torch.bfloat16, + use_cache=False, + attn_implementation="flash_attention_2", + ), + train_dataset=load_train_set(), + eval_dataset=ColPaliEngineDataset( + load_dataset("./data_dir/colpali_train_set", split="test"), + pos_target_column_name="image", + ), + run_eval=True, + loss_func=loss_func, + tr_args=TrainingArguments( + output_dir=None, + overwrite_output_dir=True, + num_train_epochs=5, + per_device_train_batch_size=64, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + per_device_eval_batch_size=16, + eval_strategy="steps", + dataloader_num_workers=8, + save_steps=500, + logging_steps=10, + eval_steps=100, + warmup_steps=100, + learning_rate=args.lr, + save_total_limit=1, + ), + peft_config=LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.1, + init_lora_weights="gaussian", + bias="none", + task_type="FEATURE_EXTRACTION", + target_modules="(.*(model)(?!.*visual).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)", + ) + if args.peft + else None, + ) + + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + shutil.copy(Path(__file__), Path(config.output_dir) / Path(__file__).name) + + trainer = ColModelTraining(config) if args.trainer == "hf" else ColModelTorchTraining(config) + trainer.train() + trainer.save() +