Skip to content

Commit 9aee33e

Browse files
authored
optimize phi::IntArray(common::vectorize()) in kernels/legacy/gpu (#77187)
1 parent dbb3ca7 commit 9aee33e

10 files changed

+36
-79
lines changed

paddle/phi/kernels/legacy/gpu/moe_combine_grad_kernel.cu

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,11 @@ void MoeCombineGradKernel(const Context& dev_ctx,
149149
DenseTensor* grad_combine_weights_helper) {
150150
dev_ctx.template Alloc<T>(grad_x);
151151
dev_ctx.template Alloc<T>(grad_combine_weights_helper);
152-
phi::Full<T, Context>(
153-
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
154-
phi::Full<T, Context>(
155-
dev_ctx,
156-
phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())),
157-
0,
158-
grad_combine_weights_helper);
152+
Full<T, Context>(dev_ctx, grad_x->dims(), 0, grad_x);
153+
Full<T, Context>(dev_ctx,
154+
grad_combine_weights_helper->dims(),
155+
0,
156+
grad_combine_weights_helper);
159157
auto x_shape = x.dims();
160158
auto combine_weights_shape = combine_weights.dims();
161159
moe_combine_bwd<T, Context>(dev_ctx,
@@ -182,18 +180,13 @@ void MoeCombineAutoGradKernel(const Context& dev_ctx,
182180
dev_ctx.template Alloc<T>(grad_combine_weights_helper);
183181
dev_ctx.template Alloc<int32_t>(grad_scatter_index);
184182

185-
phi::Full<T, Context>(
186-
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
187-
phi::Full<T, Context>(
188-
dev_ctx,
189-
phi::IntArray(common::vectorize(grad_combine_weights_helper->dims())),
190-
0,
191-
grad_combine_weights_helper);
192-
phi::Full<int32_t, Context>(
193-
dev_ctx,
194-
phi::IntArray(common::vectorize(grad_scatter_index->dims())),
195-
0,
196-
grad_scatter_index);
183+
Full<T, Context>(dev_ctx, grad_x->dims(), 0, grad_x);
184+
Full<T, Context>(dev_ctx,
185+
grad_combine_weights_helper->dims(),
186+
0,
187+
grad_combine_weights_helper);
188+
Full<int32_t, Context>(
189+
dev_ctx, grad_scatter_index->dims(), 0, grad_scatter_index);
197190

198191
// TODO(nieyuntao): Temporarily use 'grad_combine_weight_intermediate' to
199192
// bypass the grad_combine_weights_helper's shape mismatch to kernel shape
@@ -207,11 +200,10 @@ void MoeCombineAutoGradKernel(const Context& dev_ctx,
207200
x.dims()[1]}));
208201
grad_combine_weight_intermediate_meta.set_dtype(combine_weights.dtype());
209202
dev_ctx.template Alloc<T>(grad_combine_weight_intermediate);
210-
phi::Full<T, Context>(dev_ctx,
211-
phi::IntArray(common::vectorize(
212-
grad_combine_weight_intermediate->dims())),
213-
0,
214-
grad_combine_weight_intermediate);
203+
Full<T, Context>(dev_ctx,
204+
grad_combine_weight_intermediate->dims(),
205+
0,
206+
grad_combine_weight_intermediate);
215207

216208
auto x_shape = x.dims();
217209
auto combine_weights_shape = combine_weights.dims();

paddle/phi/kernels/legacy/gpu/moe_combine_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ void MoeCombineKernel(const Context& dev_ctx,
109109
DenseTensor* y) {
110110
dev_ctx.template Alloc<T>(y); // T cannot support phi::dtype::float8 very
111111
// well, maybe replaced with x.dtype();
112-
phi::Full<T, Context>(
113-
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
112+
Full<T, Context>(dev_ctx, y->dims(), 0, y);
114113
auto combine_weights_shape = combine_weights.dims();
115114
auto x_shape = x.dims();
116115
moe_combine_fwd<T, Context>(dev_ctx,

paddle/phi/kernels/legacy/gpu/moe_combine_no_weight_grad_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ void MoeCombineNoWeightGradKernel(const Context& dev_ctx,
105105
const int64_t k = scatter_index_shape[1];
106106

107107
dev_ctx.template Alloc<T>(grad_x);
108-
phi::Full<T, Context>(
109-
dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x);
108+
Full<T, Context>(dev_ctx, grad_x->dims(), 0, grad_x);
110109

111110
moe_combine_no_weight_bwd<T>(combine_weights.data<T>(),
112111
scatter_index.data<int>(),

paddle/phi/kernels/legacy/gpu/moe_gate_dispatch_and_quant_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,7 @@ void MoeDispatchAndQuantKernel(const Context &dev_ctx,
367367
sizeof(phi::float8_e4m3fn) * out_fp8->numel(),
368368
dev_ctx.stream());
369369

370-
phi::Full<float, Context>(
371-
dev_ctx, phi::IntArray(common::vectorize(scale->dims())), 1, scale);
370+
Full<float, Context>(dev_ctx, scale->dims(), 1, scale);
372371

373372
const auto &x_shape = x.dims();
374373
const auto &gate_logits_shape = gate_logits.dims();

paddle/phi/kernels/legacy/gpu/moe_gate_dispatch_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ void MoeGateDispatchKernel(const Context &dev_ctx,
129129
dev_ctx.template Alloc<float>(combine_weights);
130130
dev_ctx.template Alloc<T>(y);
131131

132-
phi::Full<T, Context>(
133-
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
132+
Full<T, Context>(dev_ctx, y->dims(), 0, y);
134133
auto x_dims = x.dims();
135134
auto gate_logits_dims = gate_logits.dims();
136135

paddle/phi/kernels/legacy/gpu/moe_gate_dispatch_permute_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ void MoEDispatchPermuteKernel(const Context &dev_ctx,
134134
dev_ctx.template Alloc<int>(scatter_index);
135135
dev_ctx.template Alloc<float>(combine_weights);
136136
dev_ctx.template Alloc<T>(y);
137-
phi::Full<T, Context>(
138-
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
137+
Full<T, Context>(dev_ctx, y->dims(), 0, y);
139138
const auto &x_shape = x.dims();
140139
const auto &gate_logits_shape = gate_logits.dims();
141140
int64_t num_rows = x_shape[0];

paddle/phi/kernels/legacy/gpu/moe_ops_partial_nosoftmaxtopk_grad_kernel.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,8 @@ void MoeGateDispatchPartialNoSoftMaxTopkGradKernel(
114114
DenseTensor* combine_weights_grad) {
115115
dev_ctx.template Alloc<T>(x_grad);
116116
dev_ctx.template Alloc<float>(combine_weights_grad);
117-
phi::Full<float, Context>(
118-
dev_ctx,
119-
phi::IntArray(common::vectorize(combine_weights_grad->dims())),
120-
0,
121-
combine_weights_grad);
117+
Full<float, Context>(
118+
dev_ctx, combine_weights_grad->dims(), 0, combine_weights_grad);
122119
DenseTensor t_scatter_index;
123120
phi::Transpose<int, Context>(
124121
dev_ctx, scatter_index, {1, 0}, &t_scatter_index);

paddle/phi/kernels/legacy/gpu/moe_ops_partial_nosoftmaxtopk_kernel.cu

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,7 @@ void apply_moe_dispatch_fwd(
439439
y->Resize({expert_offset_host.back(), x.dims()[1]});
440440
dev_ctx.template Alloc<T>(y);
441441
}
442-
phi::Full<T, Context>(
443-
dev_ctx, phi::IntArray(common::vectorize(y->dims())), 0, y);
442+
Full<T, Context>(dev_ctx, y->dims(), 0, y);
444443
copy_unpermuted_to_permuted_kernelLauncher(
445444
x.data<T>(),
446445
y->data<T>(), // out
@@ -526,31 +525,14 @@ void MoeGateDispatchPartialNoSoftMaxTopkKernel(
526525
dev_ctx.template Alloc<int64_t>(expert_offset);
527526
dev_ctx.template Alloc<int64_t>(expert_nums_local);
528527
dev_ctx.template Alloc<float>(combine_weights_out);
529-
phi::Full<int32_t, Context>(
530-
dev_ctx,
531-
phi::IntArray(common::vectorize(scatter_index->dims())),
532-
0,
533-
scatter_index);
534-
phi::Full<int32_t, Context>(
535-
dev_ctx,
536-
phi::IntArray(common::vectorize(scatter_index_rev->dims())),
537-
0,
538-
scatter_index_rev);
539-
phi::Full<int64_t, Context>(
540-
dev_ctx,
541-
phi::IntArray(common::vectorize(expert_offset->dims())),
542-
0,
543-
expert_offset);
544-
phi::Full<int64_t, Context>(
545-
dev_ctx,
546-
phi::IntArray(common::vectorize(expert_nums_local->dims())),
547-
0,
548-
expert_nums_local);
549-
phi::Full<float, Context>(
550-
dev_ctx,
551-
phi::IntArray(common::vectorize(combine_weights_out->dims())),
552-
0,
553-
combine_weights_out);
528+
Full<int32_t, Context>(dev_ctx, scatter_index->dims(), 0, scatter_index);
529+
Full<int32_t, Context>(
530+
dev_ctx, scatter_index_rev->dims(), 0, scatter_index_rev);
531+
Full<int64_t, Context>(dev_ctx, expert_offset->dims(), 0, expert_offset);
532+
Full<int64_t, Context>(
533+
dev_ctx, expert_nums_local->dims(), 0, expert_nums_local);
534+
Full<float, Context>(
535+
dev_ctx, combine_weights_out->dims(), 0, combine_weights_out);
554536
phi::Copy(
555537
dev_ctx, combine_weights, dev_ctx.GetPlace(), false, combine_weights_out);
556538
const auto &x_shape = x.dims();

paddle/phi/kernels/stride/indexing_kernel.cu

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,7 @@ void IndexPutGradKernel_V2(const Context& dev_ctx,
332332
dev_ctx.template Alloc<T>(x_grad);
333333
// Fill value_grad with 0.
334334
if (value_grad) {
335-
phi::Full<T, Context>(
336-
dev_ctx,
337-
phi::IntArray(common::vectorize(value_grad->dims())),
338-
0,
339-
value_grad);
335+
phi::Full<T, Context>(dev_ctx, value_grad->dims(), 0, value_grad);
340336
}
341337
return;
342338
}
@@ -390,10 +386,7 @@ void IndexPutGradKernel_V2(const Context& dev_ctx,
390386
x_grad->ShareInplaceVersionCounterWith(out_grad);
391387
} else {
392388
DenseTensor value_zero;
393-
phi::Full<T, Context>(dev_ctx,
394-
phi::IntArray(common::vectorize(value.dims())),
395-
0,
396-
&value_zero);
389+
phi::Full<T, Context>(dev_ctx, value.dims(), 0, &value_zero);
397390
if (funcs::IsInUint32Range(x_grad->numel(), value.numel())) {
398391
LaunchIndexPutKernel_V2<T, Context>(
399392
dev_ctx, out_grad, indices, value_zero, false, x_grad);

paddle/phi/kernels/stride/reduce_stride_kernel.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,7 @@ void ProdStrideKernel(const Context& dev_ctx,
320320

321321
if (x_.numel() == 0) {
322322
// fill with 1.
323-
phi::Full<T, Context>(
324-
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
323+
phi::Full<T, Context>(dev_ctx, out->dims(), 1, out);
325324
return;
326325
}
327326

@@ -647,8 +646,7 @@ void MeanStrideKernel(const Context& dev_ctx,
647646
}
648647

649648
if (x_.numel() == 0) {
650-
phi::Full<T, Context>(
651-
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
649+
phi::Full<T, Context>(dev_ctx, out->dims(), NAN, out);
652650
return;
653651
}
654652

0 commit comments

Comments
 (0)