Skip to content
Merged

lint #368

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions colpali_engine/interpretability/similarity_map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def normalize_similarity_map(

if value_range is None:
# Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(
dim=-2, keepdim=True
)[0] # (1, 1) or (batch_size, 1, 1)
min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[
0
] # (1, 1) or (batch_size, 1, 1)

# Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(
dim=-2, keepdim=True
)[0] # (1, 1) or (batch_size, 1, 1)
max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[
0
] # (1, 1) or (batch_size, 1, 1)
else:
min_vals, max_vals = value_range
broadcast_shape = (1,) * similarity_map.ndim
Expand Down
5 changes: 1 addition & 4 deletions colpali_engine/interpretability/similarity_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def plot_similarity_map(

# Normalize the similarity map and convert it to Pillow image
similarity_map_array = (
normalize_similarity_map(similarity_map, value_range=normalization_range)
.to(torch.float32)
.cpu()
.numpy()
normalize_similarity_map(similarity_map, value_range=normalization_range).to(torch.float32).cpu().numpy()
) # (n_patches_x, n_patches_y)

# Reshape the similarity map to match the PIL shape convention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ class ColIdefics3Processor(

query_augmentation_token: ClassVar[str] = "<end_of_utterance>"
image_token: ClassVar[str] = "<image>"
visual_prompt_prefix: ClassVar[str] = (
"<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
)
visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"

def __init__(self, *args, image_seq_len=64, **kwargs):
super().__init__(*args, image_seq_len=image_seq_len, **kwargs)
Expand Down Expand Up @@ -105,9 +103,7 @@ def get_n_patches(
longest_edge = self.image_processor.size.get("longest_edge", 4 * patch_size)

# Step 1: Calculate resized dimensions using the mixin helper method
height_new, width_new = self._calculate_resized_dimensions(
image_size, longest_edge
)
height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge)

# Step 2: Calculate the number of patches in each direction
# This mirrors the split_image logic from Idefics3ImageProcessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def get_n_patches(
longest_edge = self.image_processor.size.get("longest_edge", 2048)

# Step 1: Calculate resized dimensions using the mixin helper method
height_new, width_new = self._calculate_resized_dimensions(
image_size, longest_edge
)
height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge)

# Step 2: Calculate number of sub-patches (512x512 patches)
# This mirrors the split_image logic from Idefics3ImageProcessor
Expand Down
4 changes: 1 addition & 3 deletions colpali_engine/utils/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,7 @@ def get_similarity_maps_from_embeddings(
# query: (query_tokens, dim)
# image_grid: (n_patches_x, n_patches_y, dim)
# result: (query_tokens, n_patches_x, n_patches_y)
similarity_map = torch.einsum(
"nk,ijk->nij", query_embeddings[idx], image_embedding_grid
)
similarity_map = torch.einsum("nk,ijk->nij", query_embeddings[idx], image_embedding_grid)

similarity_maps.append(similarity_map)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
python examples/interpretability/colmodernvbert/simple_interpretability_example.py
"""

from pathlib import Path
import uuid
from typing import cast, Any
from pathlib import Path
from typing import Any, cast

import matplotlib.pyplot as plt
import torch
Expand All @@ -33,9 +33,7 @@ def main():
print("Loading a real document from DocVQA dataset...")
from datasets import load_dataset

dataset = load_dataset(
"vidore/docvqa_test_subsampled", split="test", streaming=True
)
dataset = load_dataset("vidore/docvqa_test_subsampled", split="test", streaming=True)
# streaming datasets may yield values that type checkers treat as Sequence;
# cast to dict so string indexing (sample["image"]) is accepted by the type checker.
sample = dict(next(iter(dataset)))
Expand Down Expand Up @@ -81,9 +79,7 @@ def main():
)

# Get the similarity map for our input image
similarity_maps = similarity_maps_batch[
0
] # (query_length, n_patches_x, n_patches_y)
similarity_maps = similarity_maps_batch[0] # (query_length, n_patches_x, n_patches_y)
print(f"Similarity map shape: {similarity_maps.shape}")

# Get query tokens (filtering out special tokens)
Expand All @@ -105,9 +101,7 @@ def main():
# Clean tokens for display (remove special characters that may cause encoding issues)
display_tokens = [t.replace("Ġ", " ").replace("▁", " ") for t in filtered_tokens]
print(f"\nQuery tokens: {display_tokens}")
print(
f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]"
)
print(f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]")

# Generate all similarity maps
print("\nGenerating similarity maps for all tokens...")
Expand Down
34 changes: 10 additions & 24 deletions tests/models/idefics3/colidefics3/test_processing_colidefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,39 +74,29 @@ def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor):
Test that get_n_patches returns the correct number of patches for various image sizes.
"""
# Get the patch size from the image processor
patch_size = processor_from_pretrained.image_processor.max_image_size.get(
"longest_edge", 512
)
patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512)

# Test case 1: Small square image
image_size = (100, 100)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
image_size, patch_size
)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
assert isinstance(n_patches_x, int)
assert isinstance(n_patches_y, int)
assert n_patches_x > 0
assert n_patches_y > 0

# Test case 2: Wide image (width > height)
image_size = (100, 200)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
image_size, patch_size
)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
assert n_patches_x >= n_patches_y # More patches along width

# Test case 3: Tall image (height > width)
image_size = (200, 100)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
image_size, patch_size
)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
assert n_patches_y >= n_patches_x # More patches along height

# Test case 4: Square image
image_size = (500, 500)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
image_size, patch_size
)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
assert n_patches_x == n_patches_y # Equal patches for square image


Expand All @@ -126,22 +116,18 @@ def test_get_n_patches_matches_actual_processing(
actual_num_patches = batch_feature["pixel_values"].shape[1]

# Get the patch size from the image processor
patch_size = processor_from_pretrained.image_processor.max_image_size.get(
"longest_edge", 512
)
patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512)

# Calculate expected patches using get_n_patches
# Note: image_size for get_n_patches is (height, width), but PIL uses (width, height)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
(image_size[1], image_size[0]), patch_size
)
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((image_size[1], image_size[0]), patch_size)
expected_num_patches = n_patches_x * n_patches_y

# The actual number of patches includes the global image patch (+1)
# So we compare with expected + 1
assert (
actual_num_patches == expected_num_patches + 1
), f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}"
assert actual_num_patches == expected_num_patches + 1, (
f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}"
)


def test_get_image_mask(processor_from_pretrained: ColIdefics3Processor):
Expand Down
Loading