-
Notifications
You must be signed in to change notification settings - Fork 142
Expand file tree
/
Copy pathconvert.py
More file actions
111 lines (93 loc) · 3.28 KB
/
convert.py
File metadata and controls
111 lines (93 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import copy
import importlib
import sys
import torch.multiprocessing as mp
from utils import validate_args
def load_plugin(plugin_type, name):
module_name = f"{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError as e:
print(e)
module_name = name
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError as e:
print(e)
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, "add_arguments"):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
print(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
parser = argparse.ArgumentParser(
description="Convert checkpoint", allow_abbrev=False, conflict_handler="resolve"
)
# convert args
parser.add_argument(
"--model-type",
type=str,
default=[],
nargs="+",
required=True,
choices=["aquila", "mistral", "mixtral", "llama", "deepseek_v3", "qwen3", "smollm2"],
help="Type of the model.",
)
parser.add_argument(
"--loader",
type=str,
default="mcore",
choices=["mcore", "transformers"],
help="Module name to load checkpoint, should be on python path",
)
parser.add_argument(
"--saver",
type=str,
default="mcore",
choices=["mcore", "transformers"],
help="Module name to save checkpoint, shdoul be on python path",
)
parser.add_argument(
"--load-dir", type=str, required=True, help="Directory to load model checkpoint from"
)
parser.add_argument(
"--save-dir", type=str, required=True, help="Directory to save model checkpoint to"
)
parser.add_argument(
"--max-queue-size", type=int, default=50, help="Maximum number of tensors in the queue"
)
extend_cases = [["mistral", "mixtral"]]
known_args, _ = parser.parse_known_args()
loader = load_plugin("loader", known_args.loader)
saver = load_plugin("saver", known_args.saver)
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
validate_args(args)
queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_args = copy.deepcopy(args)
if len(args.model_type) == 1:
saver_args.model_type = args.model_type[0]
elif len(args.model_type) == 2:
assert args.model_type in extend_cases, f"Only support extend cases are {extend_cases}"
saver_args.model_type = args.model_type[1]
else:
raise ValueError("")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, saver_args))
saver_proc.start()
print("Starting loader...")
loader_args = copy.deepcopy(args)
if len(args.model_type) == 1:
loader_args.model_type = args.model_type[0]
elif len(args.model_type) == 2:
assert args.model_type in extend_cases, f"Only support extend cases are {extend_cases}"
loader_args.model_type = args.model_type[0]
else:
raise ValueError("")
loader.load_checkpoint(queue, loader_args)
print("Waiting for saver to complete...")
saver_proc.join()
if __name__ == "__main__":
main()