Skip to content

Commit 02663d6

Browse files
tohtanadelock
andauthored
Support AutoEP with ZeRO-3 zero.Init source modules (#8060)
This PR enables ZeRO3 support for AutoEP-managed MoE layers by partitioning expert parameters over expert replica groups while router and replicated parameters use the global data-parallel group. With ZeRO3 enable, AutoEP preserves global data-parallel gradient averaging for AutoEP expert parameters while reducing them over expert replica groups. ZeRO parameters are gathered before AutoEP reads router or expert tensors when replacing MoE modules created under `deepspeed.zero.Init()`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
1 parent f0253c8 commit 02663d6

21 files changed

Lines changed: 2980 additions & 422 deletions
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) DeepSpeed Team.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
"""Shared validation for AutoEP ZeRO-3 checkpoint metadata."""
6+
7+
from deepspeed.checkpoint.constants import (
8+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION,
9+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY,
10+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY,
11+
AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT,
12+
)
13+
14+
AUTOEP_METADATA_REQUIRED_FIELDS = frozenset({
15+
'moe_layer_id',
16+
'module_path',
17+
'num_experts',
18+
'num_local_experts',
19+
'ep_size',
20+
'expert_key_prefix',
21+
})
22+
23+
AUTOEP_ZERO3_PARTITIONED_METADATA_FIELDS = frozenset({
24+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY,
25+
'ep_group_name',
26+
'ep_rank',
27+
'expert_data_parallel_rank',
28+
'expert_data_parallel_world_size',
29+
'global_expert_start',
30+
'global_expert_end',
31+
})
32+
33+
34+
def is_autoep_zero3_partitioned_entry(entry):
35+
return (isinstance(entry, dict)
36+
and entry.get(AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY) == AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT)
37+
38+
39+
def validate_autoep_zero3_partitioned_metadata(autoep_metadata,
40+
require_partitioned=True,
41+
expected_expert_prefixes=None,
42+
version_context="This DeepSpeed build"):
43+
if not isinstance(autoep_metadata, list):
44+
raise RuntimeError(f"ds_autoep_layers metadata is malformed: expected list, got "
45+
f"{type(autoep_metadata).__name__}")
46+
47+
seen_layer_ids = set()
48+
seen_prefixes = set()
49+
partitioned_count = 0
50+
51+
for entry in autoep_metadata:
52+
if not isinstance(entry, dict):
53+
raise RuntimeError(f"ds_autoep_layers entry is malformed: expected dict, got "
54+
f"{type(entry).__name__}")
55+
missing = AUTOEP_METADATA_REQUIRED_FIELDS - entry.keys()
56+
if missing:
57+
raise RuntimeError(f"ds_autoep_layers entry is invalid: missing fields {sorted(missing)}")
58+
59+
layer_id = entry['moe_layer_id']
60+
if layer_id in seen_layer_ids:
61+
raise RuntimeError(f"ds_autoep_layers metadata has duplicate moe_layer_id: {layer_id}")
62+
seen_layer_ids.add(layer_id)
63+
64+
prefix = entry['expert_key_prefix']
65+
if prefix in seen_prefixes:
66+
raise RuntimeError(f"ds_autoep_layers metadata has duplicate expert_key_prefix: {prefix}")
67+
seen_prefixes.add(prefix)
68+
69+
if not is_autoep_zero3_partitioned_entry(entry):
70+
continue
71+
72+
missing = AUTOEP_ZERO3_PARTITIONED_METADATA_FIELDS - entry.keys()
73+
if missing:
74+
raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata is invalid: missing fields {sorted(missing)}")
75+
version = entry[AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY]
76+
if version != AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION:
77+
raise RuntimeError("Unsupported AutoEP ZeRO-3 checkpoint format version: "
78+
f"{version}. {version_context} supports version "
79+
f"{AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION}.")
80+
81+
num_experts = entry['num_experts']
82+
num_local_experts = entry['num_local_experts']
83+
ep_size = entry['ep_size']
84+
if num_local_experts * ep_size != num_experts:
85+
raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata is inconsistent: "
86+
f"num_local_experts={num_local_experts}, ep_size={ep_size}, "
87+
f"num_experts={num_experts}")
88+
89+
expected_start = entry['ep_rank'] * num_local_experts
90+
expected_end = expected_start + num_local_experts
91+
if entry['global_expert_start'] != expected_start or entry['global_expert_end'] != expected_end:
92+
raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has inconsistent global expert range: "
93+
f"got [{entry['global_expert_start']}, {entry['global_expert_end']}), "
94+
f"expected [{expected_start}, {expected_end})")
95+
96+
if expected_expert_prefixes is not None:
97+
module_path = entry['module_path']
98+
if module_path not in expected_expert_prefixes:
99+
raise RuntimeError(f"AutoEP ZeRO-3 checkpoint metadata references missing module: {module_path}")
100+
expected_prefix = expected_expert_prefixes[module_path]
101+
if prefix != expected_prefix:
102+
raise RuntimeError("AutoEP ZeRO-3 checkpoint metadata has unexpected expert key prefix: "
103+
f"got {prefix}, expected {expected_prefix}")
104+
105+
partitioned_count += 1
106+
107+
if require_partitioned and partitioned_count == 0:
108+
raise RuntimeError("AutoEP ZeRO-3 partition-native checkpoint metadata was expected but no "
109+
"partitioned AutoEP layer entries were found")

deepspeed/checkpoint/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@
9393
#########################################
9494
AUTOEP_LAYERS_KEY = 'ds_autoep_layers'
9595
AUTOEP_LAYERS_KEY_LEGACY = 'autoep_layers'
96+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_KEY = 'checkpoint_format'
97+
AUTOEP_ZERO3_PARTITIONED_EXPERT_STATE_FORMAT = 'zero3_partitioned'
98+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION_KEY = 'checkpoint_format_version'
99+
AUTOEP_ZERO3_EXPERT_STATE_FORMAT_VERSION = 1
96100

97101
#########################################
98102
# Universal Checkpoint EP keys

0 commit comments

Comments
 (0)