Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from .metadata import LocalTensorIndex, LocalTensorMetadata, Metadata
from .metadata_manager import MetadataManager
from .reshard_comm import CommunicatorFactory
from .resharder import StateDictResharder
from .resharder import (
StateDictResharder,
ThreeDCommGroupStateResharder,
)
from .sharded_weight import (
ShardedWeight,
ShardedWeightDesc,
Expand All @@ -62,7 +65,10 @@
from paddle.distributed.collective import Group

PATH_TO_CHECKPOINT_FILES: dict[str, tuple[list, list]] = {}
_UNINIT_TENSOR_MODES = ["send_recv", "grouped_send_recv"]

# When using the communication mode described below, newly created tensors will not be allocated GPU memory.
# The allocation of GPU memory for these tensors will occur only when meaningful values are written to them.
_UNINIT_TENSOR_MODES = ["send_recv", "grouped_send_recv", "parallel_broadcast"]

_metadata_manager = MetadataManager()

Expand Down Expand Up @@ -734,6 +740,7 @@ def load_state_dict(
"broadcast",
"multi_group_broadcast",
"grouped_send_recv",
"parallel_broadcast",
]
assert comm_method in valid_methods, (
f"Invalid communication method '{comm_method}'. "
Expand Down Expand Up @@ -976,7 +983,6 @@ def restore_unflattened_state_dict(
tmp_metadata.storage_metadata = {
k: v for d in global_storage_metadata for k, v in d.items()
}

_load_state_dict(
target_state_dict=destination_sharded_state_dict,
source_state_dict=source_state_dict_for_reshard,
Expand Down Expand Up @@ -1265,19 +1271,39 @@ def _load_state_dict(
worker_groups: list[Group] | None = None,
comm_method: str = 'broadcast',
):
use_dist = True if paddle.distributed.get_world_size() > 1 else False
communicator = CommunicatorFactory.create(
comm_method, worker_groups=worker_groups
)
resharder = StateDictResharder(
target_state_dict=target_state_dict,
source_state_dict=source_state_dict,
metadata_list=metadata_list,
communicator=communicator,
process_group=process_group,
offload=offload,
use_dist=use_dist,
)
if comm_method != "parallel_broadcast":
use_dist = True if paddle.distributed.get_world_size() > 1 else False
communicator = CommunicatorFactory.create(
comm_method, worker_groups=worker_groups
)
resharder = StateDictResharder(
target_state_dict=target_state_dict,
source_state_dict=source_state_dict,
metadata_list=metadata_list,
communicator=communicator,
process_group=process_group,
offload=offload,
use_dist=use_dist,
)
else:
assert len(worker_groups) == 3, (
f"When the reshard communication mode is set to 'parallel_broadcast', the number of worker_groups must be 3, "
f"i.e., it must include groups for the horizontal, vertical, and parallel directions. "
f"However, there are currently only {len(worker_groups)} groups. "
f"Please check the worker_groups parameter: {worker_groups}"
)
h_group, v_group, p_group = worker_groups[:3]
resharder = ThreeDCommGroupStateResharder(
target_state_dict=target_state_dict,
source_state_dict=source_state_dict,
metadata_list=metadata_list,
h_group=h_group,
v_group=v_group,
p_group=p_group,
memory_growth_threshold=8 * (2**30),
offload=offload,
)

resharder.reshard()


Expand Down
Loading