Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[checkpointio]support distributed checkpoint io for model saving. #6181

Open
wants to merge 8 commits into
base: feature/dist-ckp-io
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove duplicates
  • Loading branch information
flybird11111 committed Jan 21, 2025
commit c5b088219f4f3ab08fa97f441208bc47268da656
302 changes: 31 additions & 271 deletions colossalai/checkpoint_io/distributed_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,50 @@
import json
import logging
import os
from pathlib import Path
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
from typing import Dict

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.distributed_c10d import _get_default_group

from colossalai.interface import ModelWrapper
from colossalai.utils import get_non_persistent_buffers_set
from colossalai.shardformer.layer.parallel_module import ParallelModule
from contextlib import contextmanager

from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
load_state_dict,
save_state_dict,
save_state_dict_shards,
search_tp_partition_dim,
)

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"

MODEL_META_PREFIX = "pytorch_model-meta-dist-"
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
SHARD_META_SUFFIX = ".index.json"
UNSHARD_META_SUFFIX = ".json"


def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False):
destination = dict()
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
destination[prefix + name] = param
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
destination[prefix + name] = buffer

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
destination[extra_state_key] = extra_state
return destination


def load_state_dict_into_dist_model(
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
):
destination = dict()
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
with torch.no_grad():
param.copy_(state_dict[prefix + name])
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
with torch.no_grad():
buf.copy_(state_dict[prefix + name])

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
with torch.no_grad():
extra_state.copy_(state_dict[extra_state_key])
return destination
@contextmanager
def RestoreDefaultStateDictBehavior(model):
original_methods = {}
for name, module in model.named_modules():
if isinstance(module, ParallelModule):
original_methods[module] = (module._save_to_state_dict, module._load_from_state_dict)
module._save_to_state_dict = nn.Module._save_to_state_dict.__get__(module, nn.Module)
module._load_from_state_dict = nn.Module._load_from_state_dict.__get__(module, nn.Module)
try:
yield model
finally:
for module, original_method in original_methods.items():
module._save_to_state_dict, module._load_from_state_dict = original_method



def create_model_metadata(
model: nn.Module,
model: ModelWrapper,
prefix: str = "",
tp_size=None,
tp_rank=None,
tp_size: int = None,
tp_rank: int = None,
zero_size: int = None,
zero_rank: int = None,
):
param_origin_shape = model.param_origin_shape
model = model.unwrap()
@@ -105,7 +57,7 @@ def create_model_metadata(
tp_partition_dim = search_tp_partition_dim(
current_shape=param.shape, original_shape=original_shape, tp_size=tp_size
)
model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
model_metadata[prefix + name]["offsets"] = [0] * len(original_shape)
model_metadata[prefix + name]["lengths"] = list(param.shape)
model_metadata[prefix + name]["global_shape"] = list(original_shape)
if tp_partition_dim is not None:
@@ -257,119 +209,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file):
return False


def dist_model_sharder(
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.

state_dict_sharder = StateDictSharder(size_per_shard)

# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(param)
param = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, param)
if block is not None:
yield block, block_size

# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size

# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size


def save_dist_unshard_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
use_safetensors: bool,
use_async: bool = False,
dist_id=0,
pinned_state_dicts=None,
):
"""
Save model state dict to a single file with given checkpointing path.
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

model = model.unwrap()

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = dist_model_state_dict(model)

Path(checkpoint).mkdir(parents=True, exist_ok=True)
file_name = f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin"
if use_async:
file_name = file_name.replace(".bin", ".safetensors")
checkpoint_file = os.path.join(checkpoint, file_name)
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}.json")
save_metadata(model_metadata, metadata_file, file_name)

if use_async:
from colossalai.utils.safetensors import save

if id(model) not in pinned_state_dicts:
pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint_file, state_dict=state_dict)
return writer
else:
save_state_dict(state_dict, checkpoint_file, use_safetensors)
return None


def load_dist_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from a single file with the given path of checkpoint.
@@ -380,10 +222,6 @@ def load_dist_model(
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
"""

model_before_wrapping = model
model = model.unwrap()

metadata_loaded = load_metadata(checkpoint)

load_files = {}
@@ -420,92 +258,14 @@ def load_dist_model(
)
state_dict[key] = state

if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)

load_state_dict_into_dist_model(model=model, state_dict=state_dict)

# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
return state_dict


def save_dist_sharded_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
dist_id: int = 0,
pinned_state_dicts=None,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

model = model.unwrap()

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 and sp_rank == 0 save the model.

if use_async:
if id(model) not in pinned_state_dicts:
pinned_state_dicts[id(model)] = {}
pinned_state_dicts = pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts)
weights_name, _ = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)

# Manage filenames of sharded weights and index file for each pipeline stage.
def get_dist_files_name(weights_name, dist_id):
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors")
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
async_writers = []
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=True,
state_preprocess=False,
)
async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
use_pp_format=True,
)
for k, _ in model_metadata.items():
model_metadata[k]["file"] = index_file.get_checkpoint_file(k)
return weights_name

save_metadata(model_metadata, metadata_file, total_size=total_size)
return async_writers
def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors):
if use_safetensors:
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")
197 changes: 103 additions & 94 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
from contextlib import nullcontext

