Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
4aed370
test
ManuelFay Feb 18, 2025
2151d53
fff
ManuelFay Feb 18, 2025
68f8ff6
gg
ManuelFay Feb 19, 2025
2079dd3
lesgo
ManuelFay Feb 19, 2025
2a2605b
fix
ManuelFay Feb 19, 2025
89e9c72
etst
ManuelFay Feb 19, 2025
a4ae460
fr
ManuelFay Feb 19, 2025
8289456
brr
ManuelFay Feb 19, 2025
96064d9
debugf
ManuelFay Feb 19, 2025
3b28dd4
fic
ManuelFay Feb 19, 2025
df26bd6
etst
ManuelFay Feb 19, 2025
12df709
debug
ManuelFay Feb 19, 2025
33aadab
etst
ManuelFay Feb 19, 2025
318c027
test
ManuelFay Feb 19, 2025
6827d2d
fix
ManuelFay Feb 19, 2025
6bb085b
test
ManuelFay Feb 19, 2025
12b4df4
fff
ManuelFay Feb 19, 2025
e488426
fix
ManuelFay Feb 19, 2025
05f2d49
test
ManuelFay Feb 19, 2025
2983cf8
fff
ManuelFay Feb 19, 2025
92ffe94
fff
ManuelFay Feb 19, 2025
dd1aa13
pred step
ManuelFay Feb 19, 2025
7d553e3
tets
ManuelFay Feb 19, 2025
ae6f978
fff
ManuelFay Feb 19, 2025
f4ca45d
test
ManuelFay Feb 19, 2025
b1e7d68
vfdg
ManuelFay Feb 19, 2025
964f96b
fgr
ManuelFay Feb 19, 2025
4c1e8f3
dataloadr
ManuelFay Feb 19, 2025
88b657f
gg
ManuelFay Feb 19, 2025
efe6fea
fff
ManuelFay Feb 19, 2025
b9d0f48
low worker
ManuelFay Feb 19, 2025
2c3aaae
low worker
ManuelFay Feb 19, 2025
79d14e5
ff
ManuelFay Feb 19, 2025
fa01408
512
ManuelFay Feb 19, 2025
59869ff
512
ManuelFay Feb 19, 2025
b143122
tt
ManuelFay Feb 19, 2025
33d6cbe
test
ManuelFay Feb 19, 2025
f546674
ff
ManuelFay Feb 19, 2025
2426507
gradcache
ManuelFay Feb 20, 2025
7f35d0f
tt
ManuelFay Feb 20, 2025
bfced75
debug
ManuelFay Feb 20, 2025
17450e3
ff
ManuelFay Feb 20, 2025
e19b796
fff
ManuelFay Feb 20, 2025
a3c5935
mini bs
ManuelFay Feb 20, 2025
1414a8e
fff
ManuelFay Feb 20, 2025
b5dce31
ffff
ManuelFay Feb 20, 2025
524b6fd
ffg
ManuelFay Feb 20, 2025
a76abe3
break
ManuelFay Feb 20, 2025
aa27ae1
fff
ManuelFay Feb 20, 2025
7563254
pad
ManuelFay Feb 20, 2025
cf831d3
fff
ManuelFay Feb 20, 2025
c637e6f
fff
ManuelFay Feb 20, 2025
88d096f
fff
ManuelFay Feb 20, 2025
d3f11ab
fff
ManuelFay Feb 20, 2025
605c0e8
revert
ManuelFay Feb 20, 2025
6c9d731
test
ManuelFay Feb 20, 2025
38009d0
fff
ManuelFay Feb 20, 2025
c391de3
fff
ManuelFay Feb 20, 2025
921a681
goo
ManuelFay Feb 20, 2025
6f76630
test
ManuelFay Feb 20, 2025
13c3678
shape
ManuelFay Feb 20, 2025
bad7a53
gooo
ManuelFay Feb 20, 2025
33a52b6
train
ManuelFay Feb 21, 2025
54e3e09
fff
ManuelFay Feb 21, 2025
a29a5b7
fff
ManuelFay Feb 21, 2025
784873c
debug
ManuelFay Feb 21, 2025
e39c0b8
fff
ManuelFay Feb 21, 2025
427a0c8
ff
ManuelFay Feb 21, 2025
023c33e
fff
ManuelFay Feb 21, 2025
23a9d21
fff
ManuelFay Feb 21, 2025
3d2266e
ff
ManuelFay Feb 21, 2025
1dc3499
1024
ManuelFay Feb 21, 2025
3a13aea
test
ManuelFay Feb 21, 2025
df5a89a
smart resize
ManuelFay Feb 21, 2025
0fd33c3
writer batch size
ManuelFay Feb 21, 2025
e29b9bd
fff
ManuelFay Feb 21, 2025
ff48881
workers
ManuelFay Feb 21, 2025
6238862
dd
ManuelFay Feb 21, 2025
1adc6b8
test
ManuelFay Feb 21, 2025
2db198d
call
ManuelFay Feb 21, 2025
c177096
revert
ManuelFay Feb 21, 2025
a4ecc36
fff
ManuelFay Feb 21, 2025
7d9decf
gradcache
ManuelFay Feb 21, 2025
369fe1a
test
ManuelFay Feb 21, 2025
3748e8b
fff
ManuelFay Feb 21, 2025
487085d
ff
ManuelFay Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colpali_engine/collators/corpus_query_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@ def __init__(
image_dataset: Optional["Dataset"] = None, # noqa: F821
mined_negatives: bool = True,
corpus_format: str = "wikiss",
process_images_before_training: bool = False,
):
super().__init__(
processor=processor,
max_length=max_length,
process_images_before_training=process_images_before_training,
)
if image_dataset is None:
raise ValueError("`image_dataset` must be provided")
self.image_dataset = image_dataset
self.mined_negatives = mined_negatives
self.corpus_format = corpus_format
self.process_images_before_training = process_images_before_training

