-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
115 lines (94 loc) · 3.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from argparse import ArgumentParser
import json
import os
import random
from configs.config import Config
from misc.logger import get_logger
from misc.global_counter import init_counter
from datasets.gqa import run_gqa
def set_seed(seed: int = 42) -> None:
"""set a seed integer for random sampling algorithms
Args:
seed (int, optional): Defaults to 42.
"""
random.seed(seed)
def get_args_parser() -> ArgumentParser:
"""define arguments
Returns:
argparse.ArgumentParser:
"""
parser = ArgumentParser("", add_help=False)
parser.add_argument(
"--config",
metavar="PATH",
default="./configs/config_default.json",
help="path to config",
)
return parser
def load_json(path: str) -> dict:
"""wrapper to load json files
Args:
path (str)
Returns:
dict
"""
return json.load(open(path))
def main(args: ArgumentParser) -> None:
"""loads data and triggers generation for all dataset splits
Args:
args (ArgumentParser):
"""
cfg = Config(args.config)
logger = get_logger(cfg)
logger.info(cfg)
cpt_counter = init_counter()
set_seed(seed=42)
# load data
logger.info("load data")
train_set = load_json(cfg.get("gqa_sg_train"))
valid_set = load_json(cfg.get("gqa_sg_valid"))
if ("val" in cfg.get("generate_captions")) or (
"val_subset" in cfg.get("generate_captions")
):
logger.info("generate validation captions")
# run dataset creation for val split
run_subset = True if "val_subset" in cfg.get("generate_captions") else False
valid_results = run_gqa(
valid_set,
cfg,
cpt_counter,
logger,
cfg.get("filter_noisy"),
cfg.get("relaxed_mode"),
run_subset,
)
if cfg.get("save_results"):
file_name = cfg.get("output_filename")
os.makedirs("./results", exist_ok=True)
logger.info(f"save validation captions: ./results/{file_name}_val.json")
with open(f"./results/{file_name}_val.json", "w") as outfile:
json.dump(valid_results, outfile)
if "train" in cfg.get("generate_captions"):
logger.info("generate training captions")
# run dataset creation for train split
train_results = run_gqa(
train_set,
cfg,
cpt_counter,
logger,
cfg.get("filter_noisy"),
cfg.get("relaxed_mode"),
)
if cfg.get("save_results"):
file_name = cfg.get("output_filename")
os.makedirs("./results", exist_ok=True)
logger.info(f"save training captions: ./results/{file_name}_train.json")
with open(f"./results/{file_name}_train.json", "w") as outfile:
json.dump(train_results, outfile)
logger.info("generation process completed")
if __name__ == "__main__":
parser = ArgumentParser(
"Vision and language mismatch detection", parents=[get_args_parser()]
)
args = parser.parse_args()
main(args)