|
26 | 26 | from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
27 | 27 |
|
28 | 28 | from .distributed_checkpoint_utils import (
|
29 |
| - create_model_metadata, |
30 | 29 | is_pytorch_model_meta_dist_file,
|
31 | 30 | load_dist_model,
|
32 | 31 | save_metadata,
|
@@ -216,6 +215,34 @@ def _optimizer_sharder(
|
216 | 215 | # Return the last block in sharder.
|
217 | 216 | yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
218 | 217 |
|
| 218 | + def create_model_metadata( |
| 219 | + self, |
| 220 | + model: ModelWrapper, |
| 221 | + prefix: str = "", |
| 222 | + ): |
| 223 | + param_origin_shape = model.param_origin_shape |
| 224 | + model = model.unwrap() |
| 225 | + model_metadata = {} |
| 226 | + for name, param in model.named_parameters(): |
| 227 | + if param is None: |
| 228 | + continue |
| 229 | + model_metadata[prefix + name] = {} |
| 230 | + original_shape = param_origin_shape[name] |
| 231 | + tp_partition_dim = search_tp_partition_dim( |
| 232 | + current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size |
| 233 | + ) |
| 234 | + model_metadata[prefix + name]["offsets"] = [0] * len(original_shape) |
| 235 | + model_metadata[prefix + name]["lengths"] = list(param.shape) |
| 236 | + model_metadata[prefix + name]["global_shape"] = list(original_shape) |
| 237 | + if tp_partition_dim is not None: |
| 238 | + partition_size = param.shape[tp_partition_dim] |
| 239 | + model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank |
| 240 | + if self.tp_rank == self.tp_size - 1: |
| 241 | + model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - ( |
| 242 | + partition_size * (self.tp_size - 1) |
| 243 | + ) |
| 244 | + return model_metadata |
| 245 | + |
219 | 246 | def save_sharded_model(
|
220 | 247 | self,
|
221 | 248 | model: ModelWrapper,
|
@@ -253,7 +280,7 @@ def save_sharded_model(
|
253 | 280 | model_metadata = None
|
254 | 281 | if not gather_dtensor:
|
255 | 282 | # Manage filenames of sharded weights and index file for each pipeline stage.
|
256 |
| - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
| 283 | + model_metadata = self.create_model_metadata(model) |
257 | 284 |
|
258 | 285 | model = model.unwrap()
|
259 | 286 |
|
@@ -409,7 +436,7 @@ def load_sharded_model(
|
409 | 436 | model._force_wait_all_gather()
|
410 | 437 |
|
411 | 438 | if is_pytorch_model_meta_dist_file(checkpoint_index_file):
|
412 |
| - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
| 439 | + model_metadata = self.create_model_metadata(model) |
413 | 440 | checkpoint = checkpoint_index_file.parent
|
414 | 441 | state_dict = load_dist_model(
|
415 | 442 | model_metadata=model_metadata,
|
@@ -817,7 +844,7 @@ def save_unsharded_model(
|
817 | 844 | if not gather_dtensor:
|
818 | 845 | dist_id = self.tp_size * self.pp_rank + self.tp_rank
|
819 | 846 | Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
820 |
| - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
| 847 | + model_metadata = self.create_model_metadata(model) |
821 | 848 | checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin")
|
822 | 849 | if use_async:
|
823 | 850 | checkpoint_file = checkpoint_file.replace(".bin", f".safetensors")
|
@@ -903,7 +930,7 @@ def load_unsharded_model(
|
903 | 930 |
|
904 | 931 | model_metadata = None # used for dist model
|
905 | 932 | if load_dtensor:
|
906 |
| - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) |
| 933 | + model_metadata = self.create_model_metadata(model) |
907 | 934 |
|
908 | 935 | strict = False
|
909 | 936 | model_before_wrapping = model
|
|
0 commit comments