-
Notifications
You must be signed in to change notification settings - Fork 4
Dm/didemo eval #45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Dm/didemo eval #45
Changes from all commits
f71d047
806aafa
f6e2fe1
eb67689
c067636
ea7efc0
9dd02b5
f9097c0
82d1b7c
e40abfd
e4fb4d5
4b03cc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,3 +127,4 @@ dmypy.json | |
|
|
||
| # Pyre type checker | ||
| .pyre/ | ||
| CLIP-DiDeMo/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| from aggregation.mean import Mean | ||
| import sys | ||
| import os | ||
| import open_clip | ||
| import torch | ||
|
|
||
| sys.path.insert(1, '/Users/daniel/Desktop/LAION_Videoclip/clip-video-encode') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleanup |
||
| from clip_video_encode.dataset import EmbeddingWebDatasetReader | ||
|
|
||
| from evaluation.retrieval import retrieval_evaluation | ||
|
|
||
| def process_times(times, len_embeddings): | ||
| ''' | ||
| Assumptions: | ||
| - times contains [start, end] in intervals of 5 seconds, | ||
| i.e. [0, 1] corresponds to [0*5, (1+1)*5] | ||
|
|
||
| - there's 1 embedding per second, i.e. embeddings[10] is the sole embedding for the 10th second | ||
| ''' | ||
| SEGMENT_INTERVAL = 5 | ||
| return [ | ||
| SEGMENT_INTERVAL * times[0], | ||
| min( SEGMENT_INTERVAL * (times[1] + 1), len_embeddings ) | ||
| ] | ||
|
|
||
| def zero_pad(e, seq_len, model_dim=1024): | ||
| out = torch.zeros(size=(seq_len, model_dim)) | ||
| out[0:len(e)] = e | ||
|
|
||
| return out | ||
|
|
||
| def process_didemo_segments(embeddings, segments, seq_len=200): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no function, no specific code for each dataset, just move these operations into the "if "segments" in batch["meta"]" part of the retrieval eval |
||
| times_frames = [ | ||
| process_times(caption_segments[0], seq_len) | ||
| for caption_segments in segments | ||
| ] # taking only the first annotated segment | ||
|
|
||
| out_embeddings = torch.stack([ | ||
| zero_pad(embeddings[:, start:end, :].squeeze(0), seq_len) | ||
| for (start, end) in times_frames | ||
| ]) | ||
|
|
||
| return out_embeddings | ||
|
|
||
| if __name__ == "__main__": | ||
| eval_path = '/Users/daniel/Documents/GitHub/temporal-embedding-aggregation/CLIP-DiDeMo/data/oc_h14/test/' | ||
| val_urls = eval_path + '{000000000..000000007}.tar' | ||
| val_reader = EmbeddingWebDatasetReader( | ||
| val_urls, | ||
| standard_seq_len=200, | ||
| batch_size=1, | ||
| num_prepro_workers=0, | ||
| to_tensor=False, | ||
| enable_text=True, | ||
| enable_meta=True | ||
| ) | ||
|
|
||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') | ||
|
|
||
| model_text = model.encode_text | ||
| model_video = Mean().to(device) | ||
|
|
||
| ret_mets = retrieval_evaluation(model_video, model_text, val_reader, multicaption=True, segment=False, process_segments=process_didemo_segments) | ||
| print(ret_mets) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,8 +2,7 @@ | |
| import numpy as np | ||
| import torch | ||
|
|
||
|
|
||
| def retrieval_evaluation(model_video, model_text, data, multicaption=False): | ||
| def retrieval_evaluation(model_video, model_text, data, multicaption=False, segment=False): | ||
| if type(data) == dict: | ||
| dataloader = data["val"].dataloader | ||
| else: | ||
|
|
@@ -15,6 +14,8 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): | |
|
|
||
| with torch.no_grad(): | ||
| for i, batch in enumerate(dataloader): | ||
| if i==3: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleanup |
||
| break | ||
| embeddings = batch["embeddings"] | ||
| toks = [] | ||
| # TODO: does this require batch_size = 1 ?? | ||
|
|
@@ -23,10 +24,24 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): | |
| for c in cap.split(";"): # multiple captions separated by ; | ||
| toks.append(open_clip.tokenize(c)) | ||
| ground_truth.append(samp) | ||
| samp += segment | ||
|
|
||
| if segment: | ||
| segments = batch["meta"]["segment"] | ||
| for idx, segment in enumerate(segments): | ||
| start_frame, end_frame = segment | ||
|
|
||
| # Zero pad our segmented | ||
| segmented_embedding = torch.zeros_like(embeddings[idx]) | ||
| segmented_embedding[start_frame:end_frame] = embeddings[idx][start_frame:end_frame] | ||
|
|
||
| embeddings[idx] = segmented_embedding | ||
|
|
||
| else: | ||
| toks.append(open_clip.tokenize(cap)) | ||
| ground_truth.append(samp) | ||
| samp += 1 | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleanup |
||
| toks = torch.cat(toks) | ||
| embeddings = embeddings.to(device, non_blocking=True) | ||
| toks = toks.to(device, non_blocking=True) | ||
|
|
@@ -36,7 +51,7 @@ def retrieval_evaluation(model_video, model_text, data, multicaption=False): | |
|
|
||
| all_video_features.append(video_embeddings.cpu()) | ||
| all_text_features.append(text_embeddings.cpu()) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleanup |
||
| val_metrics = get_metrics( | ||
| video_features=torch.cat(all_video_features), | ||
| text_features=torch.cat(all_text_features), | ||
|
|
@@ -53,7 +68,9 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): | |
| logits_per_video = (logit_scale * video_features @ text_features.t()).detach().cpu() | ||
| logits_per_text = logits_per_video.t().detach().cpu() | ||
|
|
||
|
|
||
| # logits_per_video = torch.rand_like(logits_per_video) | ||
| # logits_per_text = torch.rand_like(logits_per_text) | ||
|
|
||
| # TODO: let's to text2video correctly and then figure out how to do video2text | ||
| # maybe video2text is average logits over multiple captions | ||
| ''' | ||
|
|
@@ -67,7 +84,7 @@ def get_metrics(video_features, text_features, ground_truth, logit_scale): | |
| # logits = {"video_to_text": logits_per_video, "text_to_video": logits_per_text} | ||
| logits = {"text_to_video": logits_per_text} | ||
| ground_truth = torch.tensor(ground_truth).view(-1, 1) | ||
|
|
||
| # print(f'Num samples: {len(logits_per_text)}') | ||
| for name, logit in logits.items(): | ||
| ranking = torch.argsort(logit, descending=True) | ||
| preds = torch.where(ranking == ground_truth)[1] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup