diff --git a/.gitignore b/.gitignore index b6e4761..70f6445 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,4 @@ dmypy.json # Pyre type checker .pyre/ +CLIP-DiDeMo/ diff --git a/src/didemo_eval.py b/src/didemo_eval.py new file mode 100644 index 0000000..3a3fda5 --- /dev/null +++ b/src/didemo_eval.py @@ -0,0 +1,66 @@ +from aggregation.mean import Mean +import sys +import os +import open_clip +import torch + +sys.path.insert(1, '/Users/daniel/Desktop/LAION_Videoclip/clip-video-encode') +from clip_video_encode.dataset import EmbeddingWebDatasetReader + +from evaluation.retrieval import retrieval_evaluation + +def process_times(times, len_embeddings): + ''' + Assumptions: + - times contains [start, end] in intervals of 5 seconds, + i.e. [0, 1] corresponds to [0*5, (1+1)*5] + + - there's 1 embedding per second, i.e. embeddings[10] is the sole embedding for the 10th second + ''' + SEGMENT_INTERVAL = 5 + return [ + SEGMENT_INTERVAL * times[0], + min( SEGMENT_INTERVAL * (times[1] + 1), len_embeddings ) + ] + +def zero_pad(e, seq_len, model_dim=1024): + out = torch.zeros(size=(seq_len, model_dim)) + out[0:len(e)] = e + + return out + +def process_didemo_segments(embeddings, segments, seq_len=200): + times_frames = [ + process_times(caption_segments[0], seq_len) + for caption_segments in segments + ] # taking only the first annotated segment + + out_embeddings = torch.stack([ + zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) + for (start, end) in times_frames + ]) + + return out_embeddings + +if __name__ == "__main__": + eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/' + val_urls = eval_path + '{000000000..000000007}.tar' + val_reader = EmbeddingWebDatasetReader( + val_urls, + standard_seq_len=200, + batch_size=1, + num_prepro_workers=0, + to_tensor=False, + enable_text=True, + enable_meta=True + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') + + model_text = model.encode_text + model_video = Mean().to(device) + + ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=False, process_segments=process_didemo_segments) + print(ret_mets) \ No newline at end of file diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 49f976f..56eace3 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -2,8 +2,7 @@ import numpy as np import torch - -def retrieval_evaluation(model_video, model_text, data, multicaption=False): +def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False): if type(data) == dict: dataloader = data["val"].dataloader else: @@ -15,6 +14,8 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): with torch.no_grad(): for i, batch in enumerate(dataloader): + if i==3: + break embeddings = batch["embeddings"] toks = [] # TODO: does this require batch_size = 1 ?? @@ -23,10 +24,24 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): for c in cap.split(";"): # multiple captions separated by ; toks.append(open_clip.tokenize(c)) ground_truth.append(samp) + samp += segment + + if segment: + segments = batch["meta"]["segment"] + for idx, segment in enumerate(segments): + start_frame, end_frame = segment + + # Zero pad our segmented + segmented_embedding = torch.zeros_like(embeddings[idx]) + segmented_embedding[start_frame:end_frame] = embeddings[idx][start_frame:end_frame] + + embeddings[idx] = segmented_embedding + else: toks.append(open_clip.tokenize(cap)) ground_truth.append(samp) samp += 1 + toks = torch.cat(toks) embeddings = embeddings.to(device, non_blocking=True) toks = toks.to(device, non_blocking=True) @@ -36,7 +51,7 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): all_video_features.append(video_embeddings.cpu()) all_text_features.append(text_embeddings.cpu()) - + val_metrics = get_metrics( video_features=torch.cat(all_video_features), text_features=torch.cat(all_text_features), @@ -53,7 +68,9 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_video.t().detach().cpu() - + # logits_per_video = torch.rand_like(logits_per_video) + # logits_per_text = torch.rand_like(logits_per_text) + # TODO: let's to text2video correctly and then figure out how to do video2text # maybe video2text is average logits over multiple captions ''' @@ -67,7 +84,7 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): # logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} logits = {"text_to_video": logits_per_text} ground_truth = torch.tensor(ground_truth).view(-1, 1) - + # print(f'Num samples: {len(logits_per_text)}') for name, logit in logits.items(): ranking = torch.argsort(logit, descending=True) preds = torch.where(ranking == ground_truth)[1] diff --git a/src/evaluation/retrieval_old.py b/src/evaluation/retrieval_old.py index 668ad8e..30c49a8 100644 --- a/src/evaluation/retrieval_old.py +++ b/src/evaluation/retrieval_old.py @@ -10,6 +10,7 @@ def retrieval_evaluation(model_video, model_text, data): device = "cuda" if torch.cuda.is_available() else "cpu" dataloader = data all_video_features, all_text_features = [], [] + with torch.no_grad(): for i, batch in enumerate(dataloader): # embeddings, toks = batch @@ -24,7 +25,7 @@ def retrieval_evaluation(model_video, model_text, data): all_video_features.append(video_embeddings.cpu()) all_text_features.append(text_embeddings.cpu()) - + val_metrics = get_metrics( video_features=torch.cat(all_video_features), text_features=torch.cat(all_text_features), @@ -32,7 +33,6 @@ def retrieval_evaluation(model_video, model_text, data): ) return val_metrics - def get_metrics(video_features, text_features, logit_scale): metrics = {}