import torch
import torch.distributed as dist
@@ -28,8 +29,11 @@
create_model_metadata,
is_pytorch_model_meta_dist_file,
load_dist_model,
save_dist_sharded_model,
save_dist_unshard_model,
save_metadata,
get_dist_files_name,
get_dist_meta_file_name,
MODEL_WEIGHT_PREFIX,
RestoreDefaultStateDictBehavior
)
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -97,13 +101,14 @@ def __init__(
self.verbose = verbose
self.coordinator = DistCoordinator()

@staticmethod
def _model_sharder(
self,
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
gather_dtensor: bool = True,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.

@@ -113,10 +118,15 @@ def _model_sharder(
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)

if gather_dtensor:
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
else:
param_ = param

if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
@@ -237,55 +247,45 @@ def save_sharded_model(

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()

if gather_dtensor:
if self.dp_rank != 0 and self.sp_rank != 0:
return
dist_id = self.tp_size * self.pp_rank + self.tp_rank
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
async_writers = save_dist_sharded_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
use_async=use_async,
dist_id=dist_id,
pinned_state_dicts=self.pinned_state_dicts,
)
self.async_writers.extend(async_writers)
if self.dp_rank != 0 and self.sp_rank != 0:
return


model_metadata = None
if not gather_dtensor:
# Manage filenames of sharded weights and index file for each pipeline stage.
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)

model = model.unwrap()

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return

# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
control_saving = self.tp_rank == 0 and self.sp_rank == 0
control_saving = self.tp_rank == 0 if gather_dtensor else True
if control_saving and use_async:
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
state_dict_shard = self._model_sharder(
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts, gather_dtensor=gather_dtensor
)

weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)

if self.pp_size == 1:
if self.pp_size == 1 or not gather_dtensor:
# When pipeline is not used, save the model shards as in general checkpointIO
if not gather_dtensor:
dist_id = self.tp_size * self.pp_rank + self.tp_rank
weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id)
metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors)

if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
@@ -305,16 +305,22 @@ def save_sharded_model(
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if not gather_dtensor:
# saving metadata for distributed checkpoint
for k, _ in model_metadata.items():
model_metadata[k]["file"] = index_file.get_checkpoint_file(k)
save_metadata(model_metadata, metadata_file, total_size=total_size)
else:
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -405,13 +411,15 @@ def load_sharded_model(
if is_pytorch_model_meta_dist_file(checkpoint_index_file):
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
checkpoint = checkpoint_index_file.parent
load_dist_model(
model=model,
state_dict = load_dist_model(
model_metadata=model_metadata,
checkpoint=checkpoint,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
model = model.unwrap()
with RestoreDefaultStateDictBehavior(model):
load_state_dict_into_model(
model, state_dict, missing_keys=[], strict=False, load_sub_module=True
)
return

model_before_wrapping = model # backup for model before wrapping
@@ -803,47 +811,43 @@ def save_unsharded_model(
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()

model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
if self.dp_rank != 0 and self.sp_rank != 0:
return

if gather_dtensor:
if self.dp_rank != 0 and self.sp_rank != 0:
return
if not gather_dtensor:
dist_id = self.tp_size * self.pp_rank + self.tp_rank
writer = save_dist_unshard_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
use_safetensors=use_safetensors,
use_async=use_async,
dist_id=dist_id,
pinned_state_dicts=self.pinned_state_dicts,
)
if writer is not None:
self.async_writers.append(writer)
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin")
if use_async:
checkpoint_file = checkpoint_file.replace(".bin", f".safetensors")
metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_async)
save_metadata(model_metadata=model_metadata, metadata_file=metadata_file, checkpoint_file=checkpoint_file)
else:
checkpoint_file = checkpoint

model = model.unwrap()
if self.dp_rank != 0:
return

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
if use_async:
from colossalai.utils.safetensors import save
ctx = RestoreDefaultStateDictBehavior(model) if not gather_dtensor else nullcontext()
with ctx:
state_dict = model.state_dict()

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
if (self.pp_size == 1 and self.tp_rank == 0) or not gather_dtensor:
# When pipeline is not used, let master rank directly save the collected state_dict.
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint_file, state_dict=state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint_file, use_safetensors)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
@@ -862,10 +866,10 @@ def save_unsharded_model(
for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
writer = save(path=checkpoint_file, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
save_state_dict(complete_state_dict, checkpoint_file, use_safetensors)

def load_unsharded_model(
self,
@@ -890,18 +894,16 @@ def load_unsharded_model(
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()

load_dtensor = False
if os.path.isdir(checkpoint):
for filename in os.listdir(checkpoint):
if is_pytorch_model_meta_dist_file(filename):
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
load_dist_model(
model=model,
model_metadata=model_metadata,
checkpoint=checkpoint,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
return
load_dtensor = True
break

model_metadata = None # used for dist model
if load_dtensor:
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)

strict = False
model_before_wrapping = model
@@ -910,10 +912,17 @@ def load_unsharded_model(
# Load from checkpoint. Since the logic of breaking parameter shards along tp degree
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
# model.load_state_dict can be directly called.
state_dict = load_state_dict(checkpoint)
if load_dtensor:
state_dict = load_dist_model(model_metadata=model_metadata, checkpoint=checkpoint)
else:
state_dict = load_state_dict(checkpoint)

if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
model.load_state_dict(state_dict, strict=strict)

ctx = RestoreDefaultStateDictBehavior(model) if load_dtensor else nullcontext()
with ctx:
model.load_state_dict(state_dict, strict=strict)

# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
4 changes: 2 additions & 2 deletions tests/test_checkpoint_io/test_dist_checkpointio.py
Original file line number Diff line number Diff line change
@@ -82,7 +82,7 @@ def _preprocess_data(data):
model_0,
model_ckpt_path_0,
shard=shard,
gather_dtensor=True,
gather_dtensor=False,
size_per_shard=size_per_shard,
use_async=use_async,
)
@@ -104,7 +104,7 @@ def _preprocess_data(data):
model_1,
model_ckpt_path_1,
shard=shard,
gather_dtensor=True,
gather_dtensor=False,
size_per_shard=size_per_shard,
use_async=use_async,
)