Skip to content

Commit 6a8a917

Browse files
committed
Remove duplicates
1 parent f388bbe commit 6a8a917

File tree

3 files changed

+138
-351
lines changed

3 files changed

+138
-351
lines changed

colossalai/checkpoint_io/distributed_checkpoint_utils.py

+30-255
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from colossalai.interface import ModelWrapper
1313
from colossalai.utils import get_non_persistent_buffers_set
14+
from colossalai.shardformer.layer.parallel_module import ParallelModule
15+
from contextlib import contextmanager
1416

1517
from .index_file import CheckpointIndexFile
1618
from .utils import (
@@ -32,67 +34,32 @@
3234
MODEL_META_PREFIX = "pytorch_model-meta-dist-"
3335
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
3436
SHARD_META_SUFFIX = ".index.json"
37+
UNSHARD_META_SUFFIX = ".json"
3538

3639

37-
def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False):
38-
destination = dict()
39-
# Save parameters.
40-
for name, param in model.named_parameters():
41-
if param is None:
42-
continue
43-
destination[prefix + name] = param
44-
# Save buffers.
45-
non_persist_buffers_set = get_non_persistent_buffers_set(model)
46-
for name, buf in model.named_buffers():
47-
if buf is not None and name not in non_persist_buffers_set:
48-
buffer = buf if keep_vars else buf.detach()
49-
destination[prefix + name] = buffer
50-
51-
# Save extra states.
52-
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
53-
if (
54-
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
55-
is not torch.nn.Module.get_extra_state
56-
):
57-
extra_state = model.get_extra_state()
58-
destination[extra_state_key] = extra_state
59-
return destination
60-
61-
62-
def load_state_dict_into_dist_model(
63-
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
64-
):
65-
destination = dict()
66-
# Save parameters.
67-
for name, param in model.named_parameters():
68-
if param is None:
69-
continue
70-
with torch.no_grad():
71-
param.copy_(state_dict[prefix + name])
72-
# Save buffers.
73-
non_persist_buffers_set = get_non_persistent_buffers_set(model)
74-
for name, buf in model.named_buffers():
75-
if buf is not None and name not in non_persist_buffers_set:
76-
with torch.no_grad():
77-
buf.copy_(state_dict[prefix + name])
78-
79-
# Save extra states.
80-
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
81-
if (
82-
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
83-
is not torch.nn.Module.get_extra_state
84-
):
85-
extra_state = model.get_extra_state()
86-
with torch.no_grad():
87-
extra_state.copy_(state_dict[extra_state_key])
88-
return destination
40+
@contextmanager
41+
def RestoreDefaultStateDictBehavior(model):
42+
original_methods = {}
43+
for name, module in model.named_modules():
44+
if isinstance(module, ParallelModule):
45+
original_methods[module] = (module._save_to_state_dict, module._load_from_state_dict)
46+
module._save_to_state_dict = nn.Module._save_to_state_dict.__get__(module, nn.Module)
47+
module._load_from_state_dict = nn.Module._load_from_state_dict.__get__(module, nn.Module)
48+
try:
49+
yield model
50+
finally:
51+
for module, original_method in original_methods.items():
52+
module._save_to_state_dict, module._load_from_state_dict = original_method
53+
8954

9055

9156
def create_model_metadata(
92-
model: nn.Module,
57+
model: ModelWrapper,
9358
prefix: str = "",
94-
tp_size=None,
95-
tp_rank=None,
59+
tp_size: int = None,
60+
tp_rank: int = None,
61+
zero_size: int = None,
62+
zero_rank: int = None,
9663
):
9764
param_origin_shape = model.param_origin_shape
9865
model = model.unwrap()
@@ -105,7 +72,7 @@ def create_model_metadata(
10572
tp_partition_dim = search_tp_partition_dim(
10673
current_shape=param.shape, original_shape=original_shape, tp_size=tp_size
10774
)
108-
model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
75+
model_metadata[prefix + name]["offsets"] = [0] * len(original_shape)
10976
model_metadata[prefix + name]["lengths"] = list(param.shape)
11077
model_metadata[prefix + name]["global_shape"] = list(original_shape)
11178
if tp_partition_dim is not None:
@@ -257,119 +224,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file):
257224
return False
258225

