Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions custom_ops/xpu_ops/src/ops/block_attn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
rope_head_dim = rotary_embs.dims()[4];
}
std::string pos_emb_type;
if (use_neox_rotary_style == true) {
if (use_neox_rotary_style) {
pos_emb_type = "NEOX";
} else if (rope_head_dim == head_dim / 2) {
pos_emb_type = "HALF_HEAD_DIM";
Expand Down Expand Up @@ -344,12 +344,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
value_cache.data<cdata_t>())),
vsl.usual_lod_vp, // seq_lod
vsl.slot_mapping_vp, // real_batch
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
rope_head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
Expand Down Expand Up @@ -600,14 +602,16 @@ std::vector<paddle::Tensor> BlockAttnKernel(
key_cache.data<cdata_t>())),
const_cast<XPU_CType*>(reinterpret_cast<const XPU_CType*>(
value_cache.data<cdata_t>())),
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
decoder_seq_lod_vp, // seq_lod
decoder_batch_map_vp, // real_batch
decoder_context_len_cache_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rope_max_seqlen, // max_seqlen
param.head_num,
param.kv_head_num,
param.head_dim,
rope_head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
Expand Down Expand Up @@ -808,6 +812,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.head_num,
param.kv_head_num,
param.head_dim,
rope_head_dim,
param.max_batch_size,
block_size,
max_block_per_seq,
Expand Down
6 changes: 3 additions & 3 deletions custom_ops/xpu_ops/src/ops/fused_noaux_tc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ std::vector<std::vector<int64_t>> FusedNoAuxTcInferShape(
const float routed_scaling_factor) {
std::vector<int64_t> topk_ids_shape = {gating_logits_shape[0], top_k};
std::vector<int64_t> topk_weights_shape = {gating_logits_shape[0], top_k};
return {gating_logits_shape, topk_ids_shape, topk_weights_shape};
return {gating_logits_shape, topk_weights_shape, topk_ids_shape};
}

std::vector<paddle::DataType> FusedNoAuxTcInferDtype(
const paddle::DataType& gating_logits_dtype,
const paddle::DataType& bias_dtype) {
return {
gating_logits_dtype, paddle::DataType::INT64, paddle::DataType::FLOAT32};
gating_logits_dtype, paddle::DataType::FLOAT32, paddle::DataType::INT32};
}

PD_BUILD_STATIC_OP(fused_noaux_tc)
.Inputs({"gating_logits", "bias"})
.Outputs({"gating_logits_out", "topk_ids", "topk_weights"})
.Outputs({"gating_logits_out", "topk_weights", "topk_ids"})
.Attrs({"n_group: int",
"topk_group: int",
"top_k: int",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def apply_tp(
"""
gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc":
_, topk_idx, topk_weights = get_moe_scores(
_, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
)

if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
if extra_weight_attrs["output_dim"] is not None:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]

set_weight_attrs(
layer.weight,
Expand Down
13 changes: 9 additions & 4 deletions fastdeploy/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(self, position_ids):
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
partial_rotary_position_ids = position_ids / self.partial_rotary_factor
freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
if current_platform.is_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
Expand Down Expand Up @@ -89,9 +89,14 @@ def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
if current_platform.is_xpu():
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=[intermediate_size, intermediate_size],
output_sizes=[intermediate_size, intermediate_size],
with_bias=False,
)

Expand Down
1 change: 1 addition & 0 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def _init_share_inputs(self, max_num_seqs: int):
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)

# Set block tables
Expand Down
Loading