Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2a15001

Browse files
committedJan 16, 2025·
support distribute checkpoint io
1 parent 5b094a8 commit 2a15001

File tree

7 files changed

+781
-14
lines changed

7 files changed

+781
-14
lines changed
 

‎colossalai/booster/plugin/hybrid_parallel_plugin.py

+3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def __init__(
7878
self.require_grad_sync = True
7979
self.overlap_allgather = overlap_allgather
8080
self.use_fp8 = use_fp8
81+
self.param_origin_shape = {}
82+
for name, param in module.named_parameters():
83+
self.param_origin_shape[name] = param.shape
8184

8285
shardformer = ShardFormer(shard_config)
8386
if custom_policy is not None:

‎colossalai/checkpoint_io/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .checkpoint_io_base import CheckpointIO
22
from .general_checkpoint_io import GeneralCheckpointIO
33
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
4+
from.distributed_checkpoint_io import DistributedCheckpointIO
45
from .index_file import CheckpointIndexFile
56
from .moe_checkpoint import MoECheckpointIO
67

@@ -10,4 +11,5 @@
1011
"GeneralCheckpointIO",
1112
"HybridParallelCheckpointIO",
1213
"MoECheckpointIO",
14+
"DistributedCheckpointIO"
1315
]

‎colossalai/checkpoint_io/distributed_checkpoint_io.py

+633
Large diffs are not rendered by default.

