-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
115 lines (93 loc) · 3.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import cProfile
import io
import pstats
from pstats import SortKey
import glob
import torch
import torch.nn as nn
from pathlib import Path
import datetime
import os
import random
import numpy as np
from sophia import SophiaG
from Kidus import Kidus
# Save the model.
def save_model(epoch: int, model, optimizer, PATH: Path) -> None:
model_state_dict = {
"model": model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}
save_dir = PATH / f"{current()}-iter-{epoch}.tar"
torch.save(model_state_dict, str(save_dir))
def get_last_epoch(path: Path) -> int:
"""Get the last epoch and TAR file"""
files = glob.glob(f"{str(path)}/*")
if len(files) == 0:
return None
epochs = [get_epoch(filename) for filename in files]
return max(epochs)
def get_epoch(filename: str) -> int:
epoch = int(filename.split("/")[-1].split(".")[0].split("-")[-1])
return epoch
def prepare_for_resuming(path: Path, model_size:str, learning_rate:float, best=True, pretrain=True):
model = Kidus.from_name(model_size, pretrain=pretrain)
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95))
rho = 0.03 if pretrain else 0.01
weight_decay = 0.2 if pretrain else 0.1
optimizer = SophiaG(model.parameters(), lr=learning_rate, betas=(0.965, 0.99), rho=rho, weight_decay=weight_decay)
if best:
model_dirs = glob.glob(f"{str(path)}/*")
best_model_dir = sorted(model_dirs, key=get_epoch)[0]
model_state_dict = torch.load(best_model_dir)
else:
assert path.exists(), "Please Check the model is completley closed."
model_state_dict = torch.load(str(path))
state_dict = model_state_dict["model"]
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
if isinstance(model, nn.DataParallel):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
if pretrain:
optimizer.load_state_dict(model_state_dict["optimizer"])
start_epoch = model_state_dict["epoch"]
return model, optimizer, start_epoch
def load_model(model_path: Path, model_size:str, device):
if not os.path.isfile(str(model_path)):
model_path = f"{str(model_path.absolute())}/*"
model_dirs = glob.glob(model_path)
assert len(model_dirs) != 0, "There're no checkpoints in the directory specfied!."
model_path = sorted(model_dirs, key=get_epoch, reverse=True)[0]
model = Kidus.from_pretrained(model_size, model_path, device=device)
return model
def is_torch_2():
return torch.__version__[0] == "2"
def tokenizer_setting():
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def set_seed(seed=12346):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def profile(func):
def wrapper(*args, **kwargs):
pr = cProfile.Profile()
pr.enable()
retval = func(*args, **kwargs)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE # 'cumulative'
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
return retval
return wrapper
def current():
date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
return date