6
6
path_to_data_imgs_val = ''
7
7
path_to_data_captions_val = ''
8
8
9
+ from custom_types import *
9
10
import torch
10
11
import torch .nn as nn
11
12
from torch .nn import functional as nnf
20
21
import json , math
21
22
from typing import Tuple , Optional , Union
22
23
from parse_coco import add_text_embedding , train_with_noise_data_augmentation
24
+ from PIL import Image
25
+ import clip
26
+ from gpt2_prefix_e2e import ClipCaptionE2E
27
+
23
28
24
29
device = torch .device ('cuda:0' )
25
30
@@ -308,7 +313,67 @@ def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
308
313
return model , parser
309
314
310
315
311
- def train (dataset : ClipCocoDataset , model : ClipCaptionModel , args ,
316
+ def train (data , model : ClipCaptionModel , out_path , tokenizer , args = None ):
317
+ device = CUDA (0 )
318
+ model = model .to (device ) #FIXME
319
+ model .eval () #FIXME
320
+ if args .is_rn :
321
+ clip_model , preprocess = clip .load ("RN50x4" , device = device , jit = False )
322
+ normalize = True
323
+ args .beam = True
324
+ else :
325
+ clip_model , preprocess = clip .load ("ViT-B/32" , device = device , jit = False )
326
+ normalize = False
327
+ # preprocess = clip_transform_full()
328
+ #prefix_length = 10
329
+
330
+ images_root = "/home/dcor/datasets/COCO/val2014"
331
+ if not os .path .isdir (images_root ):
332
+ images_root = "./data/coco/val2014"
333
+ embeddings = model .gpt .get_input_embeddings ().weight .data
334
+ embeddings = nnf .normalize (embeddings , 2 , 1 )
335
+ for ii , d in enumerate (data ):
336
+ #print(ii)
337
+ #if ii > 20:
338
+ # break
339
+
340
+ img_id = d ["image_id" ]
341
+ filename = f'{ images_root } /COCO_val2014_{ int (img_id ):012d} .jpg'
342
+ #print(filename)
343
+
344
+ image_raw = Image .open (filename ).convert ("RGB" )
345
+ image = preprocess (image_raw ).unsqueeze (0 ).to (device )
346
+ with torch .no_grad ():
347
+ if type (model ) is ClipCaptionE2E :
348
+ prefix_embed = model .forward_image (image )
349
+ else :
350
+ prefix = clip_model .encode_image (image ).to (device , dtype = torch .float32 )
351
+ if normalize :
352
+ prefix = prefix / prefix .norm (2 , - 1 )
353
+ prefix_embed = model .clip_project (prefix ).reshape (1 , args .prefix_length , - 1 )
354
+ if args .beam :
355
+ generated_text_prefix = generate_beam (model , tokenizer , embed = prefix_embed )[0 ]
356
+ else :
357
+ generated_text_prefix = generate2 (model , tokenizer , embed = prefix_embed )
358
+
359
+ print (img_id )
360
+ print (generated_text_prefix .lower ())
361
+ print (d ["caption" ])
362
+ if DEBUG :
363
+ prefix_sent = get_prefix_tokens (prefix_embed , embeddings , tokenizer )
364
+ imshow (image_raw , title = f'{ generated_text_prefix } \n { prefix_sent } ' )
365
+
366
+ d ["caption" ] = generated_text_prefix .lower ()
367
+
368
+ #sys.exit()
369
+ with open (out_path , 'w' ) as outfile :
370
+ json .dump (data , outfile )
371
+ print ("JSON is dumped" )
372
+
373
+ return 0
374
+
375
+
376
+ def regular_train (dataset : ClipCocoDataset , model : ClipCaptionModel , args ,
312
377
lr : float = 2e-5 , warmup_steps : int = 5000 , output_dir : str = "." , output_prefix : str = "" ):
313
378
314
379
device = torch .device ('cuda:0' )
@@ -322,7 +387,7 @@ def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
322
387
323
388
# save_config(args)
324
389
for epoch in range (epochs ):
325
- print (f">>> Training epoch { epoch } " )
390
+ print (f">>> calc predictions " )
326
391
sys .stdout .flush ()
327
392
progress = tqdm (total = len (train_dataloader ), desc = output_prefix )
328
393
for idx , (tokens , mask , prefix ) in enumerate (train_dataloader ):
0 commit comments