‎colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _model_sharder(
126126
buffer = buf if keep_vars else buf.detach()
127127
if pinned_state_dicts is not None:
128128
if (prefix + name) not in pinned_state_dicts:
129-
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
129+
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
130130
pinned_state_dicts[prefix + name].copy_(buffer)
131131
buffer = pinned_state_dicts[prefix + name]
132132
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
@@ -142,7 +142,7 @@ def _model_sharder(
142142
extra_state = model.get_extra_state()
143143
if pinned_state_dicts is not None:
144144
if extra_state_key not in pinned_state_dicts:
145-
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
145+
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
146146
pinned_state_dicts[extra_state_key].copy_(extra_state)
147147
extra_state = pinned_state_dicts[extra_state_key]
148148
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
@@ -298,9 +298,9 @@ def save_sharded_model(
298298
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
299299

300300
# Manage filenames of sharded weights and index file for each pipeline stage.
301-
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
302-
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
303-
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
301+
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
302+
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
303+
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
304304
save_index_file = os.path.join("tmp_index_files", save_index_file)
305305
if use_async:
306306
total_size, writers = async_save_state_dict_shards(

‎colossalai/checkpoint_io/utils.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -854,14 +854,11 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
854854
# check if there is only one a file ending with .index.json in this directory
855855
index_files = list(checkpoint_path.glob("*.index.*json"))
856856

857-
# if we found a .index.json file, make sure there is only one
858-
if len(index_files) > 0:
859-
assert (
860-
len(index_files) == 1
861-
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
862-
863857
if len(index_files) == 1:
864858
return True, index_files[0]
859+
elif len(index_files) > 1:
860+
# Used for distributed checkpoint IO, where the metadata is stored across multiple files.
861+
return True, checkpoint_path
865862
else:
866863
return False, None
867864
else:
@@ -943,8 +940,8 @@ def get_shard_filename(weights_name: str, idx: int):
943940
"""
944941
get shard file name
945942
"""
946-
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
947-
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
943+
shard_file = weights_name.replace(".bin", f"-{idx:05d}.bin")
944+
shard_file = shard_file.replace(".safetensors", f"-{idx:05d}.safetensors")
948945
return shard_file
949946

950947

‎colossalai/shardformer/layer/parallel_module.py

-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def _load_from_state_dict(
120120
"received {}".format(key, type(input_param))
121121
)
122122
continue
123-
124123
if is_distributed_tensor(param):
125124
# shard the input param
126125
device_mesh = get_device_mesh(param)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
import torch
3+
import torch.distributed as dist
4+
from packaging.version import Version
5+
from torch.optim import Adam
6+
from utils import shared_tempdir
7+
from copy import deepcopy
8+
9+
import colossalai
10+
from colossalai.booster import Booster
11+
from colossalai.booster.plugin import HybridParallelPlugin
12+
from colossalai.shardformer.layer.utils import Randomizer
13+
from colossalai.tensor.d_tensor.api import clear_layout_converter
14+
from colossalai.checkpoint_io import DistributedCheckpointIO
15+
from colossalai.testing import (
16+
assert_close_loose,
17+
check_state_dict_equal,
18+
clear_cache_before_run,
19+
parameterize,
20+
rerun_if_address_is_in_use,
21+
spawn,
22+
)
23+
from tests.kit.model_zoo import model_zoo
24+
25+
26+
TEST_CONFIGS = [
27+
({"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
28+
{"tp_size": 2, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},)
29+
]
30+
31+
32+
@parameterize("shard", [False, True])
33+
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
34+
@parameterize("size_per_shard", [1])
35+
@parameterize("test_config", TEST_CONFIGS)
36+
@parameterize("use_async", [False, True])
37+
@parameterize("low_cpu_mem_mode", [False, True])
38+
@clear_cache_before_run()
39+
def exam_state_dict(
40+
shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool
41+
):
42+
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
43+
iter(model_zoo.get_sub_registry(model_name).values())
44+
)
45+
criterion = loss_fn
46+
test_config_0, test_config_1 = test_config
47+
plugin_0 = HybridParallelPlugin(**test_config_0)
48+
booster_0 = Booster(plugin=plugin_0)
49+
hybrid_ckp_0 = booster_0.checkpoint_io
50+
booster_0.checkpoint_io = DistributedCheckpointIO(hybrid_ckp_0.global_dp_group, hybrid_ckp_0.pp_group, hybrid_ckp_0.tp_group, hybrid_ckp_0.sp_group, hybrid_ckp_0.use_zero)
51+
52+
def _criterion(outputs, inputs):
53+
outputs = output_transform_fn(outputs)
54+
loss = criterion(outputs)
55+
return loss
56+
57+
def _preprocess_data(data):
58+
if booster_0.plugin.stage_manager is not None:
59+
for k, v in data.items():
60+
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
61+
new_shape = [1] * v.dim()
62+
new_shape[0] = 4
63+
data[k] = v.to("cuda").repeat(*new_shape)
64+
return iter([data])
65+
else:
66+
return {k: v.cuda() for k, v in data.items()}
67+
68+
model_0 = model_fn().cuda()
69+
optimizer_0 = Adam(model_0.parameters(), lr=1e-3)
70+
model_0, optimizer_0, criterion, _, _ = booster_0.boost(model_0, optimizer_0, criterion)
71+
72+
data = data_gen_fn()
73+
model_0.train()
74+
if booster_0.plugin.stage_manager is not None:
75+
booster_0.execute_pipeline(_preprocess_data(data), model_0, _criterion, optimizer_0, return_loss=True)
76+
else:
77+
output = model_0(**_preprocess_data(data))
78+
loss = criterion(output)
79+
optimizer_0.backward(loss)
80+
81+
optimizer_0.step()
82+
optimizer_0.zero_grad()
83+
with shared_tempdir() as tempdir:
84+
model_ckpt_path_0 = f"{tempdir}/model_0"
85+
86+
booster_0.save_model(model_0, model_ckpt_path_0, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
87+
booster_0.checkpoint_io._sync_d2h()
88+
booster_0.checkpoint_io._sync_io()
89+
dist.barrier()
90+
91+
plugin_1 = HybridParallelPlugin(**test_config_1)
92+
booster_1 = Booster(plugin=plugin_1)
93+
hybrid_ckp_1 = booster_1.checkpoint_io
94+
booster_1.checkpoint_io = DistributedCheckpointIO(hybrid_ckp_1.global_dp_group, hybrid_ckp_1.pp_group, hybrid_ckp_1.tp_group, hybrid_ckp_1.sp_group, hybrid_ckp_1.use_zero)
95+
96+
model_1 = model_fn().cuda()
97+
optimizer_1 = Adam(model_1.parameters(), lr=1e-3)
98+
model_1, optimizer_1, criterion, _, _ = booster_1.boost(model_1, optimizer_1, criterion)
99+
100+
booster_1.load_model(model_1, model_ckpt_path_0, low_cpu_mem_mode=low_cpu_mem_mode)
101+
102+
model_ckpt_path_1 = f"{tempdir}/model_1"
103+
booster_1.save_model(model_1, model_ckpt_path_1, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
104+
booster_1.checkpoint_io._sync_d2h()
105+
booster_1.checkpoint_io._sync_io()
106+
dist.barrier()
107+
108+
model_2 = model_fn().cuda()
109+
optimizer_2 = Adam(model_2.parameters(), lr=1e-3)
110+
model_2, optimizer_2, criterion, _, _ = booster_0.boost(model_2, optimizer_2, criterion)
111+
112+
booster_0.load_model(model_2, model_ckpt_path_1, low_cpu_mem_mode=low_cpu_mem_mode)
113+
check_state_dict_equal(model_0.unwrap().state_dict(), model_2.unwrap().state_dict())
114+
115+
dist.barrier()
116+
Randomizer.reset_index()
117+
clear_layout_converter()
118+
119+
120+
def run_dist(rank, world_size, port):
121+
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
122+
exam_state_dict()
123+
124+
125+
@pytest.mark.dist
126+
@pytest.mark.parametrize("world_size", [4])
127+
@rerun_if_address_is_in_use()
128+
def test_hybrid_ckpIO(world_size):
129+
spawn(run_dist, world_size)
130+
131+
132+
if __name__ == "__main__":
133+
test_hybrid_ckpIO(4)

0 commit comments

Comments
 (0)
Please sign in to comment.