Skip to content

Commit 916ff36

Browse files
author
DavidHuji
committed
adding eval files from private git repo
1 parent 92d8460 commit 916ff36

6 files changed

+1070
-2
lines changed

custom_types.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as nnf
5+
import sys
6+
from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized
7+
from enum import Enum
8+
9+
10+
IS_WINDOWS = sys.platform == 'win32'
11+
get_trace = getattr(sys, 'gettrace', None)
12+
DEBUG = get_trace is not None and get_trace() is not None
13+
14+
15+
# if DEBUG:
16+
# seed = 99
17+
# torch.manual_seed(seed)
18+
# np.random.seed(seed)
19+
20+
N = type(None)
21+
V = np.array
22+
ARRAY = np.ndarray
23+
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
24+
VS = Union[Tuple[V, ...], List[V]]
25+
VN = Union[V, N]
26+
VNS = Union[VS, N]
27+
T = torch.Tensor
28+
TS = Union[Tuple[T, ...], List[T]]
29+
TN = Optional[T]
30+
TNS = Union[Tuple[TN, ...], List[TN]]
31+
TSN = Optional[TS]
32+
TA = Union[T, ARRAY]
33+
34+
35+
D = torch.device
36+
CPU = torch.device('cpu')
37+
38+
39+
def get_device(device_id: int) -> D:
40+
if not torch.cuda.is_available():
41+
return CPU
42+
device_id = min(torch.cuda.device_count() - 1, device_id)
43+
return torch.device(f'cuda:{device_id}')
44+
45+
46+
CUDA = get_device

dave_generte_captions_for_eval.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
path_to_data_imgs_val = ''
77
path_to_data_captions_val = ''
88

9+
from custom_types import *
910
import torch
1011
import torch.nn as nn
1112
from torch.nn import functional as nnf
@@ -20,6 +21,10 @@
2021
import json, math
2122
from typing import Tuple, Optional, Union
2223
from parse_coco import add_text_embedding, train_with_noise_data_augmentation
24+
from PIL import Image
25+
import clip
26+
from gpt2_prefix_e2e import ClipCaptionE2E
27+
2328

2429
device = torch.device('cuda:0')
2530

@@ -308,7 +313,67 @@ def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
308313
return model, parser
309314

310315

311-
def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
316+
def train(data, model: ClipCaptionModel, out_path, tokenizer, args=None):
317+
device = CUDA(0)
318+
model = model.to(device) #FIXME
319+
model.eval() #FIXME
320+
if args.is_rn:
321+
clip_model, preprocess = clip.load("RN50x4", device=device, jit=False)
322+
normalize = True
323+
args.beam = True
324+
else:
325+
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
326+
normalize = False
327+
# preprocess = clip_transform_full()
328+
#prefix_length = 10
329+
330+
images_root = "/home/dcor/datasets/COCO/val2014"
331+
if not os.path.isdir(images_root):
332+
images_root = "./data/coco/val2014"
333+
embeddings = model.gpt.get_input_embeddings().weight.data
334+
embeddings = nnf.normalize(embeddings, 2, 1)
335+
for ii, d in enumerate(data):
336+
#print(ii)
337+
#if ii > 20:
338+
# break
339+
340+
img_id = d["image_id"]
341+
filename = f'{images_root}/COCO_val2014_{int(img_id):012d}.jpg'
342+
#print(filename)
343+
344+
image_raw = Image.open(filename).convert("RGB")
345+
image = preprocess(image_raw).unsqueeze(0).to(device)
346+
with torch.no_grad():
347+
if type(model) is ClipCaptionE2E:
348+
prefix_embed = model.forward_image(image)
349+
else:
350+
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
351+
if normalize:
352+
prefix = prefix / prefix.norm(2, -1)
353+
prefix_embed = model.clip_project(prefix).reshape(1, args.prefix_length, -1)
354+
if args.beam:
355+
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
356+
else:
357+
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
358+
359+
print(img_id)
360+
print(generated_text_prefix.lower())
361+
print(d["caption"])
362+
if DEBUG:
363+
prefix_sent = get_prefix_tokens(prefix_embed, embeddings, tokenizer)
364+
imshow(image_raw, title=f'{generated_text_prefix}\n{prefix_sent}')
365+
366+
d["caption"] = generated_text_prefix.lower()
367+
368+
#sys.exit()
369+
with open(out_path, 'w') as outfile:
370+
json.dump(data, outfile)
371+
print("JSON is dumped")
372+
373+
return 0
374+
375+
376+
def regular_train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
312377
lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):
313378

314379
device = torch.device('cuda:0')
@@ -322,7 +387,7 @@ def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
322387

323388
# save_config(args)
324389
for epoch in range(epochs):
325-
print(f">>> Training epoch {epoch}")
390+
print(f">>> calc predictions")
326391
sys.stdout.flush()
327392
progress = tqdm(total=len(train_dataloader), desc=output_prefix)
328393
for idx, (tokens, mask, prefix) in enumerate(train_dataloader):

0 commit comments

Comments
 (0)