Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ dmypy.json

# Pyre type checker
.pyre/
CLIP-DiDeMo/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup

66 changes: 66 additions & 0 deletions src/didemo_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from aggregation.mean import Mean
import sys
import os
import open_clip
import torch

sys.path.insert(1, '/Users/daniel/Desktop/LAION_Videoclip/clip-video-encode')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup

from clip_video_encode.dataset import EmbeddingWebDatasetReader

from evaluation.retrieval import retrieval_evaluation

def process_times(times, len_embeddings):
'''
Assumptions:
- times contains [start, end] in intervals of 5 seconds,
i.e. [0, 1] corresponds to [0*5, (1+1)*5]

- there's 1 embedding per second, i.e. embeddings[10] is the sole embedding for the 10th second
'''
SEGMENT_INTERVAL = 5
return [
SEGMENT_INTERVAL * times[0],
min( SEGMENT_INTERVAL * (times[1] + 1), len_embeddings )
]

def zero_pad(e, seq_len, model_dim=1024):
out = torch.zeros(size=(seq_len, model_dim))
out[0:len(e)] = e

return out

def process_didemo_segments(embeddings, segments, seq_len=200):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no function, no specific code for each dataset, just move these operations into the "if "segments" in batch["meta"]" part of the retrieval eval

times_frames = [
process_times(caption_segments[0], seq_len)
for caption_segments in segments
] # taking only the first annotated segment

out_embeddings = torch.stack([
zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len)
for (start, end) in times_frames
])

return out_embeddings

if __name__ == "__main__":
eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/'
val_urls = eval_path + '{000000000..000000007}.tar'
val_reader = EmbeddingWebDatasetReader(
val_urls,
standard_seq_len=200,
batch_size=1,
num_prepro_workers=0,
to_tensor=False,
enable_text=True,
enable_meta=True
)

device = "cuda" if torch.cuda.is_available() else "cpu"

model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')

model_text = model.encode_text
model_video = Mean().to(device)

ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=False, process_segments=process_didemo_segments)
print(ret_mets)
27 changes: 22 additions & 5 deletions src/evaluation/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import numpy as np
import torch


def retrieval_evaluation(model_video, model_text, data, multicaption=False):
def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False):
if type(data) == dict:
dataloader = data["val"].dataloader
else:
Expand All @@ -15,6 +14,8 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False):

with torch.no_grad():
for i, batch in enumerate(dataloader):
if i==3:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup

break
embeddings = batch["embeddings"]
toks = []
# TODO: does this require batch_size = 1 ??
Expand All @@ -23,10 +24,24 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False):
for c in cap.split(";"): # multiple captions separated by ;
toks.append(open_clip.tokenize(c))
ground_truth.append(samp)
samp += segment

if segment:
segments = batch["meta"]["segment"]
for idx, segment in enumerate(segments):
start_frame, end_frame = segment

# Zero pad our segmented
segmented_embedding = torch.zeros_like(embeddings[idx])
segmented_embedding[start_frame:end_frame] = embeddings[idx][start_frame:end_frame]

embeddings[idx] = segmented_embedding

else:
toks.append(open_clip.tokenize(cap))
ground_truth.append(samp)
samp += 1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup

toks = torch.cat(toks)
embeddings = embeddings.to(device, non_blocking=True)
toks = toks.to(device, non_blocking=True)
Expand All @@ -36,7 +51,7 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False):

all_video_features.append(video_embeddings.cpu())
all_text_features.append(text_embeddings.cpu())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup

val_metrics = get_metrics(
video_features=torch.cat(all_video_features),
text_features=torch.cat(all_text_features),
Expand All @@ -53,7 +68,9 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale):
logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_video.t().detach().cpu()


# logits_per_video = torch.rand_like(logits_per_video)
# logits_per_text = torch.rand_like(logits_per_text)

# TODO: let's to text2video correctly and then figure out how to do video2text
# maybe video2text is average logits over multiple captions
'''
Expand All @@ -67,7 +84,7 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale):
# logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text}
logits = {"text_to_video": logits_per_text}
ground_truth = torch.tensor(ground_truth).view(-1, 1)

# print(f'Num samples: {len(logits_per_text)}')
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
Expand Down
4 changes: 2 additions & 2 deletions src/evaluation/retrieval_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def retrieval_evaluation(model_video, model_text, data):
device = "cuda" if torch.cuda.is_available() else "cpu"
dataloader = data
all_video_features, all_text_features = [], []

with torch.no_grad():
for i, batch in enumerate(dataloader):
# embeddings, toks = batch
Expand All @@ -24,15 +25,14 @@ def retrieval_evaluation(model_video, model_text, data):

all_video_features.append(video_embeddings.cpu())
all_text_features.append(text_embeddings.cpu())

val_metrics = get_metrics(
video_features=torch.cat(all_video_features),
text_features=torch.cat(all_text_features),
logit_scale=100.0,
)
return val_metrics


def get_metrics(video_features, text_features, logit_scale):
metrics = {}

Expand Down