diff --git a/colpali_engine/interpretability/similarity_map_utils.py b/colpali_engine/interpretability/similarity_map_utils.py index 9ab6f192..8b2f5e8d 100644 --- a/colpali_engine/interpretability/similarity_map_utils.py +++ b/colpali_engine/interpretability/similarity_map_utils.py @@ -76,14 +76,14 @@ def normalize_similarity_map( if value_range is None: # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y) - min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min( - dim=-2, keepdim=True - )[0] # (1, 1) or (batch_size, 1, 1) + min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[ + 0 + ] # (1, 1) or (batch_size, 1, 1) # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y) - max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max( - dim=-2, keepdim=True - )[0] # (1, 1) or (batch_size, 1, 1) + max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[ + 0 + ] # (1, 1) or (batch_size, 1, 1) else: min_vals, max_vals = value_range broadcast_shape = (1,) * similarity_map.ndim diff --git a/colpali_engine/interpretability/similarity_maps.py b/colpali_engine/interpretability/similarity_maps.py index ce95d653..477942b4 100644 --- a/colpali_engine/interpretability/similarity_maps.py +++ b/colpali_engine/interpretability/similarity_maps.py @@ -43,10 +43,7 @@ def plot_similarity_map( # Normalize the similarity map and convert it to Pillow image similarity_map_array = ( - normalize_similarity_map(similarity_map, value_range=normalization_range) - .to(torch.float32) - .cpu() - .numpy() + normalize_similarity_map(similarity_map, value_range=normalization_range).to(torch.float32).cpu().numpy() ) # (n_patches_x, n_patches_y) # Reshape the similarity map to match the PIL shape convention diff --git a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py index acd4b0e3..8dea4292 100644 --- a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py +++ b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py @@ -22,9 +22,7 @@ class ColIdefics3Processor( query_augmentation_token: ClassVar[str] = "" image_token: ClassVar[str] = "" - visual_prompt_prefix: ClassVar[str] = ( - "<|im_start|>User:Describe the image.\nAssistant:" - ) + visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:Describe the image.\nAssistant:" def __init__(self, *args, image_seq_len=64, **kwargs): super().__init__(*args, image_seq_len=image_seq_len, **kwargs) @@ -105,9 +103,7 @@ def get_n_patches( longest_edge = self.image_processor.size.get("longest_edge", 4 * patch_size) # Step 1: Calculate resized dimensions using the mixin helper method - height_new, width_new = self._calculate_resized_dimensions( - image_size, longest_edge - ) + height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge) # Step 2: Calculate the number of patches in each direction # This mirrors the split_image logic from Idefics3ImageProcessor diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py index 0aa42aee..786c9433 100644 --- a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -110,9 +110,7 @@ def get_n_patches( longest_edge = self.image_processor.size.get("longest_edge", 2048) # Step 1: Calculate resized dimensions using the mixin helper method - height_new, width_new = self._calculate_resized_dimensions( - image_size, longest_edge - ) + height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge) # Step 2: Calculate number of sub-patches (512x512 patches) # This mirrors the split_image logic from Idefics3ImageProcessor diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 86ef191b..a25779e6 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -495,9 +495,7 @@ def get_similarity_maps_from_embeddings( # query: (query_tokens, dim) # image_grid: (n_patches_x, n_patches_y, dim) # result: (query_tokens, n_patches_x, n_patches_y) - similarity_map = torch.einsum( - "nk,ijk->nij", query_embeddings[idx], image_embedding_grid - ) + similarity_map = torch.einsum("nk,ijk->nij", query_embeddings[idx], image_embedding_grid) similarity_maps.append(similarity_map) diff --git a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py index 24185ba1..c26a5463 100644 --- a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py +++ b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py @@ -9,9 +9,9 @@ python examples/interpretability/colmodernvbert/simple_interpretability_example.py """ -from pathlib import Path import uuid -from typing import cast, Any +from pathlib import Path +from typing import Any, cast import matplotlib.pyplot as plt import torch @@ -33,9 +33,7 @@ def main(): print("Loading a real document from DocVQA dataset...") from datasets import load_dataset - dataset = load_dataset( - "vidore/docvqa_test_subsampled", split="test", streaming=True - ) + dataset = load_dataset("vidore/docvqa_test_subsampled", split="test", streaming=True) # streaming datasets may yield values that type checkers treat as Sequence; # cast to dict so string indexing (sample["image"]) is accepted by the type checker. sample = dict(next(iter(dataset))) @@ -81,9 +79,7 @@ def main(): ) # Get the similarity map for our input image - similarity_maps = similarity_maps_batch[ - 0 - ] # (query_length, n_patches_x, n_patches_y) + similarity_maps = similarity_maps_batch[0] # (query_length, n_patches_x, n_patches_y) print(f"Similarity map shape: {similarity_maps.shape}") # Get query tokens (filtering out special tokens) @@ -105,9 +101,7 @@ def main(): # Clean tokens for display (remove special characters that may cause encoding issues) display_tokens = [t.replace("Ġ", " ").replace("▁", " ") for t in filtered_tokens] print(f"\nQuery tokens: {display_tokens}") - print( - f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]" - ) + print(f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]") # Generate all similarity maps print("\nGenerating similarity maps for all tokens...") diff --git a/tests/models/idefics3/colidefics3/test_processing_colidefics3.py b/tests/models/idefics3/colidefics3/test_processing_colidefics3.py index fce61a17..21a28d4c 100644 --- a/tests/models/idefics3/colidefics3/test_processing_colidefics3.py +++ b/tests/models/idefics3/colidefics3/test_processing_colidefics3.py @@ -74,15 +74,11 @@ def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor): Test that get_n_patches returns the correct number of patches for various image sizes. """ # Get the patch size from the image processor - patch_size = processor_from_pretrained.image_processor.max_image_size.get( - "longest_edge", 512 - ) + patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512) # Test case 1: Small square image image_size = (100, 100) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert isinstance(n_patches_x, int) assert isinstance(n_patches_y, int) assert n_patches_x > 0 @@ -90,23 +86,17 @@ def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor): # Test case 2: Wide image (width > height) image_size = (100, 200) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert n_patches_x >= n_patches_y # More patches along width # Test case 3: Tall image (height > width) image_size = (200, 100) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert n_patches_y >= n_patches_x # More patches along height # Test case 4: Square image image_size = (500, 500) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert n_patches_x == n_patches_y # Equal patches for square image @@ -126,22 +116,18 @@ def test_get_n_patches_matches_actual_processing( actual_num_patches = batch_feature["pixel_values"].shape[1] # Get the patch size from the image processor - patch_size = processor_from_pretrained.image_processor.max_image_size.get( - "longest_edge", 512 - ) + patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512) # Calculate expected patches using get_n_patches # Note: image_size for get_n_patches is (height, width), but PIL uses (width, height) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - (image_size[1], image_size[0]), patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((image_size[1], image_size[0]), patch_size) expected_num_patches = n_patches_x * n_patches_y # The actual number of patches includes the global image patch (+1) # So we compare with expected + 1 - assert ( - actual_num_patches == expected_num_patches + 1 - ), f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}" + assert actual_num_patches == expected_num_patches + 1, ( + f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}" + ) def test_get_image_mask(processor_from_pretrained: ColIdefics3Processor): diff --git a/tests/models/modernvbert/test_interpretability_colmodernvbert.py b/tests/models/modernvbert/test_interpretability_colmodernvbert.py index 31030009..cea4dd40 100644 --- a/tests/models/modernvbert/test_interpretability_colmodernvbert.py +++ b/tests/models/modernvbert/test_interpretability_colmodernvbert.py @@ -13,10 +13,10 @@ import torch from PIL import Image -from colpali_engine.models import ColModernVBert, ColModernVBertProcessor from colpali_engine.interpretability.similarity_map_utils import ( normalize_similarity_map, ) +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor @pytest.fixture(scope="module") @@ -28,9 +28,7 @@ def model_name() -> str: def processor_from_pretrained( model_name: str, ) -> Generator[ColModernVBertProcessor, None, None]: - yield cast( - ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name) - ) + yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)) @pytest.fixture(scope="module") @@ -41,79 +39,55 @@ def model_from_pretrained(model_name: str) -> Generator[ColModernVBert, None, No class TestGetNPatches: """Test the get_n_patches method for calculating patch dimensions.""" - def test_get_n_patches_returns_integers( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_returns_integers(self, processor_from_pretrained: ColModernVBertProcessor): """Test that get_n_patches returns integer values.""" patch_size = 14 # Common patch size for vision transformers image_size = (100, 100) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) assert isinstance(n_patches_x, int) assert isinstance(n_patches_y, int) assert n_patches_x > 0 assert n_patches_y > 0 - def test_get_n_patches_wide_image( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_wide_image(self, processor_from_pretrained: ColModernVBertProcessor): """Test that wide images have more patches along width.""" patch_size = 14 image_size = (100, 200) # (height, width) - wider than tall - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) # n_patches_x is along width, n_patches_y is along height - assert ( - n_patches_x >= n_patches_y - ), f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}" + assert n_patches_x >= n_patches_y, f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}" - def test_get_n_patches_tall_image( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_tall_image(self, processor_from_pretrained: ColModernVBertProcessor): """Test that tall images have more patches along height.""" patch_size = 14 image_size = (200, 100) # (height, width) - taller than wide - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) - assert ( - n_patches_y >= n_patches_x - ), f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}" + assert n_patches_y >= n_patches_x, f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}" - def test_get_n_patches_square_image( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_square_image(self, processor_from_pretrained: ColModernVBertProcessor): """Test that square images have equal patches in both dimensions.""" patch_size = 14 image_size = (500, 500) - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) - assert ( - n_patches_x == n_patches_y - ), f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}" + assert n_patches_x == n_patches_y, ( + f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}" + ) - def test_get_n_patches_aspect_ratio_preservation( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_n_patches_aspect_ratio_preservation(self, processor_from_pretrained: ColModernVBertProcessor): """Test that aspect ratio is approximately preserved in patch dimensions.""" patch_size = 14 # Test with a 2:1 aspect ratio image image_size = (300, 600) # height=300, width=600 - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - image_size, patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size) # The aspect ratio of patches should be close to 2:1 patch_ratio = n_patches_x / n_patches_y @@ -124,15 +98,13 @@ def test_get_n_patches_aspect_ratio_preservation( # 2. Even-dimension rounding in resize logic # 3. Ceiling division in patch calculations # These factors can cause ~25% deviation from the ideal aspect ratio - assert 1.5 <= patch_ratio <= 2.5, f"Expected ~2:1 ratio, got {patch_ratio:.2f}" + assert 1.5 <= patch_ratio <= 2.5, f"Expected ~{expected_ratio}:1 ratio, got {patch_ratio:.2f}" class TestGetImageMask: """Test the get_image_mask method for identifying image tokens.""" - def test_get_image_mask_shape( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_image_mask_shape(self, processor_from_pretrained: ColModernVBertProcessor): """Test that image mask has the same shape as input_ids.""" image = Image.new("RGB", (64, 32), color="red") batch_feature = processor_from_pretrained.process_images([image]) @@ -142,9 +114,7 @@ def test_get_image_mask_shape( assert image_mask.shape == batch_feature.input_ids.shape assert image_mask.dtype == torch.bool - def test_get_image_mask_has_image_tokens( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_image_mask_has_image_tokens(self, processor_from_pretrained: ColModernVBertProcessor): """Test that the mask identifies some image tokens.""" image = Image.new("RGB", (64, 32), color="blue") batch_feature = processor_from_pretrained.process_images([image]) @@ -152,13 +122,9 @@ def test_get_image_mask_has_image_tokens( image_mask = processor_from_pretrained.get_image_mask(batch_feature) # There should be image tokens present - assert ( - image_mask.sum() > 0 - ), "Expected to find image tokens in the processed batch" + assert image_mask.sum() > 0, "Expected to find image tokens in the processed batch" - def test_get_image_mask_batch_consistency( - self, processor_from_pretrained: ColModernVBertProcessor - ): + def test_get_image_mask_batch_consistency(self, processor_from_pretrained: ColModernVBertProcessor): """Test that image mask works correctly with batched images.""" images = [ Image.new("RGB", (64, 32), color="red"), @@ -197,14 +163,13 @@ def test_similarity_maps_shape( # Get patch size from the model or processor # ModernVBert uses patch_size from its config - patch_size = ( - 14 # Default for many vision transformers (unused but required for API) - ) + patch_size = 14 # Default for many vision transformers (unused but required for API) # Calculate expected patches # Note: image_size for get_n_patches is (height, width) n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - (image_size_pil[1], image_size_pil[0]), patch_size # (height, width) + (image_size_pil[1], image_size_pil[0]), + patch_size, # (height, width) ) # Get embeddings @@ -230,9 +195,9 @@ def test_similarity_maps_shape( # similarity_maps[0] should have shape (query_tokens, n_patches_x, n_patches_y) expected_shape = (query_length, n_patches_x, n_patches_y) - assert ( - similarity_maps[0].shape == expected_shape - ), f"Expected shape {expected_shape}, got {similarity_maps[0].shape}" + assert similarity_maps[0].shape == expected_shape, ( + f"Expected shape {expected_shape}, got {similarity_maps[0].shape}" + ) @pytest.mark.slow def test_similarity_maps_values( @@ -248,9 +213,7 @@ def test_similarity_maps_values( batch_queries = processor_from_pretrained.process_texts([query]) patch_size = 14 - n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( - (64, 64), patch_size - ) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((64, 64), patch_size) with torch.no_grad(): image_embeddings = model_from_pretrained(**batch_images) @@ -274,9 +237,7 @@ def test_similarity_maps_values( # After normalization, values should be in [0, 1] assert normalized_map.min() >= 0.0 assert normalized_map.max() <= 1.0 - assert ( - normalized_map.max() == 1.0 - ) # Max should be exactly 1.0 after normalization + assert normalized_map.max() == 1.0 # Max should be exactly 1.0 after normalization @pytest.mark.slow def test_patch_count_matches_mask_count( @@ -303,9 +264,9 @@ def test_patch_count_matches_mask_count( expected_local_patches = n_patches_x * n_patches_y # LOCAL tokens should match exactly - assert ( - actual_local_tokens == expected_local_patches - ), f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}" + assert actual_local_tokens == expected_local_patches, ( + f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}" + ) @pytest.mark.slow def test_global_patch_excluded( @@ -326,9 +287,9 @@ def test_global_patch_excluded( # The difference should be exactly image_seq_len (global patch tokens) image_seq_len = processor_from_pretrained.image_seq_len - assert ( - full_count - local_count == image_seq_len - ), f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}" + assert full_count - local_count == image_seq_len, ( + f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}" + ) class TestInterpretabilityConsistency: