From f71d047ee73d33d4fe4469b62178e783acd5085d Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 16:55:00 -0700 Subject: [PATCH 01/11] testing --- src/evaluation/retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 668ad8e..cea0603 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -1,6 +1,6 @@ import numpy as np import torch - +# hello def retrieval_evaluation(model_video, model_text, data): if type(data) == dict: From 806aafac72de151cc0a0b302b6160c49950b91db Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 20:28:52 -0700 Subject: [PATCH 02/11] begin work on adapting didemo to multicaption eval --- .gitignore | 1 + src/didemo_eval.py | 88 ++++++++++++++++++++++++ src/evaluation/multicaption_retrieval.py | 31 ++++++--- src/evaluation/retrieval.py | 1 - 4 files changed, 111 insertions(+), 10 deletions(-) create mode 100644 src/didemo_eval.py diff --git a/.gitignore b/.gitignore index b6e4761..70f6445 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,4 @@ dmypy.json # Pyre type checker .pyre/ +CLIP-DiDeMo/ diff --git a/src/didemo_eval.py b/src/didemo_eval.py new file mode 100644 index 0000000..a64ad48 --- /dev/null +++ b/src/didemo_eval.py @@ -0,0 +1,88 @@ +from aggregation.mean import Mean +import sys +import os +import open_clip +import torch + +sys.path.insert(1, '/Users/daniel/Desktop/LAION_Videoclip/clip-video-encode') +from clip_video_encode.dataset import EmbeddingWebDatasetReader +from evaluation.multicaption_retrieval import multicaption_retrieval_evaluation + +eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/' + +def process_times(times, len_embeddings): + ''' + Assumptions: + - times contains [start, end] in intervals of 5 seconds, + i.e. [0, 1] corresponds to [0*5, (1+1)*5] + + - there's 1 embedding per second, i.e. embeddings[10] is the sole embedding for the 10th second + ''' + SEGMENT_INTERVAL = 5 + return [ + SEGMENT_INTERVAL * times[0], + min( SEGMENT_INTERVAL * (times[1] + 1), len_embeddings ) + ] + +def zero_pad(e, seq_len, model_dim=512): + out = torch.zeros(size=(seq_len, model_dim)) + out[0:len(e)] = e + + return out + +def process_didemo_batch(batch, caption_sep = ';', device='cuda'): + SEGMENT_KEY = 'times' + embeddings = batch['embeddings'] + print(embeddings.shape) + seq_len = embeddings.shape[1] + + #print(batch) + + captions = [ + text.split(caption_sep) + for text in batch['text'] + ] + + times_frames = [ + process_times(caption_segments[0], seq_len) + for caption_segments in batch['meta'][SEGMENT_KEY] + ] # just take the first annotation for each caption + + toks = torch.stack([ + open_clip.tokenize(caption).to(device) + for caption in captions + ]).squeeze(0) + + out_embeddings = torch.stack([ + zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) + for (start, end) in times_frames + ]) + + #print(times_frames) + #print(out_embeddings.shape) + #print(toks.shape) + return out_embeddings, toks + + +if __name__ == "__main__": + val_urls = eval_path + '{000000000..000000007}.tar' + val_reader = EmbeddingWebDatasetReader( + val_urls, + standard_seq_len=200, + batch_size=1, + num_prepro_workers=0, + to_tensor=False, + enable_text=True, + enable_meta=True + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + #TODO: update dis to work with ViT-H/14 + model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion400m_e32') + + + model_text = model.encode_text + model_video = Mean().to(device) + + ret_mets = multicaption_retrieval_evaluation(model_video, model_text, val_reader, segment=True, process_batch=process_didemo_batch) + print(next(iter(val_reader))) diff --git a/src/evaluation/multicaption_retrieval.py b/src/evaluation/multicaption_retrieval.py index 5711f7f..75bf66f 100644 --- a/src/evaluation/multicaption_retrieval.py +++ b/src/evaluation/multicaption_retrieval.py @@ -3,21 +3,35 @@ import torch -def multicaption_retrieval_evaluation(model_video, model_text, data): +def multicaption_retrieval_evaluation(model_video, model_text, data, segment=False, process_batch=None): if type(data) == dict: dataloader = data["val"].dataloader else: dataloader = data device = "cuda" if torch.cuda.is_available() else "cpu" all_video_features, all_text_features = [], [] + with torch.no_grad(): for i, batch in enumerate(dataloader): - embeddings = batch["embeddings"] - toks = [] - for k, v in batch["meta"].items(): - if "caption" in k: - toks.append(clip.tokenize(v)) - toks = torch.cat(toks) + if i % 5 == 0: + print(i) + + if i == 10: + break + + if segment: + embeddings, toks = process_batch(batch, device=device) + + else: + embeddings = batch["embeddings"] + toks = [] + + for k, v in batch["meta"].items(): + if "caption" in k: + toks.append(clip.tokenize(v)) + + toks = torch.cat(toks) + embeddings = embeddings.to(device, non_blocking=True) toks = toks.to(device, non_blocking=True) @@ -26,7 +40,7 @@ def multicaption_retrieval_evaluation(model_video, model_text, data): all_video_features.append(video_embeddings.cpu()) all_text_features.append(text_embeddings.cpu()) - + val_metrics = get_metrics( video_features=torch.cat(all_video_features), text_features=torch.cat(all_text_features), @@ -34,7 +48,6 @@ def multicaption_retrieval_evaluation(model_video, model_text, data): ) return val_metrics - def get_metrics(video_features, text_features, logit_scale): metrics = {} diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index cea0603..f2c76d7 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -1,6 +1,5 @@ import numpy as np import torch -# hello def retrieval_evaluation(model_video, model_text, data): if type(data) == dict: From f6e2fe1afaa065780d13f6c98030555f3aaa08ea Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 20:44:59 -0700 Subject: [PATCH 03/11] process embeddings end-to-end --- src/didemo_eval.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/didemo_eval.py b/src/didemo_eval.py index a64ad48..cc45712 100644 --- a/src/didemo_eval.py +++ b/src/didemo_eval.py @@ -24,7 +24,7 @@ def process_times(times, len_embeddings): min( SEGMENT_INTERVAL * (times[1] + 1), len_embeddings ) ] -def zero_pad(e, seq_len, model_dim=512): +def zero_pad(e, seq_len, model_dim=1024): out = torch.zeros(size=(seq_len, model_dim)) out[0:len(e)] = e @@ -33,7 +33,7 @@ def zero_pad(e, seq_len, model_dim=512): def process_didemo_batch(batch, caption_sep = ';', device='cuda'): SEGMENT_KEY = 'times' embeddings = batch['embeddings'] - print(embeddings.shape) + #print(embeddings.shape) seq_len = embeddings.shape[1] #print(batch) @@ -77,12 +77,11 @@ def process_didemo_batch(batch, caption_sep = ';', device='cuda'): ) device = "cuda" if torch.cuda.is_available() else "cpu" - #TODO: update dis to work with ViT-H/14 - model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion400m_e32') - - model_text = model.encode_text + model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') + + model_text = model.encode_text model_video = Mean().to(device) ret_mets = multicaption_retrieval_evaluation(model_video, model_text, val_reader, segment=True, process_batch=process_didemo_batch) - print(next(iter(val_reader))) + print(ret_mets) From c067636a711074fb54a14c97487a89d9671805bb Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 21:09:27 -0700 Subject: [PATCH 04/11] update retrieval logic to deal with multicaption segmented, started working on processing didemo segments --- src/didemo_eval.py | 28 ++++++++++++++++++---------- src/evaluation/retrieval.py | 11 ++++++++--- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/didemo_eval.py b/src/didemo_eval.py index cc45712..fe6ef49 100644 --- a/src/didemo_eval.py +++ b/src/didemo_eval.py @@ -6,7 +6,7 @@ sys.path.insert(1, '/Users/daniel/Desktop/LAION_Videoclip/clip-video-encode') from clip_video_encode.dataset import EmbeddingWebDatasetReader -from evaluation.multicaption_retrieval import multicaption_retrieval_evaluation +from evaluation.retrieval import retrieval_evaluation eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/' @@ -30,14 +30,25 @@ def zero_pad(e, seq_len, model_dim=1024): return out +def process_didemo_segments(embeddings, segments, seq_len=200): + times_frames = [ + process_times(caption_segments[0], seq_len) + for caption_segments in segments + ] + + out_embeddings = torch.stack([ + zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) + for (start, end) in times_frames + ]) + + return out_embeddings + + def process_didemo_batch(batch, caption_sep = ';', device='cuda'): SEGMENT_KEY = 'times' embeddings = batch['embeddings'] - #print(embeddings.shape) + seq_len = embeddings.shape[1] - - #print(batch) - captions = [ text.split(caption_sep) for text in batch['text'] @@ -58,9 +69,6 @@ def process_didemo_batch(batch, caption_sep = ';', device='cuda'): for (start, end) in times_frames ]) - #print(times_frames) - #print(out_embeddings.shape) - #print(toks.shape) return out_embeddings, toks @@ -79,9 +87,9 @@ def process_didemo_batch(batch, caption_sep = ';', device='cuda'): device = "cuda" if torch.cuda.is_available() else "cpu" model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') - + model_text = model.encode_text model_video = Mean().to(device) - ret_mets = multicaption_retrieval_evaluation(model_video, model_text, val_reader, segment=True, process_batch=process_didemo_batch) + ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=True, segment_key='times', process_segments=process_didemo_segments) print(ret_mets) diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 49f976f..21fcb2e 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -2,8 +2,7 @@ import numpy as np import torch - -def retrieval_evaluation(model_video, model_text, data, multicaption=False): +def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False, segment_key=None, process_segments=None): if type(data) == dict: dataloader = data["val"].dataloader else: @@ -18,15 +17,21 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): embeddings = batch["embeddings"] toks = [] # TODO: does this require batch_size = 1 ?? - for cap in batch["text"]: + for cap_idx, cap in enumerate(batch["text"]): if multicaption: for c in cap.split(";"): # multiple captions separated by ; toks.append(open_clip.tokenize(c)) ground_truth.append(samp) + if segment: + samp += 1 + if segment: + segments = batch["meta"][segment_key] + embeddings = process_segments(embeddings, segments) else: toks.append(open_clip.tokenize(cap)) ground_truth.append(samp) samp += 1 + toks = torch.cat(toks) embeddings = embeddings.to(device, non_blocking=True) toks = toks.to(device, non_blocking=True) From ea7efc0d3f4688ff8ec74fa46752446200aa36c8 Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 21:21:26 -0700 Subject: [PATCH 05/11] first working iteration of didemo eval --- src/didemo_eval.py | 2 +- src/evaluation/retrieval.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/didemo_eval.py b/src/didemo_eval.py index fe6ef49..84933db 100644 --- a/src/didemo_eval.py +++ b/src/didemo_eval.py @@ -40,7 +40,7 @@ def process_didemo_segments(embeddings, segments, seq_len=200): zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) for (start, end) in times_frames ]) - + print(out_embeddings.shape) return out_embeddings diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 21fcb2e..34aa646 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -2,7 +2,7 @@ import numpy as np import torch -def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False, segment_key=None, process_segments=None): +def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False, process_segments=None): if type(data) == dict: dataloader = data["val"].dataloader else: @@ -17,7 +17,7 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm embeddings = batch["embeddings"] toks = [] # TODO: does this require batch_size = 1 ?? - for cap_idx, cap in enumerate(batch["text"]): + for cap in batch["text"]: if multicaption: for c in cap.split(";"): # multiple captions separated by ; toks.append(open_clip.tokenize(c)) @@ -25,13 +25,16 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm if segment: samp += 1 if segment: - segments = batch["meta"][segment_key] + segments = batch["meta"]["times"] # change to ...['segment'] embeddings = process_segments(embeddings, segments) + else: toks.append(open_clip.tokenize(cap)) ground_truth.append(samp) samp += 1 + #print(len(ground_truth)) + #print(embeddings.shape) toks = torch.cat(toks) embeddings = embeddings.to(device, non_blocking=True) toks = toks.to(device, non_blocking=True) @@ -58,7 +61,6 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_video.t().detach().cpu() - # TODO: let's to text2video correctly and then figure out how to do video2text # maybe video2text is average logits over multiple captions ''' From 9dd02b577cfc5706c924e71447cb733875c710ad Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 21:33:55 -0700 Subject: [PATCH 06/11] debugging eval --- src/didemo_eval.py | 2 +- src/evaluation/retrieval.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/didemo_eval.py b/src/didemo_eval.py index 84933db..8f3e413 100644 --- a/src/didemo_eval.py +++ b/src/didemo_eval.py @@ -91,5 +91,5 @@ def process_didemo_batch(batch, caption_sep = ';', device='cuda'): model_text = model.encode_text model_video = Mean().to(device) - ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=True, segment_key='times', process_segments=process_didemo_segments) + ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=True, process_segments=process_didemo_segments) print(ret_mets) diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 34aa646..12c07ab 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -16,6 +16,8 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm for i, batch in enumerate(dataloader): embeddings = batch["embeddings"] toks = [] + if i == 10: + break # TODO: does this require batch_size = 1 ?? for cap in batch["text"]: if multicaption: @@ -61,6 +63,9 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_video.t().detach().cpu() + logits_per_video = torch.randn_like(logits_per_video) + logits_per_text = torch.randn_like(logits_per_text) + # TODO: let's to text2video correctly and then figure out how to do video2text # maybe video2text is average logits over multiple captions ''' @@ -74,7 +79,7 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): # logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} logits = {"text_to_video": logits_per_text} ground_truth = torch.tensor(ground_truth).view(-1, 1) - + print(f'Num samples: {len(logits_per_text)}') for name, logit in logits.items(): ranking = torch.argsort(logit, descending=True) preds = torch.where(ranking == ground_truth)[1] From f9097c07badd1e10a291f8e32e670134335ce69d Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Thu, 20 Oct 2022 22:04:21 -0700 Subject: [PATCH 07/11] clean up prints --- src/didemo_eval.py | 37 ++++--------------------------------- src/evaluation/retrieval.py | 10 ++++------ 2 files changed, 8 insertions(+), 39 deletions(-) diff --git a/src/didemo_eval.py b/src/didemo_eval.py index 8f3e413..fa031f1 100644 --- a/src/didemo_eval.py +++ b/src/didemo_eval.py @@ -6,9 +6,8 @@ sys.path.insert(1, '/Users/daniel/Desktop/LAION_Videoclip/clip-video-encode') from clip_video_encode.dataset import EmbeddingWebDatasetReader -from evaluation.retrieval import retrieval_evaluation -eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/' +from evaluation.retrieval import retrieval_evaluation def process_times(times, len_embeddings): ''' @@ -40,39 +39,11 @@ def process_didemo_segments(embeddings, segments, seq_len=200): zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) for (start, end) in times_frames ]) - print(out_embeddings.shape) - return out_embeddings - - -def process_didemo_batch(batch, caption_sep = ';', device='cuda'): - SEGMENT_KEY = 'times' - embeddings = batch['embeddings'] - seq_len = embeddings.shape[1] - captions = [ - text.split(caption_sep) - for text in batch['text'] - ] - - times_frames = [ - process_times(caption_segments[0], seq_len) - for caption_segments in batch['meta'][SEGMENT_KEY] - ] # just take the first annotation for each caption - - toks = torch.stack([ - open_clip.tokenize(caption).to(device) - for caption in captions - ]).squeeze(0) - - out_embeddings = torch.stack([ - zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) - for (start, end) in times_frames - ]) - - return out_embeddings, toks - + return out_embeddings if __name__ == "__main__": + eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/' val_urls = eval_path + '{000000000..000000007}.tar' val_reader = EmbeddingWebDatasetReader( val_urls, @@ -92,4 +63,4 @@ def process_didemo_batch(batch, caption_sep = ';', device='cuda'): model_video = Mean().to(device) ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=True, process_segments=process_didemo_segments) - print(ret_mets) + print(ret_mets) \ No newline at end of file diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 12c07ab..6c3fab2 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -16,7 +16,7 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm for i, batch in enumerate(dataloader): embeddings = batch["embeddings"] toks = [] - if i == 10: + if i == 100: break # TODO: does this require batch_size = 1 ?? for cap in batch["text"]: @@ -35,8 +35,6 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm ground_truth.append(samp) samp += 1 - #print(len(ground_truth)) - #print(embeddings.shape) toks = torch.cat(toks) embeddings = embeddings.to(device, non_blocking=True) toks = toks.to(device, non_blocking=True) @@ -63,8 +61,8 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_video.t().detach().cpu() - logits_per_video = torch.randn_like(logits_per_video) - logits_per_text = torch.randn_like(logits_per_text) + # logits_per_video = torch.rand_like(logits_per_video) + # logits_per_text = torch.rand_like(logits_per_text) # TODO: let's to text2video correctly and then figure out how to do video2text # maybe video2text is average logits over multiple captions @@ -79,7 +77,7 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): # logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} logits = {"text_to_video": logits_per_text} ground_truth = torch.tensor(ground_truth).view(-1, 1) - print(f'Num samples: {len(logits_per_text)}') + # print(f'Num samples: {len(logits_per_text)}') for name, logit in logits.items(): ranking = torch.argsort(logit, descending=True) preds = torch.where(ranking == ground_truth)[1] From 82d1b7ca59858f0635c1941922ed8f1108dc48dd Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Fri, 21 Oct 2022 15:16:08 -0700 Subject: [PATCH 08/11] init test for vid2text mutlicaption --- src/didemo_eval.py | 4 ++-- src/evaluation/retrieval.py | 20 +++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/didemo_eval.py b/src/didemo_eval.py index fa031f1..3a3fda5 100644 --- a/src/didemo_eval.py +++ b/src/didemo_eval.py @@ -33,7 +33,7 @@ def process_didemo_segments(embeddings, segments, seq_len=200): times_frames = [ process_times(caption_segments[0], seq_len) for caption_segments in segments - ] + ] # taking only the first annotated segment out_embeddings = torch.stack([ zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) @@ -62,5 +62,5 @@ def process_didemo_segments(embeddings, segments, seq_len=200): model_text = model.encode_text model_video = Mean().to(device) - ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=True, process_segments=process_didemo_segments) + ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=False, process_segments=process_didemo_segments) print(ret_mets) \ No newline at end of file diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 6c3fab2..1065423 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -14,10 +14,10 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm with torch.no_grad(): for i, batch in enumerate(dataloader): + if i==3: + break embeddings = batch["embeddings"] toks = [] - if i == 100: - break # TODO: does this require batch_size = 1 ?? for cap in batch["text"]: if multicaption: @@ -29,7 +29,6 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm if segment: segments = batch["meta"]["times"] # change to ...['segment'] embeddings = process_segments(embeddings, segments) - else: toks.append(open_clip.tokenize(cap)) ground_truth.append(samp) @@ -44,19 +43,26 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm all_video_features.append(video_embeddings.cpu()) all_text_features.append(text_embeddings.cpu()) - + val_metrics = get_metrics( video_features=torch.cat(all_video_features), text_features=torch.cat(all_text_features), ground_truth=ground_truth, logit_scale=100.0, + multicaption=multicaption ) return val_metrics -def get_metrics(video_features, text_features, ground_truth, logit_scale): +def get_metrics(video_features, text_features, ground_truth, logit_scale, multicaption = False): metrics = {} + print(video_features.shape) + print(text_features.shape) + if multicaption: + video_features = torch.stack([video_features[samp] for samp in ground_truth]) + + print(video_features.shape) video_features = video_features.float() logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_video.t().detach().cpu() @@ -74,8 +80,8 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): logits_per_video = avg_per_20 ''' - # logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} - logits = {"text_to_video": logits_per_text} + logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} + #logits = {"text_to_video": logits_per_text} ground_truth = torch.tensor(ground_truth).view(-1, 1) # print(f'Num samples: {len(logits_per_text)}') for name, logit in logits.items(): From e40abfd88d023bf5d6e2b7a2c2be3c09247f6e4a Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:06:27 -0700 Subject: [PATCH 09/11] undo multicaption, moving to separate PR --- src/evaluation/retrieval.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index 1065423..d7436d3 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -49,20 +49,13 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm text_features=torch.cat(all_text_features), ground_truth=ground_truth, logit_scale=100.0, - multicaption=multicaption ) return val_metrics -def get_metrics(video_features, text_features, ground_truth, logit_scale, multicaption = False): +def get_metrics(video_features, text_features, ground_truth, logit_scale): metrics = {} - print(video_features.shape) - print(text_features.shape) - if multicaption: - video_features = torch.stack([video_features[samp] for samp in ground_truth]) - - print(video_features.shape) video_features = video_features.float() logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_video.t().detach().cpu() @@ -80,8 +73,8 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale, multic logits_per_video = avg_per_20 ''' - logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} - #logits = {"text_to_video": logits_per_text} + # logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} + logits = {"text_to_video": logits_per_text} ground_truth = torch.tensor(ground_truth).view(-1, 1) # print(f'Num samples: {len(logits_per_text)}') for name, logit in logits.items(): From e4fb4d5abee50663775229dd7821f76649c7f6a8 Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:26:41 -0700 Subject: [PATCH 10/11] remove need for process_segments in eval script --- src/evaluation/retrieval.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index d7436d3..cd790e5 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -2,7 +2,7 @@ import numpy as np import torch -def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False, process_segments=None): +def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False): if type(data) == dict: dataloader = data["val"].dataloader else: @@ -24,11 +24,14 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm for c in cap.split(";"): # multiple captions separated by ; toks.append(open_clip.tokenize(c)) ground_truth.append(samp) - if segment: - samp += 1 + samp += segment + if segment: - segments = batch["meta"]["times"] # change to ...['segment'] - embeddings = process_segments(embeddings, segments) + segments = batch["meta"]["segment"] + for idx, segment in enumerate(segments): + start_frame, end_frame = segment + embeddings[idx] = embeddings[idx][start_frame:end_frame] + else: toks.append(open_clip.tokenize(cap)) ground_truth.append(samp) From 4b03cc47a8193c8dba3ac549c05387cbdb76f555 Mon Sep 17 00:00:00 2001 From: danielmend <31258255+danielmend@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:29:52 -0700 Subject: [PATCH 11/11] zero pad segments --- src/evaluation/retrieval.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/evaluation/retrieval.py b/src/evaluation/retrieval.py index cd790e5..56eace3 100644 --- a/src/evaluation/retrieval.py +++ b/src/evaluation/retrieval.py @@ -30,7 +30,12 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False, segm segments = batch["meta"]["segment"] for idx, segment in enumerate(segments): start_frame, end_frame = segment - embeddings[idx] = embeddings[idx][start_frame:end_frame] + + # Zero pad our segmented + segmented_embedding = torch.zeros_like(embeddings[idx]) + segmented_embedding[start_frame:end_frame] = embeddings[idx][start_frame:end_frame] + + embeddings[idx] = segmented_embedding else: toks.append(open_clip.tokenize(cap))