Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow scalar broadcasting in VisitorRowBroadcast and VisitorColBroadcast #1539

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ struct VisitorAuxLoad{
template<
class ThreadMap,
class Element,
class StrideMNL
class StrideMNL,
bool EnableNullptr = false
>
struct VisitorRowBroadcast {

Expand Down Expand Up @@ -403,10 +404,31 @@ struct VisitorRowBroadcast {
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);

if constexpr (EnableNullptr) {
if (params_ptr->ptr_row == nullptr) {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element *>(&filled_vec)[i] = params_ptr->null_default;
}

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
return;
}
}

// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const *)&src_v(i), guard);
}
}

Expand Down Expand Up @@ -464,7 +486,8 @@ struct VisitorRowBroadcast {
template<
class ThreadMap,
class Element,
class StrideMNL = Stride<_1,_0,_0>
class StrideMNL = Stride<_1,_0,_0>,
bool EnableNullptr = false
>
struct VisitorColBroadcast {

Expand Down Expand Up @@ -524,11 +547,29 @@ struct VisitorColBroadcast {
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rCol);

Tensor pred = make_tensor<bool>(shape(tC_gCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tC_cCol(i)) < m;
}

if constexpr (EnableNullptr) {
if (params_ptr->ptr_col == nullptr) {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if (pred(i)) {
dst_v(i) = params_ptr->null_default;
}
}
return;
}
}

// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
}

Expand Down