Skip to content

Commit 8e6902c

Browse files
committed
fix
1 parent c5b0882 commit 8e6902c

File tree

2 files changed

+32
-38
lines changed

2 files changed

+32
-38
lines changed

colossalai/checkpoint_io/distributed_checkpoint_utils.py

-33
Original file line numberDiff line numberDiff line change
@@ -35,39 +35,6 @@ def RestoreDefaultStateDictBehavior(model):
3535
finally:
3636
for module, original_method in original_methods.items():
3737
module._save_to_state_dict, module._load_from_state_dict = original_method
38-
39-
40-
41-
def create_model_metadata(
42-
model: ModelWrapper,
43-
prefix: str = "",
44-
tp_size: int = None,
45-
tp_rank: int = None,
46-
zero_size: int = None,
47-
zero_rank: int = None,
48-
):
49-
param_origin_shape = model.param_origin_shape
50-
model = model.unwrap()
51-
model_metadata = {}
52-
for name, param in model.named_parameters():
53-
if param is None:
54-
continue
55-
model_metadata[prefix + name] = {}
56-
original_shape = param_origin_shape[name]
57-
tp_partition_dim = search_tp_partition_dim(
58-
current_shape=param.shape, original_shape=original_shape, tp_size=tp_size
59-
)
60-
model_metadata[prefix + name]["offsets"] = [0] * len(original_shape)
61-
model_metadata[prefix + name]["lengths"] = list(param.shape)
62-
model_metadata[prefix + name]["global_shape"] = list(original_shape)
63-
if tp_partition_dim is not None:
64-
partition_size = param.shape[tp_partition_dim]
65-
model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * tp_rank
66-
if tp_rank == tp_size - 1:
67-
model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (
68-
partition_size * (tp_size - 1)
69-
)
70-
return model_metadata
7138

7239

7340
def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None):

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
2727

2828
from .distributed_checkpoint_utils import (
29-
create_model_metadata,
3029
is_pytorch_model_meta_dist_file,
3130
load_dist_model,
3231
save_metadata,
@@ -216,6 +215,34 @@ def _optimizer_sharder(
216215
# Return the last block in sharder.
217216
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
218217

218+
def create_model_metadata(
219+
self,
220+
model: ModelWrapper,
221+
prefix: str = "",
222+
):
223+
param_origin_shape = model.param_origin_shape
224+
model = model.unwrap()
225+
model_metadata = {}
226+
for name, param in model.named_parameters():
227+
if param is None:
228+
continue
229+
model_metadata[prefix + name] = {}
230+
original_shape = param_origin_shape[name]
231+
tp_partition_dim = search_tp_partition_dim(
232+
current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size
233+
)
234+
model_metadata[prefix + name]["offsets"] = [0] * len(original_shape)
235+
model_metadata[prefix + name]["lengths"] = list(param.shape)
236+
model_metadata[prefix + name]["global_shape"] = list(original_shape)
237+
if tp_partition_dim is not None:
238+
partition_size = param.shape[tp_partition_dim]
239+
model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank
240+
if self.tp_rank == self.tp_size - 1:
241+
model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (
242+
partition_size * (self.tp_size - 1)
243+
)
244+
return model_metadata
245+
219246
def save_sharded_model(
220247
self,
221248
model: ModelWrapper,
@@ -253,7 +280,7 @@ def save_sharded_model(
253280
model_metadata = None
254281
if not gather_dtensor:
255282
# Manage filenames of sharded weights and index file for each pipeline stage.
256-
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
283+
model_metadata = self.create_model_metadata(model)
257284

258285
model = model.unwrap()
259286

@@ -409,7 +436,7 @@ def load_sharded_model(
409436
model._force_wait_all_gather()
410437

411438
if is_pytorch_model_meta_dist_file(checkpoint_index_file):
412-
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
439+
model_metadata = self.create_model_metadata(model)
413440
checkpoint = checkpoint_index_file.parent
414441
state_dict = load_dist_model(
415442
model_metadata=model_metadata,
@@ -817,7 +844,7 @@ def save_unsharded_model(
817844
if not gather_dtensor:
818845
dist_id = self.tp_size * self.pp_rank + self.tp_rank
819846
Path(checkpoint).mkdir(parents=True, exist_ok=True)
820-
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
847+
model_metadata = self.create_model_metadata(model)
821848
checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin")
822849
if use_async:
823850
checkpoint_file = checkpoint_file.replace(".bin", f".safetensors")
@@ -903,7 +930,7 @@ def load_unsharded_model(
903930

904931
model_metadata = None # used for dist model
905932
if load_dtensor:
906-
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
933+
model_metadata = self.create_model_metadata(model)
907934

908935
strict = False
909936
model_before_wrapping = model

0 commit comments

Comments
 (0)