Skip to content

Commit d339df2

Browse files
authored
Supports DP+TP+EP hybrid parallel deployment strategy (#3489)
* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
1 parent 52eda7f commit d339df2

File tree

15 files changed

+306
-226
lines changed

15 files changed

+306
-226
lines changed

custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,16 @@
4343
__VA_ARGS__ \
4444
break; \
4545
} \
46-
case 48: { \
47-
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
48-
__VA_ARGS__ \
49-
break; \
50-
} \
46+
case 32: { \
47+
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
48+
__VA_ARGS__ \
49+
break; \
50+
} \
51+
case 48: { \
52+
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
53+
__VA_ARGS__ \
54+
break; \
55+
} \
5156
case 64: { \
5257
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
5358
__VA_ARGS__ \

custom_ops/gpu_ops/save_with_output_msg.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ void SaveOutMmsg(const paddle::Tensor& x,
105105
int64_t rank_id,
106106
int msg_queue_id,
107107
bool save_each_rank) {
108-
if (!save_each_rank && rank_id > 0) {
108+
// don't use save_each_rank now!
109+
if (rank_id > 0) {
109110
return;
110111
}
111112
if (x.place() == paddle::CPUPlace()) {

fastdeploy/config.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Dict, List, Literal, Optional, Union
2323

2424
import paddle
25+
import paddle.distributed as dist
2526
from paddleformers.transformers.configuration_utils import PretrainedConfig
2627

2728
import fastdeploy
@@ -308,7 +309,10 @@ def __init__(
308309
setattr(self, key, value)
309310

310311
# currently, the expert parallel size is equal data parallel size
311-
self.expert_parallel_size = self.data_parallel_size
312+
if self.enable_expert_parallel:
313+
self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size
314+
else:
315+
self.expert_parallel_size = 1
312316
self.use_ep = self.expert_parallel_size > 1
313317
if self.splitwise_role == "mixed":
314318
self.moe_phase = MoEPhase(phase="prefill")
@@ -329,6 +333,22 @@ def __init__(
329333
else:
330334
self.pd_disaggregation_mode = "None"
331335

336+
def set_tp_group(self):
337+
# different tp group id
338+
# prevent different tp_groups using the same group_id
339+
dist.collective._set_custom_gid(self.data_parallel_rank + 100)
340+
self.tp_group = dist.new_group(
341+
range(
342+
self.data_parallel_rank * self.tensor_parallel_size,
343+
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
344+
)
345+
)
346+
# same ep group id
347+
dist.collective._set_custom_gid(self.data_parallel_size + 100)
348+
logger.info(
349+
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
350+
)
351+
332352
def print(self):
333353
"""
334354
print all config
@@ -1104,7 +1124,7 @@ def __init__(
11041124
if self.model_config is not None and self.model_config.enable_mm:
11051125
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
11061126

1107-
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size
1127+
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
11081128
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
11091129
if num_ranks > self.max_chips_per_node:
11101130
self.worker_num_per_node = self.max_chips_per_node

fastdeploy/distributed/communication.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,20 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
4747
@paddle.jit.marker.unified
4848
def tensor_model_parallel_all_reduce(
4949
input_: paddle.Tensor,
50+
group_: paddle.distributed.communication.group.Group = None,
5051
) -> paddle.Tensor:
5152
"""All-reduce the input tensor across model parallel group."""
5253
global _TP_AR
5354
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
55+
# TODO: supports different_group custom allreduce
5456
_TP_AR.custom_all_reduce(input_)
5557
elif paddle.in_dynamic_mode():
56-
hcg = fleet.get_hybrid_communicate_group()
57-
mp_group = hcg.get_model_parallel_group()
58-
dist.all_reduce(input_, group=mp_group)
58+
if group_ is not None:
59+
dist.all_reduce(input_, group=group_)
60+
else:
61+
hcg = fleet.get_hybrid_communicate_group()
62+
mp_group = hcg.get_model_parallel_group()
63+
dist.all_reduce(input_, group=mp_group)
5964
else:
6065
dist.all_reduce(input_)
6166

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -57,43 +57,37 @@ def __init__(
5757
hcg = fleet.get_hybrid_communicate_group()
5858
self.mp_rank: int = hcg.get_model_parallel_rank()
5959
self.column_cut = False
60-
self.world_size: int = hcg.get_model_parallel_world_size()
61-
self.ring_id: int = hcg.get_model_parallel_group().id
62-
self.use_ep: bool = fd_config.parallel_config.use_ep
60+
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
61+
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
62+
self.tp_group = fd_config.parallel_config.tp_group
6363
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
6464
self.initializer_range: float = fd_config.model_config.initializer_range
6565
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
6666
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
6767
self.params_dtype: str = params_dtype
6868

69-
if self.use_ep:
70-
self.embeddings = nn.Embedding(
69+
if not self.column_cut:
70+
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
7171
num_embeddings,
7272
embedding_dim,
73+
mp_group=self.tp_group,
74+
weight_attr=paddle.ParamAttr(
75+
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
76+
),
7377
)
78+
if self.world_size > 1:
79+
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
7480
else:
75-
if not self.column_cut:
76-
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
77-
num_embeddings,
78-
embedding_dim,
79-
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
80-
weight_attr=paddle.ParamAttr(
81-
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
82-
),
83-
)
84-
if self.world_size > 1:
85-
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
86-
else:
87-
# column cut embedding
88-
self.embeddings = nn.Embedding(
89-
num_embeddings,
90-
embedding_dim // self.world_size,
91-
)
92-
93-
self.embeddings.weight.is_distributed = True
94-
self.embeddings.weight.split_axis = 1
95-
if self.world_size > 1:
96-
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
81+
# column cut embedding
82+
self.embeddings = nn.Embedding(
83+
num_embeddings,
84+
embedding_dim // self.world_size,
85+
)
86+
87+
self.embeddings.weight.is_distributed = True
88+
self.embeddings.weight.split_axis = 1
89+
if self.world_size > 1:
90+
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
9791

9892
self.prefix = prefix
9993
self.dropout = nn.Dropout(self.hidden_dropout_prob)
@@ -125,20 +119,17 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
125119
Returns:
126120
Tensor: Embedded tensor representation of the input IDs.
127121
"""
128-
if self.use_ep:
122+
if self.column_cut:
129123
input_embedings = self.embeddings(ids_remove_padding)
124+
inputs_embeds_temp = []
125+
paddle.distributed.all_gather(
126+
inputs_embeds_temp,
127+
input_embedings,
128+
group=self.tp_group,
129+
sync_op=True,
130+
)
131+
input_embedings = paddle.concat(inputs_embeds_temp, -1)
130132
else:
131-
if self.column_cut:
132-
input_embedings = self.embeddings(ids_remove_padding)
133-
inputs_embeds_temp = []
134-
paddle.distributed.all_gather(
135-
inputs_embeds_temp,
136-
input_embedings,
137-
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
138-
sync_op=True,
139-
)
140-
input_embedings = paddle.concat(inputs_embeds_temp, -1)
141-
else:
142-
input_embedings = self.embeddings(ids_remove_padding)
133+
input_embedings = self.embeddings(ids_remove_padding)
143134

144135
return input_embedings

fastdeploy/model_executor/layers/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ def __init__(
703703
self.fd_config = fd_config
704704
self.skip_quant = False
705705
self.nranks = fd_config.parallel_config.tensor_parallel_size
706+
self.tp_group = fd_config.parallel_config.tp_group
706707
self.hidden_size = fd_config.model_config.hidden_size
707708
self.head_dim = fd_config.model_config.head_dim
708709
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
@@ -751,7 +752,7 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
751752
out = paddle.matmul(x, self.weight)
752753

753754
if self.reduce_results and self.nranks > 1:
754-
tensor_model_parallel_all_reduce(out)
755+
tensor_model_parallel_all_reduce(out, self.tp_group)
755756

756757
return out
757758

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 52 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self.bias_key: Optional[str] = prefix + ".bias"
5959
else:
6060
self.bias_key: Optional[str] = None
61-
self.use_ep: bool = fd_config.parallel_config.use_ep
61+
self.tp_group = fd_config.parallel_config.tp_group
6262
self.column_cut = True
6363
self.nranks = fd_config.parallel_config.tensor_parallel_size
6464
self.fd_config = fd_config
@@ -68,60 +68,46 @@ def __init__(
6868

6969
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
7070

71-
if self.use_ep:
72-
self.weight = self.create_parameter(
73-
shape=[embedding_dim, num_embeddings],
74-
dtype=paddle.get_default_dtype(),
75-
is_bias=False,
71+
if self.column_cut:
72+
need_gather = True
73+
self.linear = ColumnParallelLinear(
74+
embedding_dim,
75+
num_embeddings,
76+
mp_group=self.tp_group,
77+
weight_attr=None,
78+
has_bias=True if self.bias_key is not None else False,
79+
gather_output=need_gather,
80+
fuse_matmul_bias=False,
7681
)
77-
if self.bias_key is not None:
78-
self.bias = self.create_parameter(
79-
shape=[num_embeddings],
80-
dtype=paddle.get_default_dtype(),
81-
is_bias=True,
82-
)
83-
82+
set_weight_attrs(
83+
self.linear.weight,
84+
{
85+
"weight_loader": default_weight_loader(self.fd_config),
86+
"model_format": self.fd_config.model_config.model_format,
87+
},
88+
)
89+
if self.nranks > 1:
90+
set_weight_attrs(self.linear.weight, {"output_dim": True})
8491
else:
85-
if self.column_cut:
86-
need_gather = True
87-
self.linear = ColumnParallelLinear(
88-
embedding_dim,
89-
num_embeddings,
90-
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
91-
weight_attr=None,
92-
has_bias=True if self.bias_key is not None else False,
93-
gather_output=need_gather,
94-
fuse_matmul_bias=False,
95-
)
96-
set_weight_attrs(
97-
self.linear.weight,
98-
{
99-
"weight_loader": default_weight_loader(self.fd_config),
100-
"model_format": self.fd_config.model_config.model_format,
101-
},
102-
)
103-
if self.nranks > 1:
104-
set_weight_attrs(self.linear.weight, {"output_dim": True})
105-
else:
106-
self.linear = RowParallelLinear(
107-
embedding_dim,
108-
num_embeddings,
109-
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
110-
weight_attr=None,
111-
has_bias=True if self.bias_key is not None else False,
112-
input_is_parallel=False,
113-
fuse_matmul_bias=False,
114-
)
115-
set_weight_attrs(
116-
self.linear.weight,
117-
{
118-
"weight_loader": default_weight_loader(self.fd_config),
119-
"model_format": self.fd_config.model_config.model_format,
120-
},
121-
)
122-
123-
if self.nranks > 1:
124-
set_weight_attrs(self.linear.weight, {"output_dim": False})
92+
self.linear = RowParallelLinear(
93+
embedding_dim,
94+
num_embeddings,
95+
mp_group=self.tp_group,
96+
weight_attr=None,
97+
has_bias=True if self.bias_key is not None else False,
98+
input_is_parallel=False,
99+
fuse_matmul_bias=False,
100+
)
101+
set_weight_attrs(
102+
self.linear.weight,
103+
{
104+
"weight_loader": default_weight_loader(self.fd_config),
105+
"model_format": self.fd_config.model_config.model_format,
106+
},
107+
)
108+
109+
if self.nranks > 1:
110+
set_weight_attrs(self.linear.weight, {"output_dim": False})
125111

126112
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
127113
"""
@@ -131,24 +117,19 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
131117
state_dict (dict): A dictionary containing the checkpoint weights and biases.
132118
"""
133119

134-
if self.use_ep:
135-
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
136-
if self.bias_key is not None:
137-
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
120+
if self.tie_word_embeddings:
121+
self.linear.weight.set_value(
122+
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
123+
)
138124
else:
139-
if self.tie_word_embeddings:
140-
self.linear.weight.set_value(
141-
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
142-
)
143-
else:
144-
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
145-
if self.linear.weight.shape != weight_tensor.shape:
146-
weight_tensor = weight_tensor.transpose([1, 0])
147-
self.linear.weight.set_value(weight_tensor)
148-
149-
if self.bias_key is not None:
150-
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
151-
self.linear.bias.set_value(bias)
125+
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
126+
if self.linear.weight.shape != weight_tensor.shape:
127+
weight_tensor = weight_tensor.transpose([1, 0])
128+
self.linear.weight.set_value(weight_tensor)
129+
130+
if self.bias_key is not None:
131+
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
132+
self.linear.bias.set_value(bias)
152133

153134
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
154135
"""
@@ -161,11 +142,5 @@ def forward(self, input: paddle.Tensor) -> paddle.Tensor:
161142
Tensor: The output tensor after processing through the layer.
162143
"""
163144
logits = input
164-
if self.use_ep:
165-
if self.bias_key is None:
166-
logits = paddle.matmul(logits, self.weight)
167-
else:
168-
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
169-
else:
170-
logits = self.linear(logits)
145+
logits = self.linear(logits)
171146
return logits

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,6 @@ def apply_tp(
466466
1.0,
467467
)[0]
468468
if layer.tp_size > 1:
469-
tensor_model_parallel_all_reduce(tmp_ffn_out)
469+
tensor_model_parallel_all_reduce(tmp_ffn_out, self.tp_group)
470470

471471
return tmp_ffn_out

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def __init__(
9898
self.tp_size = fd_config.parallel_config.tensor_parallel_size
9999
self.ep_size = fd_config.parallel_config.expert_parallel_size
100100
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
101+
self.tp_group = fd_config.parallel_config.tp_group
102+
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
103+
if self.ep_size > 1:
104+
self.tp_size = 1
105+
self.tp_rank = 0
101106

102107
assert (self.tp_size >= 1 and self.ep_size == 1) or (
103108
self.tp_size == 1 and self.ep_size > 1

0 commit comments

Comments
 (0)