Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
29cef8a
feat(gemma3): Add BiGemma3 and ColGemma3 models with processors
adithya-s-k Dec 3, 2025
4b902fe
feat(gemma3): Enhance BiGemma3 and ColGemma3 models with embedding di…
adithya-s-k Dec 3, 2025
563d659
Add BiGemma3 Matryoshka Embeddings Test Suite with vLLM support
adithya-s-k Dec 3, 2025
b2a7e4c
feat(gemma3): remove deprecated offline and serving scripts for Gemma…
adithya-s-k Dec 3, 2025
d26fc8d
Add HuggingFace and vLLM serving scripts with GPU memory snapshots
adithya-s-k Dec 3, 2025
c7701fc
Remove test script for BiGemma3 with Matryoshka embeddings using vLLM
Dec 4, 2025
a05d314
fix(gemma3): Update BiGemma3 initialization to remove embedding_dim p…
adithya-s-k Dec 6, 2025
89d0ddf
Merge pull request #1 from adithya-s-k/inference_demo
adithya-s-k Dec 6, 2025
0efedfa
feat(interpretability): add example for generating ColGemma3 interpre…
adithya-s-k Dec 11, 2025
da3ebc9
refactor(interpretability): streamline imports and improve code reada…
adithya-s-k Dec 11, 2025
6629a55
fix(tests): update model name references to Nayana-cognitivelab models
adithya-s-k Dec 18, 2025
2f0fd14
fix(tests): update model name references to Cognitive-Lab models
adithya-s-k Dec 18, 2025
00e284e
fix(tests): add interpretability test and fixed bigemma test and veri…
adithya-s-k Dec 18, 2025
8441a46
feat(docs): add Cognitive-Lab model references to README
adithya-s-k Dec 18, 2025
b890412
Merge branch 'illuin-tech:main' into main
adithya-s-k Dec 20, 2025
0d54dad
fix(tests): remove unused import of List in test_interpretability_wor…
adithya-s-k Dec 27, 2025
77fa873
Add interpretability tests and ColNetraEmbed/NetraEmbed models - Adit…
adithya-s-k Dec 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colpali_engine/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .gemma3 import BiGemma3, BiGemmaProcessor3, ColGemma3, ColGemmaProcessor3
from .idefics3 import BiIdefics3, BiIdefics3Processor, ColIdefics3, ColIdefics3Processor
from .modernvbert import BiModernVBert, BiModernVBertProcessor, ColModernVBert, ColModernVBertProcessor
from .paligemma import BiPali, BiPaliProcessor, BiPaliProj, ColPali, ColPaliProcessor
Expand Down
2 changes: 2 additions & 0 deletions colpali_engine/models/gemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .bigemma3 import BiGemma3, BiGemmaProcessor3
from .colgemma3 import ColGemma3, ColGemmaProcessor3
2 changes: 2 additions & 0 deletions colpali_engine/models/gemma3/bigemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling_bigemma import BiGemma3
from .processing_bigemma import BiGemmaProcessor3
88 changes: 88 additions & 0 deletions colpali_engine/models/gemma3/bigemma3/modeling_bigemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import ClassVar, Literal

import torch
from transformers.models.gemma3 import Gemma3Config, Gemma3Model


class BiGemma3(Gemma3Model): # noqa: N801
"""
BiGemma3 is an implementation from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
Representations are pooled to obtain a single vector representation. Based on the Gemma3 backbone.

Supports Matryoshka embeddings with dimensions: 768, 1536, or 2560 (full).
The embedding dimension can be specified at inference time via the forward() method.
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(self, config: Gemma3Config):
super().__init__(config=config)
self.padding_side = "left"
self.post_init()

@classmethod
def from_pretrained(cls, *args, **kwargs):
# Remove embedding_dim if passed (backward compatibility)
kwargs.pop("embedding_dim", None)

key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = super()._checkpoint_conversion_mapping

# Load the model without embedding_dim parameter
model = super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)
return model

def forward(
self,
pooling_strategy: Literal["cls", "last", "mean"] = "last",
embedding_dim: int = 2560,
*args,
**kwargs,
) -> torch.Tensor:
"""
Forward pass for BiGemma3 model.