259226

260-
def dist_model_sharder(
261-
model: nn.Module,
262-
prefix: str = "",
263-
keep_vars: bool = False,
264-
size_per_shard: int = 1024,
265-
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
266-
) -> Iterator[Tuple[OrderedDict, int]]:
267-
# An internel method that breaks state_dict of model into shards within limited size.
268-
269-
state_dict_sharder = StateDictSharder(size_per_shard)
270-
271-
# Save parameters.
272-
for name, param in model.named_parameters():
273-
if param is None:
274-
continue
275-
if pinned_state_dicts is not None:
276-
if (prefix + name) not in pinned_state_dicts:
277-
pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu")
278-
pinned_state_dicts[prefix + name].copy_(param)
279-
param = pinned_state_dicts[prefix + name]
280-
block, block_size = state_dict_sharder.append_param(prefix + name, param)
281-
if block is not None:
282-
yield block, block_size
283-
284-
# Save buffers.
285-
non_persist_buffers_set = get_non_persistent_buffers_set(model)
286-
for name, buf in model.named_buffers():
287-
if buf is not None and name not in non_persist_buffers_set:
288-
buffer = buf if keep_vars else buf.detach()
289-
if pinned_state_dicts is not None:
290-
if (prefix + name) not in pinned_state_dicts:
291-
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
292-
pinned_state_dicts[prefix + name].copy_(buffer)
293-
buffer = pinned_state_dicts[prefix + name]
294-
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
295-
if block is not None:
296-
yield block, block_size
297-
298-
# Save extra states.
299-
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
300-
if (
301-
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
302-
is not torch.nn.Module.get_extra_state
303-
):
304-
extra_state = model.get_extra_state()
305-
if pinned_state_dicts is not None:
306-
if extra_state_key not in pinned_state_dicts:
307-
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
308-
pinned_state_dicts[extra_state_key].copy_(extra_state)
309-
extra_state = pinned_state_dicts[extra_state_key]
310-
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
311-
if block is not None:
312-
yield block, block_size
313-
314-
# Return the last block in sharder.
315-
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
316-
317-
318-
def save_dist_unshard_model(
319-
model: ModelWrapper,
320-
model_metadata: Dict,
321-
checkpoint: str,
322-
use_safetensors: bool,
323-
use_async: bool = False,
324-
dist_id=0,
325-
pinned_state_dicts=None,
326-
):
327-
"""
328-
Save model state dict to a single file with given checkpointing path.
329-
330-
Args:
331-
model (nn.Module): Model on local device to be saved.
332-
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
333-
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
334-
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
335-
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
336-
"""
337-
338-
model = model.unwrap()
339-
340-
# The logic of collecting parameter shards along tp degree
341-
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
342-
state_dict = dist_model_state_dict(model)
343-
344-
Path(checkpoint).mkdir(parents=True, exist_ok=True)
345-
file_name = f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin"
346-
if use_async:
347-
file_name = file_name.replace(".bin", ".safetensors")
348-
checkpoint_file = os.path.join(checkpoint, file_name)
349-
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}.json")
350-
save_metadata(model_metadata, metadata_file, file_name)
351-
352-
if use_async:
353-
from colossalai.utils.safetensors import save
354-
355-
if id(model) not in pinned_state_dicts:
356-
pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
357-
for name, param in state_dict.items():
358-
pinned_state_dicts[id(model)][name].copy_(param)
359-
state_dict[name] = pinned_state_dicts[id(model)][name]
360-
writer = save(path=checkpoint_file, state_dict=state_dict)
361-
return writer
362-
else:
363-
save_state_dict(state_dict, checkpoint_file, use_safetensors)
364-
return None
365-
366-
367227
def load_dist_model(
368-
model: ModelWrapper,
369228
model_metadata: Dict,
370229
checkpoint: str,
371-
low_cpu_mem_mode: bool = True,
372-
num_threads: int = 1,
373230
):
374231
"""
375232
Load model from a single file with the given path of checkpoint.
@@ -380,10 +237,6 @@ def load_dist_model(
380237
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
381238
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
382239
"""
383-
384-
model_before_wrapping = model
385-
model = model.unwrap()
386-
387240
metadata_loaded = load_metadata(checkpoint)
388241

