forked from ilsenatorov/rindti
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
91 lines (74 loc) · 2.66 KB
/
Copy pathtrain.py
File metadata and controls
91 lines (74 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import random
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, RichModelSummary, RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from rindti.data import DTIDataModule
from rindti.models import ClassificationModel, RegressionModel
from rindti.utils import get_git_hash, read_config
models = {
"class": ClassificationModel,
"reg": RegressionModel,
}
def train(**kwargs):
"""Train the whole model"""
seed_everything(kwargs["seed"])
seeds = random.sample(range(1, 100), kwargs["runs"])
folder = os.path.join(
"tb_logs",
f"dti_{kwargs['datamodule']['exp_name']}",
f"{kwargs['datamodule']['filename'].split('/')[-1].split('.')[0]}",
)
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
if len(os.listdir(folder)) == 0:
next_version = 0
else:
next_version = str(
int(
[d for d in os.listdir(folder) if "version" in d and os.path.isdir(os.path.join(folder, d))][-1].split(
"_"
)[1]
)
+ 1
)
for i, seed in enumerate(seeds):
print(f"Run {i+1} of {kwargs['runs']} with seed {seed}")
kwargs["seed"] = seed
single_run(folder, next_version, **kwargs)
def single_run(folder, version, **kwargs):
"""Does a single run."""
seed_everything(kwargs["seed"])
datamodule = DTIDataModule(**kwargs["datamodule"])
datamodule.setup()
datamodule.update_config(kwargs)
logger = TensorBoardLogger(
save_dir=folder,
name=f"version_{version}",
version=kwargs["seed"],
default_hp_metric=False,
)
callbacks = [
ModelCheckpoint(monitor=kwargs["model"]["monitor"], save_top_k=3, mode="min"),
EarlyStopping(monitor=kwargs["model"]["monitor"], mode="min", **kwargs["early_stop"]),
RichModelSummary(),
RichProgressBar(),
]
trainer = Trainer(
callbacks=callbacks,
logger=logger,
log_every_n_steps=25,
enable_model_summary=False,
**kwargs["trainer"],
)
model = models[kwargs["model"]["module"]](**kwargs)
trainer.fit(model, datamodule)
trainer.test(model, datamodule)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(prog="Model Trainer")
parser.add_argument("config", type=str, help="Path to YAML config file")
args = parser.parse_args()
orig_config = read_config(args.config)
orig_config["git_hash"] = get_git_hash() # to know the version of the code
train(**orig_config)