-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
139 lines (110 loc) · 3.98 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from pathlib import Path
from typing import List
import os
import wandb
import omegaconf
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.utils import instantiate, log
from hydra.core.hydra_config import HydraConfig
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import seed_everything, Callback
from source.common.utils import build_callbacks, log_hyperparameters, PROJECT_ROOT
import warnings
warnings.simplefilter("ignore", UserWarning)
def run(cfg: DictConfig) -> None:
"""
Generic train loop
:param cfg: run configuration, defined by Hydra in /conf
"""
""" Set up the seeds """
if cfg.train.deterministic:
seed_everything(cfg.train.random_seed)
""" Hydra run directory """
if HydraConfig.get().mode.name == "MULTIRUN":
hydra_dir = Path(
HydraConfig.get().sweep.dir + "/" + HydraConfig.get().sweep.subdir
)
os.chdir("../" + HydraConfig.get().sweep.subdir)
else:
hydra_dir = Path(HydraConfig.get().run.dir)
log.info(f"Saving os.getcwd is <{os.getcwd()}>")
log.info(f"Saving hydra_dir is <{hydra_dir}>")
""" Instantiate datamodule """
log.info(f"Instantiating <{cfg.data.datamodule._target_}>")
datamodule: pl.LightningDataModule = instantiate(
cfg.data.datamodule,
_recursive_=False,
)
""" Instantiate model """
log.info(f"Instantiating <{cfg.model._target_}>")
model: pl.LightningModule = instantiate(
cfg.model,
optim=cfg.optim,
data=cfg.data,
logging=cfg.logging,
train=cfg.train,
_recursive_=False,
)
""" Instantiate the callbacks """
callbacks: List[Callback] = build_callbacks(cfg=cfg)
""" Logger instantiation/configuration """
wandb_logger = None
if "wandb" in cfg.logging:
log.info("Instantiating <WandbLogger>")
wandb_config = cfg.logging.wandb
wandb_logger = WandbLogger(
**wandb_config,
save_dir=hydra_dir,
tags=cfg.core.tags,
)
log.info("W&B is now watching <{cfg.logging.wandb_watch.log}>!")
wandb_logger.watch(
model,
log=cfg.logging.wandb_watch.log,
log_freq=cfg.logging.wandb_watch.log_freq,
)
""" Store the YaML config separately into the wandb dir """
yaml_conf: str = OmegaConf.to_yaml(cfg=cfg)
(hydra_dir / "hparams.yaml").write_text(yaml_conf)
""" Trainer instantiation """
log.info("Instantiating the Trainer")
trainer = pl.Trainer(
accelerator="auto",
default_root_dir=hydra_dir,
logger=wandb_logger,
callbacks=callbacks,
deterministic=cfg.train.deterministic,
# check_val_every_n_epoch=cfg.logging.val_check_interval,
log_every_n_steps=1,
**cfg.train.pl_trainer, # max_steps 포함
)
log_hyperparameters(trainer=trainer, model=model, cfg=cfg)
""" Data preparation """
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
val_data_loader = datamodule.val_dataloader()
test_data_loader = datamodule.test_dataloader()
""" Train and test"""
log.info("Starting training!")
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=val_data_loader,
)
log.info("Starting testing!")
trainer.test(model=model, dataloaders=test_data_loader, ckpt_path="best")
if datamodule.rotate:
model.rotate = True
log.info("Starting testing on rotated data!")
test_rotate_dataloader = datamodule.test_rotate_dataloader()
trainer.test(model=model, dataloaders=test_rotate_dataloader, ckpt_path="best")
# Logger closing to release resources/avoid multi-run conflicts
if wandb_logger is not None:
wandb.finish()
@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
def main(cfg: omegaconf.DictConfig):
run(cfg)
if __name__ == "__main__":
main()