Skip to content

Commit 60f3656

Browse files
committed
lint
1 parent 3bbbf67 commit 60f3656

File tree

7 files changed

+57
-121
lines changed

7 files changed

+57
-121
lines changed

colpali_engine/interpretability/similarity_map_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ def normalize_similarity_map(
7676

7777
if value_range is None:
7878
# Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
79-
min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(
80-
dim=-2, keepdim=True
81-
)[0] # (1, 1) or (batch_size, 1, 1)
79+
min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[
80+
0
81+
] # (1, 1) or (batch_size, 1, 1)
8282

8383
# Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
84-
max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(
85-
dim=-2, keepdim=True
86-
)[0] # (1, 1) or (batch_size, 1, 1)
84+
max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[
85+
0
86+
] # (1, 1) or (batch_size, 1, 1)
8787
else:
8888
min_vals, max_vals = value_range
8989
broadcast_shape = (1,) * similarity_map.ndim

colpali_engine/interpretability/similarity_maps.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ def plot_similarity_map(
4343

4444
# Normalize the similarity map and convert it to Pillow image
4545
similarity_map_array = (
46-
normalize_similarity_map(similarity_map, value_range=normalization_range)
47-
.to(torch.float32)
48-
.cpu()
49-
.numpy()
46+
normalize_similarity_map(similarity_map, value_range=normalization_range).to(torch.float32).cpu().numpy()
5047
) # (n_patches_x, n_patches_y)
5148

5249
# Reshape the similarity map to match the PIL shape convention

colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ class ColIdefics3Processor(
2222

2323
query_augmentation_token: ClassVar[str] = "<end_of_utterance>"
2424
image_token: ClassVar[str] = "<image>"
25-
visual_prompt_prefix: ClassVar[str] = (
26-
"<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
27-
)
25+
visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
2826

2927
def __init__(self, *args, image_seq_len=64, **kwargs):
3028
super().__init__(*args, image_seq_len=image_seq_len, **kwargs)
@@ -105,9 +103,7 @@ def get_n_patches(
105103
longest_edge = self.image_processor.size.get("longest_edge", 4 * patch_size)
106104

107105
# Step 1: Calculate resized dimensions using the mixin helper method
108-
height_new, width_new = self._calculate_resized_dimensions(
109-
image_size, longest_edge
110-
)
106+
height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge)
111107

112108
# Step 2: Calculate the number of patches in each direction
113109
# This mirrors the split_image logic from Idefics3ImageProcessor

colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ def get_n_patches(
110110
longest_edge = self.image_processor.size.get("longest_edge", 2048)
111111

112112
# Step 1: Calculate resized dimensions using the mixin helper method
113-
height_new, width_new = self._calculate_resized_dimensions(
114-
image_size, longest_edge
115-
)
113+
height_new, width_new = self._calculate_resized_dimensions(image_size, longest_edge)
116114

117115
# Step 2: Calculate number of sub-patches (512x512 patches)
118116
# This mirrors the split_image logic from Idefics3ImageProcessor

colpali_engine/utils/processing_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,7 @@ def get_similarity_maps_from_embeddings(
495495
# query: (query_tokens, dim)
496496
# image_grid: (n_patches_x, n_patches_y, dim)
497497
# result: (query_tokens, n_patches_x, n_patches_y)
498-
similarity_map = torch.einsum(
499-
"nk,ijk->nij", query_embeddings[idx], image_embedding_grid
500-
)
498+
similarity_map = torch.einsum("nk,ijk->nij", query_embeddings[idx], image_embedding_grid)
501499

502500
similarity_maps.append(similarity_map)
503501

tests/models/idefics3/colidefics3/test_processing_colidefics3.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,39 +74,29 @@ def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor):
7474
Test that get_n_patches returns the correct number of patches for various image sizes.
7575
"""
7676
# Get the patch size from the image processor
77-
patch_size = processor_from_pretrained.image_processor.max_image_size.get(
78-
"longest_edge", 512
79-
)
77+
patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512)
8078

8179
# Test case 1: Small square image
8280
image_size = (100, 100)
83-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
84-
image_size, patch_size
85-
)
81+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
8682
assert isinstance(n_patches_x, int)
8783
assert isinstance(n_patches_y, int)
8884
assert n_patches_x > 0
8985
assert n_patches_y > 0
9086

9187
# Test case 2: Wide image (width > height)
9288
image_size = (100, 200)
93-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
94-
image_size, patch_size
95-
)
89+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
9690
assert n_patches_x >= n_patches_y # More patches along width
9791

9892
# Test case 3: Tall image (height > width)
9993
image_size = (200, 100)
100-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
101-
image_size, patch_size
102-
)
94+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
10395
assert n_patches_y >= n_patches_x # More patches along height
10496

10597
# Test case 4: Square image
10698
image_size = (500, 500)
107-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
108-
image_size, patch_size
109-
)
99+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
110100
assert n_patches_x == n_patches_y # Equal patches for square image
111101

112102

@@ -126,22 +116,18 @@ def test_get_n_patches_matches_actual_processing(
126116
actual_num_patches = batch_feature["pixel_values"].shape[1]
127117

128118
# Get the patch size from the image processor
129-
patch_size = processor_from_pretrained.image_processor.max_image_size.get(
130-
"longest_edge", 512
131-
)
119+
patch_size = processor_from_pretrained.image_processor.max_image_size.get("longest_edge", 512)
132120

133121
# Calculate expected patches using get_n_patches
134122
# Note: image_size for get_n_patches is (height, width), but PIL uses (width, height)
135-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
136-
(image_size[1], image_size[0]), patch_size
137-
)
123+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((image_size[1], image_size[0]), patch_size)
138124
expected_num_patches = n_patches_x * n_patches_y
139125

140126
# The actual number of patches includes the global image patch (+1)
141127
# So we compare with expected + 1
142-
assert (
143-
actual_num_patches == expected_num_patches + 1
144-
), f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}"
128+
assert actual_num_patches == expected_num_patches + 1, (
129+
f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}"
130+
)
145131

146132

147133
def test_get_image_mask(processor_from_pretrained: ColIdefics3Processor):

tests/models/modernvbert/test_interpretability_colmodernvbert.py

Lines changed: 36 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
import torch
1414
from PIL import Image
1515

16-
from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
1716
from colpali_engine.interpretability.similarity_map_utils import (
1817
normalize_similarity_map,
1918
)
19+
from colpali_engine.models import ColModernVBert, ColModernVBertProcessor
2020

2121

2222
@pytest.fixture(scope="module")
@@ -28,9 +28,7 @@ def model_name() -> str:
2828
def processor_from_pretrained(
2929
model_name: str,
3030
) -> Generator[ColModernVBertProcessor, None, None]:
31-
yield cast(
32-
ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name)
33-
)
31+
yield cast(ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name))
3432

3533

3634
@pytest.fixture(scope="module")
@@ -41,79 +39,55 @@ def model_from_pretrained(model_name: str) -> Generator[ColModernVBert, None, No
4139
class TestGetNPatches:
4240
"""Test the get_n_patches method for calculating patch dimensions."""
4341

44-
def test_get_n_patches_returns_integers(
45-
self, processor_from_pretrained: ColModernVBertProcessor
46-
):
42+
def test_get_n_patches_returns_integers(self, processor_from_pretrained: ColModernVBertProcessor):
4743
"""Test that get_n_patches returns integer values."""
4844
patch_size = 14 # Common patch size for vision transformers
4945
image_size = (100, 100)
5046

51-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
52-
image_size, patch_size
53-
)
47+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
5448

5549
assert isinstance(n_patches_x, int)
5650
assert isinstance(n_patches_y, int)
5751
assert n_patches_x > 0
5852
assert n_patches_y > 0
5953

60-
def test_get_n_patches_wide_image(
61-
self, processor_from_pretrained: ColModernVBertProcessor
62-
):
54+
def test_get_n_patches_wide_image(self, processor_from_pretrained: ColModernVBertProcessor):
6355
"""Test that wide images have more patches along width."""
6456
patch_size = 14
6557
image_size = (100, 200) # (height, width) - wider than tall
6658

67-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
68-
image_size, patch_size
69-
)
59+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
7060

7161
# n_patches_x is along width, n_patches_y is along height
72-
assert (
73-
n_patches_x >= n_patches_y
74-
), f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}"
62+
assert n_patches_x >= n_patches_y, f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}"
7563

76-
def test_get_n_patches_tall_image(
77-
self, processor_from_pretrained: ColModernVBertProcessor
78-
):
64+
def test_get_n_patches_tall_image(self, processor_from_pretrained: ColModernVBertProcessor):
7965
"""Test that tall images have more patches along height."""
8066
patch_size = 14
8167
image_size = (200, 100) # (height, width) - taller than wide
8268

83-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
84-
image_size, patch_size
85-
)
69+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
8670

87-
assert (
88-
n_patches_y >= n_patches_x
89-
), f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}"
71+
assert n_patches_y >= n_patches_x, f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}"
9072

91-
def test_get_n_patches_square_image(
92-
self, processor_from_pretrained: ColModernVBertProcessor
93-
):
73+
def test_get_n_patches_square_image(self, processor_from_pretrained: ColModernVBertProcessor):
9474
"""Test that square images have equal patches in both dimensions."""
9575
patch_size = 14
9676
image_size = (500, 500)
9777

98-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
99-
image_size, patch_size
100-
)
78+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
10179

102-
assert (
103-
n_patches_x == n_patches_y
104-
), f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}"
80+
assert n_patches_x == n_patches_y, (
81+
f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}"
82+
)
10583

106-
def test_get_n_patches_aspect_ratio_preservation(
107-
self, processor_from_pretrained: ColModernVBertProcessor
108-
):
84+
def test_get_n_patches_aspect_ratio_preservation(self, processor_from_pretrained: ColModernVBertProcessor):
10985
"""Test that aspect ratio is approximately preserved in patch dimensions."""
11086
patch_size = 14
11187

11288
# Test with a 2:1 aspect ratio image
11389
image_size = (300, 600) # height=300, width=600
114-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
115-
image_size, patch_size
116-
)
90+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(image_size, patch_size)
11791

11892
# The aspect ratio of patches should be close to 2:1
11993
patch_ratio = n_patches_x / n_patches_y
@@ -124,15 +98,13 @@ def test_get_n_patches_aspect_ratio_preservation(
12498
# 2. Even-dimension rounding in resize logic
12599
# 3. Ceiling division in patch calculations
126100
# These factors can cause ~25% deviation from the ideal aspect ratio
127-
assert 1.5 <= patch_ratio <= 2.5, f"Expected ~2:1 ratio, got {patch_ratio:.2f}"
101+
assert 1.5 <= patch_ratio <= 2.5, f"Expected ~{expected_ratio}:1 ratio, got {patch_ratio:.2f}"
128102

129103

130104
class TestGetImageMask:
131105
"""Test the get_image_mask method for identifying image tokens."""
132106

133-
def test_get_image_mask_shape(
134-
self, processor_from_pretrained: ColModernVBertProcessor
135-
):
107+
def test_get_image_mask_shape(self, processor_from_pretrained: ColModernVBertProcessor):
136108
"""Test that image mask has the same shape as input_ids."""
137109
image = Image.new("RGB", (64, 32), color="red")
138110
batch_feature = processor_from_pretrained.process_images([image])
@@ -142,23 +114,17 @@ def test_get_image_mask_shape(
142114
assert image_mask.shape == batch_feature.input_ids.shape
143115
assert image_mask.dtype == torch.bool
144116

145-
def test_get_image_mask_has_image_tokens(
146-
self, processor_from_pretrained: ColModernVBertProcessor
147-
):
117+
def test_get_image_mask_has_image_tokens(self, processor_from_pretrained: ColModernVBertProcessor):
148118
"""Test that the mask identifies some image tokens."""
149119
image = Image.new("RGB", (64, 32), color="blue")
150120
batch_feature = processor_from_pretrained.process_images([image])
151121

152122
image_mask = processor_from_pretrained.get_image_mask(batch_feature)
153123

154124
# There should be image tokens present
155-
assert (
156-
image_mask.sum() > 0
157-
), "Expected to find image tokens in the processed batch"
125+
assert image_mask.sum() > 0, "Expected to find image tokens in the processed batch"
158126

159-
def test_get_image_mask_batch_consistency(
160-
self, processor_from_pretrained: ColModernVBertProcessor
161-
):
127+
def test_get_image_mask_batch_consistency(self, processor_from_pretrained: ColModernVBertProcessor):
162128
"""Test that image mask works correctly with batched images."""
163129
images = [
164130
Image.new("RGB", (64, 32), color="red"),
@@ -197,14 +163,13 @@ def test_similarity_maps_shape(
197163

198164
# Get patch size from the model or processor
199165
# ModernVBert uses patch_size from its config
200-
patch_size = (
201-
14 # Default for many vision transformers (unused but required for API)
202-
)
166+
patch_size = 14 # Default for many vision transformers (unused but required for API)
203167

204168
# Calculate expected patches
205169
# Note: image_size for get_n_patches is (height, width)
206170
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
207-
(image_size_pil[1], image_size_pil[0]), patch_size # (height, width)
171+
(image_size_pil[1], image_size_pil[0]),
172+
patch_size, # (height, width)
208173
)
209174

210175
# Get embeddings
@@ -230,9 +195,9 @@ def test_similarity_maps_shape(
230195

231196
# similarity_maps[0] should have shape (query_tokens, n_patches_x, n_patches_y)
232197
expected_shape = (query_length, n_patches_x, n_patches_y)
233-
assert (
234-
similarity_maps[0].shape == expected_shape
235-
), f"Expected shape {expected_shape}, got {similarity_maps[0].shape}"
198+
assert similarity_maps[0].shape == expected_shape, (
199+
f"Expected shape {expected_shape}, got {similarity_maps[0].shape}"
200+
)
236201

237202
@pytest.mark.slow
238203
def test_similarity_maps_values(
@@ -248,9 +213,7 @@ def test_similarity_maps_values(
248213
batch_queries = processor_from_pretrained.process_texts([query])
249214

250215
patch_size = 14
251-
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches(
252-
(64, 64), patch_size
253-
)
216+
n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches((64, 64), patch_size)
254217

255218
with torch.no_grad():
256219
image_embeddings = model_from_pretrained(**batch_images)
@@ -274,9 +237,7 @@ def test_similarity_maps_values(
274237
# After normalization, values should be in [0, 1]
275238
assert normalized_map.min() >= 0.0
276239
assert normalized_map.max() <= 1.0
277-
assert (
278-
normalized_map.max() == 1.0
279-
) # Max should be exactly 1.0 after normalization
240+
assert normalized_map.max() == 1.0 # Max should be exactly 1.0 after normalization
280241

281242
@pytest.mark.slow
282243
def test_patch_count_matches_mask_count(
@@ -303,9 +264,9 @@ def test_patch_count_matches_mask_count(
303264
expected_local_patches = n_patches_x * n_patches_y
304265

305266
# LOCAL tokens should match exactly
306-
assert (
307-
actual_local_tokens == expected_local_patches
308-
), f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}"
267+
assert actual_local_tokens == expected_local_patches, (
268+
f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}"
269+
)
309270

310271
@pytest.mark.slow
311272
def test_global_patch_excluded(
@@ -326,9 +287,9 @@ def test_global_patch_excluded(
326287

327288
# The difference should be exactly image_seq_len (global patch tokens)
328289
image_seq_len = processor_from_pretrained.image_seq_len
329-
assert (
330-
full_count - local_count == image_seq_len
331-
), f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}"
290+
assert full_count - local_count == image_seq_len, (
291+
f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}"
292+
)
332293

333294

334295
class TestInterpretabilityConsistency:

0 commit comments

Comments
 (0)