Args:
pooling_strategy: The strategy to use for pooling the hidden states.
embedding_dim: Matryoshka dimension (768, 1536, or 2560). Default: 2560 (full).
*args: Variable length argument list.
**kwargs: Additional keyword arguments.

Returns:
torch.Tensor: Dense embeddings (batch_size, embedding_dim).
"""
# Validate embedding_dim
if embedding_dim not in [768, 1536, 2560]:
raise ValueError(f"embedding_dim must be one of [768, 1536, 2560], got {embedding_dim}")

kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)

outputs = super().forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
last_hidden_states = outputs.last_hidden_state # (batch_size, sequence_length, hidden_size)

# Get CLS token embedding, last token, or mean pool over sequence
if pooling_strategy == "cls":
# Use CLS token (first token) embedding
pooled_output = last_hidden_states[:, 0] # (batch_size, hidden_size)
elif pooling_strategy == "last":
# use last token since we are left padding
pooled_output = last_hidden_states[:, -1] # (batch_size, hidden_size)
elif pooling_strategy == "mean":
# Mean pooling over sequence length
if "attention_mask" in kwargs:
mask = kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, 1)
pooled_output = (last_hidden_states * mask).sum(dim=1) / mask.sum(dim=1) # (batch_size, hidden_size)
else:
pooled_output = last_hidden_states.mean(dim=1) # (batch_size, hidden_size)
else:
raise ValueError(f"Invalid pooling strategy: {pooling_strategy}")

# Matryoshka: slice to the desired embedding dimension
pooled_output = pooled_output[:, :embedding_dim] # (batch_size, embedding_dim)

# L2 normalization
pooled_output = torch.nn.functional.normalize(pooled_output, p=2, dim=-1)
return pooled_output
157 changes: 157 additions & 0 deletions colpali_engine/models/gemma3/bigemma3/processing_bigemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import ClassVar, List, Optional, Tuple, Union

import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
from transformers.models.gemma3 import Gemma3Processor

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor


class BiGemmaProcessor3(BaseVisualRetrieverProcessor, Gemma3Processor): # noqa: N801
"""
Processor for BiGemma.

Args:
*args: Variable length argument list to be passed to the parent `Gemma3Processor` class.
max_num_visual_tokens: The maximum number of visual tokens that can be processed by the model.
**kwargs: Arbitrary keyword arguments to be passed to the parent `Gemma3Processor` class.
"""

query_augmentation_token: ClassVar[str] = "<eos>"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.tokenizer.padding_side = "left"

@classmethod
def from_pretrained(
cls,
*args,
device_map: Optional[str] = None,
**kwargs,
):
instance = super().from_pretrained(
*args,
device_map=device_map,
**kwargs,
)

if "max_num_visual_tokens" in kwargs:
instance.image_processor.max_pixels = kwargs["max_num_visual_tokens"] * 56 * 56
instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels

return instance

def process_images(
self,
images: List[Image.Image],
) -> Union[BatchFeature, BatchEncoding]:
"""
Process images for BiGemma3.

Args:
images: List of PIL images.
"""
images = [image.convert("RGB") for image in images]

# Process each image using chat template
batch_docs = []
for image in images:
# Create message in chat format
message = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image"},
],
}
]

# Apply chat template to get formatted text
formatted_text = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)

# Process with formatted text
batch_doc = self(
text=[formatted_text],
images=[image],
padding="longest",
return_tensors="pt",
)
batch_docs.append(batch_doc)

if len(batch_docs) == 1:
return batch_docs[0]

# Concatenate results along batch dimension
concatenated = {}
for key in batch_docs[0].keys():
if isinstance(batch_docs[0][key], torch.Tensor):
concatenated[key] = torch.cat([doc[key] for doc in batch_docs], dim=0)
else:
# For non-tensors, take from first (assuming same)
concatenated[key] = batch_docs[0][key]
return BatchFeature(concatenated)

def process_texts(self, texts: List[str]) -> Union[BatchFeature, BatchEncoding]:
"""
Process texts for BiGemma3.

