Skip to content

Commit 0f33a72

Browse files
author
DavidHuji
committedNov 1, 2022
cosmetics
1 parent 1356575 commit 0f33a72

File tree

2 files changed

+63
-21
lines changed

2 files changed

+63
-21
lines changed
 

‎README.md

+10-21
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ As shown in the paper, CapDec achieves SOTA image-captioning in the setting of t
1212
This is the formal repository for CapDec, in which you can easily reproduce the papers results.
1313

1414
## FlickrStyle7k Examples
15+
Example for styled captions of CapDec on FlickrStyle10K dataset.
1516
![alt text](https://github.com/DavidHuji/CapDec/blob/main/examples.png)
1617

1718

@@ -28,38 +29,26 @@ conda env create -f others/environment.yml
2829
conda activate CapDec
2930
```
3031

31-
## Download Data
32-
###COCO: Download [train_captions](https://drive.google.com/file/d/1D3EzUK1d1lNhD2hAvRiKPThidiVbP2K_/view?usp=sharing) to `data/coco/annotations`.
33-
34-
Download [training images](http://images.cocodataset.org/zips/train2014.zip) and [validation images](http://images.cocodataset.org/zips/val2014.zip) and unzip (We use Karpathy et el. split).
35-
### Flickr
36-
TBD
37-
### Flickr7KStyle
38-
TBD
32+
# Datasets
33+
1. Download the datasets using the following links: [COCO](https://www.kaggle.com/datasets/shtvkumar/karpathy-splits), [Flickr30K](https://www.kaggle.com/datasets/shtvkumar/karpathy-splits), [FlickrStyle10k](https://zhegan27.github.io/Papers/FlickrStyle_v0.9.zip).
34+
2. Parse the data to the correct format using our script parse_karpathy.py, just make sure to edit head the json paths inside the script.
3935

4036

4137
#Training
42-
Extract CLIP features using:
38+
Make sure to edit head the json or pkl paths inside the scripts.
39+
1. Extract CLIP features using the following script:
4340
```
4441
python embeddings_generator.py -h
4542
```
46-
Train with fine-tuning of GPT2:
47-
```
48-
python train.py --data ./data/coco/oscar_split_ViT-B_32_train.pkl --out_dir ./coco_train/
49-
```
5043

51-
Train only transformer mapping network:
44+
2. Training the model using the following script:
5245
```
53-
python train.py --only_prefix --data ./data/coco/oscar_split_ViT-B_32_train.pkl --out_dir ./coco_train/ --mapping_type transformer --num_layres 8 --prefix_length 40 --prefix_length_clip 40
46+
python train.py --data clip_embeddings_of_last_stage.pkl --out_dir ./coco_train/
5447
```
5548

56-
**If you wish to use ResNet-based CLIP:**
57-
58-
```
59-
python parse_coco.py --clip_model_type RN50x4
60-
```
49+
**There are a few interesting configurable parameters for training as follows:**
6150
```
62-
python train.py --only_prefix --data ./data/coco/oscar_split_RN50x4_train.pkl --out_dir ./coco_train/ --mapping_type transformer --num_layres 8 --prefix_length 40 --prefix_length_clip 40 --is_rn
51+
output of train.py -h
6352
```
6453

6554
# Evaluation

‎parse_karpathy.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pickle, json
2+
3+
kagle_json = 'annotations/dataset_coco_from_kaggle.json'
4+
new_json_train = 'post_processed_karpthy_coco/train.json'
5+
new_json_test = 'post_processed_karpthy_coco/test.json'
6+
new_json_val = 'post_processed_karpthy_coco/val.json'
7+
8+
9+
def map_format_kaggle_to_clipcap():
10+
def extract_imgid_from_name(filename):
11+
return str(int(filename.split('.')[0].split('_')[-1]))
12+
13+
with open(kagle_json) as f:
14+
kaggle_data = json.load(f)
15+
train_data = []
16+
test_data = []
17+
val_data = []
18+
splits = {'train': train_data, 'test': test_data, 'val': val_data, 'restval': train_data}
19+
out_names = {'train': new_json_train, 'test': new_json_test, 'val': new_json_val}
20+
for img in kaggle_data['images']:
21+
imgid = extract_imgid_from_name(img['filename'])
22+
for cap in img['sentences']:
23+
correct_format = {"image_id": int(imgid), "caption": cap['raw'], "id": int(cap['sentid'])}
24+
splits[img['split']].append(correct_format)
25+
26+
DBG = False
27+
if not DBG:
28+
for name in out_names:
29+
with open(out_names[name], 'w') as f:
30+
json.dump(splits[name], f)
31+
32+
for name in out_names:
33+
with open(out_names[name][:-5] + '_metrics_format.json', 'w') as f:
34+
annos = splits[name]
35+
ids = [{"id": int(a["image_id"])} for a in annos]
36+
final = {"images": ids, "annotations": annos}
37+
json.dump(final, f)
38+
39+
if DBG:
40+
# rons annotations
41+
with open('annotations/train_caption_of_real_training.json') as f:
42+
# with open('../../train_caption.json') as f:
43+
cur_data = json.load(f)
44+
ids = [str(int(c['image_id'])) for c in cur_data]
45+
new_ids = [str(int(c['image_id'])) for c in train_data]
46+
ids.sort() # inplace
47+
new_ids.sort()
48+
assert ids == new_ids
49+
print('OK')
50+
51+
52+
if __name__ == '__main__':
53+
map_format_kaggle_to_clipcap()

0 commit comments

Comments
 (0)
Please sign in to comment.