-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
Issue: Each non-circular buffered TMA load is handled separately including 8 steps: mbarrier alloc, init, sync, setExpectTx, TMA load, mbarrier wait, sync, mbarrier invalid:
uint64_t* T16 = reinterpret_cast<uint64_t*>(array + smem_offset + 4224);
mbarrier::init(toSmem(T16), 1U);
__syncthreads();
if ((Hopper::electSync(4294967295U) && b16)) {
uint64_t i18;
i18 = mbarrier::arriveExpectTX(toSmem(T16), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, a7, toSmem(T16) }), toSmem(T10));
mbarrier::wait(toSmem(T16), i18);
}
__syncthreads();
mbarrier::inval(toSmem(T16));
When there are N non-circular buffered TMA loads, these 8 steps are duplicated N times, which is inefficient. With 4 inputs, the achieved bandwidth is 84% SOL. The corresponding cuda code is:
uint64_t* T16 = reinterpret_cast<uint64_t*>(array + smem_offset + 4224);
mbarrier::init(toSmem(T16), 1U);
__syncthreads();
if ((Hopper::electSync(4294967295U) && b16)) {
uint64_t i18;
i18 = mbarrier::arriveExpectTX(toSmem(T16), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, a7, toSmem(T16) }), toSmem(T10));
mbarrier::wait(toSmem(T16), i18);
}
__syncthreads();
mbarrier::inval(toSmem(T16));
uint64_t* T17 = reinterpret_cast<uint64_t*>(array + smem_offset + 4096);
mbarrier::init(toSmem(T17), 1U);
__syncthreads();
if ((Hopper::electSync(4294967295U) && b16)) {
uint64_t i19;
i19 = mbarrier::arriveExpectTX(toSmem(T17), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, a7, toSmem(T17) }), toSmem(T9));
mbarrier::wait(toSmem(T17), i19);
}
__syncthreads();
mbarrier::inval(toSmem(T17));
uint64_t* T18 = reinterpret_cast<uint64_t*>(array + smem_offset + 12544);
mbarrier::init(toSmem(T18), 1U);
__syncthreads();
if ((Hopper::electSync(4294967295U) && b16)) {
uint64_t i20;
i20 = mbarrier::arriveExpectTX(toSmem(T18), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr9, a7, toSmem(T18) }), toSmem(T8));
mbarrier::wait(toSmem(T18), i20);
}
__syncthreads();
mbarrier::inval(toSmem(T18));
uint64_t* T19 = reinterpret_cast<uint64_t*>(array + smem_offset + 12416);
mbarrier::init(toSmem(T19), 1U);
__syncthreads();
if ((Hopper::electSync(4294967295U) && b16)) {
uint64_t i21;
i21 = mbarrier::arriveExpectTX(toSmem(T19), 4096U);
Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr10, a7, toSmem(T19) }), toSmem(T7));
mbarrier::wait(toSmem(T19), i21);
}
__syncthreads();
mbarrier::inval(toSmem(T19));
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels