1313import torch
1414from PIL import Image
1515
16- from colpali_engine .models import ColModernVBert , ColModernVBertProcessor
1716from colpali_engine .interpretability .similarity_map_utils import (
1817 normalize_similarity_map ,
1918)
19+ from colpali_engine .models import ColModernVBert , ColModernVBertProcessor
2020
2121
2222@pytest .fixture (scope = "module" )
@@ -28,9 +28,7 @@ def model_name() -> str:
2828def processor_from_pretrained (
2929 model_name : str ,
3030) -> Generator [ColModernVBertProcessor , None , None ]:
31- yield cast (
32- ColModernVBertProcessor , ColModernVBertProcessor .from_pretrained (model_name )
33- )
31+ yield cast (ColModernVBertProcessor , ColModernVBertProcessor .from_pretrained (model_name ))
3432
3533
3634@pytest .fixture (scope = "module" )
@@ -41,79 +39,55 @@ def model_from_pretrained(model_name: str) -> Generator[ColModernVBert, None, No
4139class TestGetNPatches :
4240 """Test the get_n_patches method for calculating patch dimensions."""
4341
44- def test_get_n_patches_returns_integers (
45- self , processor_from_pretrained : ColModernVBertProcessor
46- ):
42+ def test_get_n_patches_returns_integers (self , processor_from_pretrained : ColModernVBertProcessor ):
4743 """Test that get_n_patches returns integer values."""
4844 patch_size = 14 # Common patch size for vision transformers
4945 image_size = (100 , 100 )
5046
51- n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
52- image_size , patch_size
53- )
47+ n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (image_size , patch_size )
5448
5549 assert isinstance (n_patches_x , int )
5650 assert isinstance (n_patches_y , int )
5751 assert n_patches_x > 0
5852 assert n_patches_y > 0
5953
60- def test_get_n_patches_wide_image (
61- self , processor_from_pretrained : ColModernVBertProcessor
62- ):
54+ def test_get_n_patches_wide_image (self , processor_from_pretrained : ColModernVBertProcessor ):
6355 """Test that wide images have more patches along width."""
6456 patch_size = 14
6557 image_size = (100 , 200 ) # (height, width) - wider than tall
6658
67- n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
68- image_size , patch_size
69- )
59+ n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (image_size , patch_size )
7060
7161 # n_patches_x is along width, n_patches_y is along height
72- assert (
73- n_patches_x >= n_patches_y
74- ), f"Expected more patches along width, got x={ n_patches_x } , y={ n_patches_y } "
62+ assert n_patches_x >= n_patches_y , f"Expected more patches along width, got x={ n_patches_x } , y={ n_patches_y } "
7563
76- def test_get_n_patches_tall_image (
77- self , processor_from_pretrained : ColModernVBertProcessor
78- ):
64+ def test_get_n_patches_tall_image (self , processor_from_pretrained : ColModernVBertProcessor ):
7965 """Test that tall images have more patches along height."""
8066 patch_size = 14
8167 image_size = (200 , 100 ) # (height, width) - taller than wide
8268
83- n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
84- image_size , patch_size
85- )
69+ n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (image_size , patch_size )
8670
87- assert (
88- n_patches_y >= n_patches_x
89- ), f"Expected more patches along height, got x={ n_patches_x } , y={ n_patches_y } "
71+ assert n_patches_y >= n_patches_x , f"Expected more patches along height, got x={ n_patches_x } , y={ n_patches_y } "
9072
91- def test_get_n_patches_square_image (
92- self , processor_from_pretrained : ColModernVBertProcessor
93- ):
73+ def test_get_n_patches_square_image (self , processor_from_pretrained : ColModernVBertProcessor ):
9474 """Test that square images have equal patches in both dimensions."""
9575 patch_size = 14
9676 image_size = (500 , 500 )
9777
98- n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
99- image_size , patch_size
100- )
78+ n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (image_size , patch_size )
10179
102- assert (
103- n_patches_x == n_patches_y
104- ), f"Expected equal patches for square image, got x= { n_patches_x } , y= { n_patches_y } "
80+ assert n_patches_x == n_patches_y , (
81+ f"Expected equal patches for square image, got x= { n_patches_x } , y= { n_patches_y } "
82+ )
10583
106- def test_get_n_patches_aspect_ratio_preservation (
107- self , processor_from_pretrained : ColModernVBertProcessor
108- ):
84+ def test_get_n_patches_aspect_ratio_preservation (self , processor_from_pretrained : ColModernVBertProcessor ):
10985 """Test that aspect ratio is approximately preserved in patch dimensions."""
11086 patch_size = 14
11187
11288 # Test with a 2:1 aspect ratio image
11389 image_size = (300 , 600 ) # height=300, width=600
114- n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
115- image_size , patch_size
116- )
90+ n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (image_size , patch_size )
11791
11892 # The aspect ratio of patches should be close to 2:1
11993 patch_ratio = n_patches_x / n_patches_y
@@ -124,15 +98,13 @@ def test_get_n_patches_aspect_ratio_preservation(
12498 # 2. Even-dimension rounding in resize logic
12599 # 3. Ceiling division in patch calculations
126100 # These factors can cause ~25% deviation from the ideal aspect ratio
127- assert 1.5 <= patch_ratio <= 2.5 , f"Expected ~2 :1 ratio, got { patch_ratio :.2f} "
101+ assert 1.5 <= patch_ratio <= 2.5 , f"Expected ~{ expected_ratio } :1 ratio, got { patch_ratio :.2f} "
128102
129103
130104class TestGetImageMask :
131105 """Test the get_image_mask method for identifying image tokens."""
132106
133- def test_get_image_mask_shape (
134- self , processor_from_pretrained : ColModernVBertProcessor
135- ):
107+ def test_get_image_mask_shape (self , processor_from_pretrained : ColModernVBertProcessor ):
136108 """Test that image mask has the same shape as input_ids."""
137109 image = Image .new ("RGB" , (64 , 32 ), color = "red" )
138110 batch_feature = processor_from_pretrained .process_images ([image ])
@@ -142,23 +114,17 @@ def test_get_image_mask_shape(
142114 assert image_mask .shape == batch_feature .input_ids .shape
143115 assert image_mask .dtype == torch .bool
144116
145- def test_get_image_mask_has_image_tokens (
146- self , processor_from_pretrained : ColModernVBertProcessor
147- ):
117+ def test_get_image_mask_has_image_tokens (self , processor_from_pretrained : ColModernVBertProcessor ):
148118 """Test that the mask identifies some image tokens."""
149119 image = Image .new ("RGB" , (64 , 32 ), color = "blue" )
150120 batch_feature = processor_from_pretrained .process_images ([image ])
151121
152122 image_mask = processor_from_pretrained .get_image_mask (batch_feature )
153123
154124 # There should be image tokens present
155- assert (
156- image_mask .sum () > 0
157- ), "Expected to find image tokens in the processed batch"
125+ assert image_mask .sum () > 0 , "Expected to find image tokens in the processed batch"
158126
159- def test_get_image_mask_batch_consistency (
160- self , processor_from_pretrained : ColModernVBertProcessor
161- ):
127+ def test_get_image_mask_batch_consistency (self , processor_from_pretrained : ColModernVBertProcessor ):
162128 """Test that image mask works correctly with batched images."""
163129 images = [
164130 Image .new ("RGB" , (64 , 32 ), color = "red" ),
@@ -197,14 +163,13 @@ def test_similarity_maps_shape(
197163
198164 # Get patch size from the model or processor
199165 # ModernVBert uses patch_size from its config
200- patch_size = (
201- 14 # Default for many vision transformers (unused but required for API)
202- )
166+ patch_size = 14 # Default for many vision transformers (unused but required for API)
203167
204168 # Calculate expected patches
205169 # Note: image_size for get_n_patches is (height, width)
206170 n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
207- (image_size_pil [1 ], image_size_pil [0 ]), patch_size # (height, width)
171+ (image_size_pil [1 ], image_size_pil [0 ]),
172+ patch_size , # (height, width)
208173 )
209174
210175 # Get embeddings
@@ -230,9 +195,9 @@ def test_similarity_maps_shape(
230195
231196 # similarity_maps[0] should have shape (query_tokens, n_patches_x, n_patches_y)
232197 expected_shape = (query_length , n_patches_x , n_patches_y )
233- assert (
234- similarity_maps [0 ].shape == expected_shape
235- ), f"Expected shape { expected_shape } , got { similarity_maps [ 0 ]. shape } "
198+ assert similarity_maps [ 0 ]. shape == expected_shape , (
199+ f"Expected shape { expected_shape } , got { similarity_maps [0 ].shape } "
200+ )
236201
237202 @pytest .mark .slow
238203 def test_similarity_maps_values (
@@ -248,9 +213,7 @@ def test_similarity_maps_values(
248213 batch_queries = processor_from_pretrained .process_texts ([query ])
249214
250215 patch_size = 14
251- n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches (
252- (64 , 64 ), patch_size
253- )
216+ n_patches_x , n_patches_y = processor_from_pretrained .get_n_patches ((64 , 64 ), patch_size )
254217
255218 with torch .no_grad ():
256219 image_embeddings = model_from_pretrained (** batch_images )
@@ -274,9 +237,7 @@ def test_similarity_maps_values(
274237 # After normalization, values should be in [0, 1]
275238 assert normalized_map .min () >= 0.0
276239 assert normalized_map .max () <= 1.0
277- assert (
278- normalized_map .max () == 1.0
279- ) # Max should be exactly 1.0 after normalization
240+ assert normalized_map .max () == 1.0 # Max should be exactly 1.0 after normalization
280241
281242 @pytest .mark .slow
282243 def test_patch_count_matches_mask_count (
@@ -303,9 +264,9 @@ def test_patch_count_matches_mask_count(
303264 expected_local_patches = n_patches_x * n_patches_y
304265
305266 # LOCAL tokens should match exactly
306- assert (
307- actual_local_tokens == expected_local_patches
308- ), f"Expected { expected_local_patches } local image tokens, got { actual_local_tokens } "
267+ assert actual_local_tokens == expected_local_patches , (
268+ f"Expected { expected_local_patches } local image tokens, got { actual_local_tokens } "
269+ )
309270
310271 @pytest .mark .slow
311272 def test_global_patch_excluded (
@@ -326,9 +287,9 @@ def test_global_patch_excluded(
326287
327288 # The difference should be exactly image_seq_len (global patch tokens)
328289 image_seq_len = processor_from_pretrained .image_seq_len
329- assert (
330- full_count - local_count == image_seq_len
331- ), f"Expected { image_seq_len } global patch tokens, got { full_count - local_count } "
290+ assert full_count - local_count == image_seq_len , (
291+ f"Expected { image_seq_len } global patch tokens, got { full_count - local_count } "
292+ )
332293
333294
334295class TestInterpretabilityConsistency :
0 commit comments