Args:
texts: List of input texts.

Returns:
Union[BatchFeature, BatchEncoding]: Processed texts.
"""
# Format each text using chat template
formatted_texts = []
for text in texts:
# Create message in chat format
message = [
{
"role": "user",
"content": [
{"type": "text", "text": f"Query: {text}"},
],
}
]

# Apply chat template to get formatted text
formatted_text = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)
formatted_texts.append(formatted_text)

return self(
text=formatted_texts,
return_tensors="pt",
padding="longest",
)

def get_n_patches(
self,
image_size: Tuple[int, int], # noqa: ARG002
patch_size: int, # noqa: ARG002
) -> Tuple[int, int]:
"""
Get the number of patches (n_patches_x, n_patches_y) for dense embedding.

For dense models like BiGemma, the entire image is embedded as a single vector,
so we return (1, 1) to indicate a single "patch" representing the whole image.
"""
return (1, 1)

def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs, # noqa: ARG002
) -> torch.Tensor:
"""
Compute the cosine similarity for the given query and passage embeddings.
"""
return self.score_single_vector(qs, ps, device=device)
4 changes: 4 additions & 0 deletions colpali_engine/models/gemma3/colgemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .modeling_colgemma import ColGemma3
from .processing_colgemma import ColGemmaProcessor3

__all__ = ["ColGemma3", "ColGemmaProcessor3"]
94 changes: 94 additions & 0 deletions colpali_engine/models/gemma3/colgemma3/modeling_colgemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
ColGemma3 Model - Implementation for late interaction retrieval.

This module implements ColGemma3 for late interaction retrieval, following
the ColQwen2 architecture pattern.

Key features:
- Direct inheritance from Gemma3Model for compatibility
- Custom projection layer for multi-vector embeddings
- MaxSim scoring support
"""

from typing import ClassVar

import torch
from torch import nn
from transformers.models.gemma3 import Gemma3Config, Gemma3Model


class ColGemma3(Gemma3Model):
"""
ColGemma3 model for late interaction retrieval.

This model extends Gemma3 to produce multi-vector embeddings suitable for
efficient document retrieval. Each input (image or text) is encoded into a
sequence of contextualized vectors, which can be compared using MaxSim scoring.

Args:
config (Gemma3Config): The model configuration.
mask_non_image_embeddings (bool, optional): Whether to ignore all tokens embeddings
except those of the image at inference. Defaults to False.

Example:
>>> model = ColGemma3.from_pretrained("google/gemma-3-4b-it")
>>> embeddings = model(input_ids=input_ids, attention_mask=attention_mask)
>>> print(embeddings.shape) # (batch_size, seq_len, 128)
"""

main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(
self,
config: Gemma3Config,
mask_non_image_embeddings: bool = False,
):
super().__init__(config=config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.dim)
self.padding_side = "left"
self.mask_non_image_embeddings = mask_non_image_embeddings
self.post_init()

@classmethod
def from_pretrained(cls, *args, **kwargs):
key_mapping = kwargs.pop("key_mapping", None)
if key_mapping is None:
key_mapping = super()._checkpoint_conversion_mapping
return super().from_pretrained(*args, **kwargs, key_mapping=key_mapping)

@property
def device(self):
"""Get the device of the model."""
return next(self.parameters()).device

@property
def dtype(self):
"""Get the dtype of the model."""
return next(self.parameters()).dtype

def forward(self, *args, **kwargs) -> torch.Tensor:
kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)
hidden_states = (
super()
.forward(*args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True)
.last_hidden_state
) # (batch_size, sequence_length, hidden_size)

proj = self.custom_text_proj(hidden_states) # (batch_size, sequence_length, dim)

# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)

if "pixel_values" in kwargs and self.mask_non_image_embeddings:
# Pools only the image embeddings
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask
return proj

@property
def patch_size(self) -> int:
return self.visual.config.patch_size
Loading
Loading