389242
load_files = {}
@@ -420,92 +273,14 @@ def load_dist_model(
420273
)
421274
state_dict[key] = state
422275

423-
if not low_cpu_mem_mode:
424-
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
425-
426-
load_state_dict_into_dist_model(model=model, state_dict=state_dict)
427-
428-
# Update master params if mixed-precision training is enabled.
429-
model_before_wrapping.update_master_params()
430-
276+
return state_dict
431277

432-
def save_dist_sharded_model(
433-
model: ModelWrapper,
434-
model_metadata: Dict,
435-
checkpoint: str,
436-
prefix: Optional[str] = None,
437-
size_per_shard: int = 1024,
438-
use_safetensors: bool = False,
439-
use_async: bool = False,
440-
dist_id: int = 0,
441-
pinned_state_dicts=None,
442-
) -> None:
443-
"""
444-
Save sharded model checkpoint under the given checkpointing path.
445-
The following files will be created under the path:
446-
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
447-
- Multiple files that store state tensors of models.
448-
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
449-
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
450-
451-
452-
Args:
453-
model (nn.Module): Model on local device to be saved.
454-
checkpoint (str): Checkpointing path which should be a directory path.
455-
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
456-
prefix (str, optional): Perfix of file to save. Defaults to None.
457-
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
458-
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
459-
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
460-
"""
461-
462-
model = model.unwrap()
463-
464-
if os.path.isfile(checkpoint):
465-
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
466-
return
467-
468-
Path(checkpoint).mkdir(parents=True, exist_ok=True)
469-
# Devices along the same dp_group share the same copies of model.
470-
# So only let the device with dp_rank == 0 and sp_rank == 0 save the model.
471-
472-
if use_async:
473-
if id(model) not in pinned_state_dicts:
474-
pinned_state_dicts[id(model)] = {}
475-
pinned_state_dicts = pinned_state_dicts[id(model)]
476-
else:
477-
pinned_state_dicts = None
478-
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts)
479-
weights_name, _ = get_model_base_filenames(prefix, use_safetensors)
480-
index_file = CheckpointIndexFile(checkpoint)
481-
482-
# Manage filenames of sharded weights and index file for each pipeline stage.
278+
def get_dist_files_name(weights_name, dist_id):
483279
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin")
484280
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors")
485-
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
486-
async_writers = []
487-
if use_async:
488-
total_size, writers = async_save_state_dict_shards(
489-
sharded_state_dict=state_dict_shard,
490-
checkpoint=checkpoint,
491-
index_file=index_file,
492-
base_filename=weights_name,
493-
is_master=True,
494-
state_preprocess=False,
495-
)
496-
async_writers.extend(writers)
497-
else:
498-
total_size = save_state_dict_shards(
499-
sharded_state_dict=state_dict_shard,
500-
checkpoint=checkpoint,
501-
index_file=index_file,
502-
base_filename=weights_name,
503-
is_master=True,
504-
use_safetensors=use_safetensors,
505-
use_pp_format=True,
506-
)
507-
for k, _ in model_metadata.items():
508-
model_metadata[k]["file"] = index_file.get_checkpoint_file(k)
281+
return weights_name
509282

510-
save_metadata(model_metadata, metadata_file, total_size=total_size)
511-
return async_writers
283+
def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors):
284+
if use_safetensors:
285+
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
286+
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")

0 commit comments

Comments
 (0)