Skip to content

[Common] Add checks to CUDA kernel launch and CUDA API calls #2074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
DType::kInt32);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0));

/*
Defining the launcher order between the communication and GEMM kernels
Expand All @@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0));
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming));
} else {
_comm_launch_event = 0;
}
Expand All @@ -129,9 +129,13 @@ CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);
if (_comm_launch_event) {
cudaEventDestroy(_comm_launch_event);
}

if (_atomic_gemm) cudaFree(_counter.dptr());
if (_atomic_gemm) {
cudaFree(_counter.dptr());
}

for (size_t i = 0; i < _stream_compute.size(); i++) {
cudaStreamSynchronize(_stream_compute[i]);
Expand Down Expand Up @@ -698,7 +702,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv);
for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]);
for (size_t i = 0; i < _stream_send.size(); i++) {
cudaStreamDestroy(_stream_send[i]);
}
}

TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
if (comm->push == 0) {
kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]),
reinterpret_cast<int *>(flagptr));
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
void *srcptr = reinterpret_cast<char *>(comm->mem_ptr[srchandler]) + srcoffset;
void *dstptr = reinterpret_cast<char *>(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset;
Expand Down Expand Up @@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast<int *>(flagptr),
reinterpret_cast<int4 *>(srcptr), reinterpret_cast<int4 *>(dstptr),
signalonly ? 0 : bytes / 16, comm->ub_timeout);
if (!signalonly)
NVTE_CHECK_CUDA(cudaGetLastError());
if (!signalonly) {
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
NVTE_CHECK_CUDA(cudaGetLastError());
}
if (comm->use_ce) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
}
Expand All @@ -2532,6 +2536,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
reinterpret_cast<int *>(0 ? // temporary disable
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -2612,24 +2617,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
producer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
reset_counters_kernel<<<grid, block, 0, stream>>>(atomic_ptr, num_chunks, allgather);
NVTE_CHECK_CUDA(cudaGetLastError());
}

template <typename fp8type, int nvec>
Expand Down Expand Up @@ -2683,6 +2692,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
reduce_fp8_in_bf16_out_cuda<fp8type, nvec>
<<<grid, block, 0, stream>>>(inputs, output, scale, num_inputs, input_size,
num_aligned_elements_per_input, tot_input_size);
NVTE_CHECK_CUDA(cudaGetLastError());
}

template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale,
Expand Down Expand Up @@ -2738,4 +2748,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud
dim3 grid(num_blocks);
reduce_bf16_cuda<nvec><<<grid, block, 0, stream>>>(
inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size);
NVTE_CHECK_CUDA(cudaGetLastError());
}
4 changes: 3 additions & 1 deletion transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(t->scale.dptr),
reinterpret_cast<float *>(t->scale_inv.dptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -73,6 +74,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
dim3 grid(numBlocks, 1, 1); \
memset_kernel<vectorizedType> \
<<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes); \
NVTE_CHECK_CUDA(cudaGetLastError()); \
return; \
}

Expand All @@ -83,7 +85,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream

if (size_in_bytes > 4096) {
// Use cudaMemsetAsync for larger sizes.
cudaMemsetAsync(ptr, value, size_in_bytes, stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream));
return;
}

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/common/fused_attn/context_parallel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
NVTE_CHECK_CUDA(cudaGetLastError());
}

/***************************************************************************************************
Expand Down Expand Up @@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
Expand All @@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
reinterpret_cast<dtype *>(grad.data.dptr),
reinterpret_cast<dtype *>(grad_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
NVTE_CHECK_CUDA(cudaGetLastError());
}

template <typename dtype>
Expand Down Expand Up @@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
batch, total_tokens, world_size, rank);
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace context_parallel
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/fused_attn/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
NVTE_CHECK_CUDA(cudaGetLastError());
}

void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
Expand Down Expand Up @@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace flash_attention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
Expand Down Expand Up @@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS);
NVTE_CHECK_CUDA(cudaGetLastError());
if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO;
Expand Down Expand Up @@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
actual_b, b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
Expand Down Expand Up @@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS);
NVTE_CHECK_CUDA(cudaGetLastError());
if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO;
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
o_ragged_offset);
NVTE_CHECK_CUDA(cudaGetLastError());
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
Expand Down Expand Up @@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl(
cu_seqlens_to_offsets<<<gridDims, blockDims, 0, stream>>>(
b, h, d, reinterpret_cast<int32_t*>(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset,
o_ragged_offset);
NVTE_CHECK_CUDA(cudaGetLastError());
void* devPtrQKVRaggedOffset = reinterpret_cast<void*>(qkv_ragged_offset);
void* devPtrORaggedOffset = reinterpret_cast<void*>(o_ragged_offset);
void* devPtrMNKOverride = reinterpret_cast<void*>(actual_seqlens_q);
Expand Down Expand Up @@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
Expand Down Expand Up @@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV));
NVTE_CHECK_CUDA(cudaGetLastError());
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/fused_attn/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
NVTE_CHECK_CUDA(cudaGetLastError());
}
dim3 grid_size(b, max_ctx_len);
copy_to_kv_cache_kernel<<<grid_size, block_size, 0, stream>>>(
Expand All @@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
Expand Down Expand Up @@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
Expand Down
8 changes: 5 additions & 3 deletions transformer_engine/common/fused_attn/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream));
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout;
}

Expand All @@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t

fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
NVTE_CHECK_CUDA(cudaGetLastError());
}
12 changes: 7 additions & 5 deletions transformer_engine/common/fused_router/fused_moe_aux_loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.stream = stream;

// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize(
NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize(
&cluster_size,
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config);
reinterpret_cast<void*>(fused_moe_aux_loss_forward_kernel<DataType, IndexType>), &config));

cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
Expand All @@ -189,14 +189,15 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.numAttrs = 1;
config.attrs = attribute;

cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf);
NVTE_CHECK_CUDA(cudaLaunchKernelEx(
&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs, tokens_per_expert,
total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf));
} else {
size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
NVTE_CHECK_CUDA(cudaGetLastError());
}
}

Expand Down Expand Up @@ -247,6 +248,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
int grid_size = (num_rows + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
NVTE_CHECK_CUDA(cudaGetLastError());
}

void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
Expand Down
Loading