if self.corpus_format == "wikiss":
print("Mapping docids to indices")
Expand Down
108 changes: 96 additions & 12 deletions colpali_engine/collators/visual_retriever_collator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, List, Union

import torch
from PIL.Image import Image

from colpali_engine.models.idefics_2 import ColIdefics2Processor
Expand All @@ -16,10 +17,12 @@ def __init__(
self,
processor: BaseVisualRetrieverProcessor,
max_length: int = 2048,
process_images_before_training: bool = False,
):
self.processor = processor
self.image_token_id = None
self.max_length = max_length
self.process_images_before_training = process_images_before_training

if isinstance(self.processor, ColPaliProcessor) or isinstance(self.processor, ColIdefics2Processor):
self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
Expand All @@ -35,6 +38,15 @@ def __init__(
def __call__(
self,
examples: List[Dict[str, Any]],
) -> Dict[str, Any]:
if self.process_images_before_training:
return self.offline_processing(examples)
return self.online_processing(examples)


def online_processing(
self,
examples: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Collate function for the vision retriever associated to the collator's processor.
Expand All @@ -43,20 +55,16 @@ def __call__(
texts_query: Union[List[str], List[None], List[Union[str, None]]] = [] # some documents don't have a query
images: List[Image] = []
neg_images: List[Image] = []
# breakpoint()

if self.processor is None or not isinstance(self.processor, BaseVisualRetrieverProcessor):
raise ValueError("Processor should be provided for vision collator.")

# Process each example
for example in examples:
texts_query.append(example["query"])
if example["image"] is None:
raise ValueError("Image is None - This collator does not support None images yet.")

images.append(cast(Image, example["image"]))
images.append(example["image"])

if "neg_image" in example and example["neg_image"] is not None:
neg_images.append(cast(Image, example["neg_image"]))
neg_images.append(example["neg_image"])

# Process the documents
batch_doc = self.processor.process_images(
Expand All @@ -76,11 +84,8 @@ def __call__(
if all([t is None for t in texts_query]):
# print("All queries are `None`. Returning `None` for all queries.")
pass
elif any([t is None for t in texts_query]):
# If it's the first query that is not None but the rest are None, then it's hard negatives.
raise ValueError("Some queries are None. This collator does not support None queries yet.")
else:
texts_query = cast(List[str], texts_query)
texts_query: List[str] = texts_query
batch_query = self.processor.process_queries(
queries=texts_query,
max_length=self.max_length,
Expand All @@ -98,3 +103,82 @@ def __call__(
batch_all.update(batch_neg_doc)

return batch_all


def offline_processing(
self,
examples: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Collate function for the vision retriever associated to the collator's processor.
"""
# Placeholders
texts_query = []
pixel_values = []
image_grid_thw = []
input_ids = []
attention_mask = []
neg_pixel_values = []
neg_image_grid_thw = []
neg_input_ids = []
neg_attention_mask = []

for example in examples:
texts_query.append(example["query"])
pixel_values.append(example["pixel_values"])
image_grid_thw.append(example["image_grid_thw"])
input_ids.append(example["input_ids"])
attention_mask.append(example["attention_mask"])

if "neg_pixel_values" in example:
neg_pixel_values.append(example["neg_pixel_values"])
neg_image_grid_thw.append(example["neg_image_grid_thw"])
neg_input_ids.append(example["neg_input_ids"])
neg_attention_mask.append(example["neg_attention_mask"])

# Pad pixel values
pixel_values = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True, padding_value=0)
image_grid_thw = torch.stack(image_grid_thw)

# Pad input sequences
batch_doc = self.processor.tokenizer.pad(
{"input_ids": input_ids, "attention_mask": attention_mask},
padding=True,
return_tensors="pt"
)

batch_all = {
"doc_pixel_values": pixel_values,
"doc_image_grid_thw": image_grid_thw,
"doc_input_ids": batch_doc["input_ids"],
"doc_attention_mask": batch_doc["attention_mask"],
}

# Process queries
if any(texts_query): # Ensure there are valid queries
batch_query = self.processor.process_queries(
queries=texts_query,
max_length=self.max_length
)
batch_all["query_input_ids"] = batch_query["input_ids"]
batch_all["query_attention_mask"] = batch_query["attention_mask"]

# Process negatives if present
if neg_pixel_values:
neg_pixel_values = torch.nn.utils.rnn.pad_sequence(neg_pixel_values, batch_first=True, padding_value=0)
neg_image_grid_thw = torch.stack(neg_image_grid_thw)

batch_neg_doc = self.processor.tokenizer.pad(
{"input_ids": neg_input_ids, "attention_mask": neg_attention_mask},
padding=True,
return_tensors="pt"
)

batch_all.update({
"neg_doc_pixel_values": neg_pixel_values,
"neg_doc_image_grid_thw": neg_image_grid_thw,
"neg_doc_input_ids": batch_neg_doc["input_ids"],
"neg_doc_attention_mask": batch_neg_doc["attention_mask"],
})

return batch_all
5 changes: 5 additions & 0 deletions colpali_engine/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
BiPairwiseCELoss,
BiPairwiseNegativeCELoss,
)
from .gradcache_late_interaction_losses import (
GradCacheColbertLoss,
GradCacheColbertPairwiseCELoss,
GradCacheColbertPairwiseNegativeCELoss,
)
from .late_interaction_losses import (
ColbertLoss,
ColbertPairwiseCELoss,
Expand Down
Loading