diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 499d45c724..14dc4757b7 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -1019,6 +1019,100 @@ struct negate> { } }; +/// Fused and-popc-add +template +struct and_popc_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + +/// Fused xor-popc-add +template +struct xor_popc_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + /// Fused multiply-add template struct multiply_add, Array, Array> { diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 65e49d5290..1f3a696306 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -579,7 +579,7 @@ struct guarded_multiply_add_relu0 { } }; -/// Fused multiply-add +/// Fused and-add template struct and_add { CUTLASS_HOST_DEVICE @@ -589,6 +589,33 @@ struct and_add { }; +/// Fused and-popc-add +template +struct and_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A and_result = a & b; + + #if defined(__CUDA__ARCH__) + int popc_result = __popc(and_result); + + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __popc(static_cast(and_result >> 32)); + } + + #else + int popc_result = __builtin_popcount(and_result); + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __builtin_popcount(static_cast(and_result >> 32)); + } + + #endif + + return C(popc_result) + c; + + } +}; + /// Fused multiply-add template struct xor_add { @@ -598,6 +625,34 @@ struct xor_add { } }; + +/// Fused xor-popc-add +template +struct xor_popc_add { + CUTLASS_HOST_DEVICE + C operator()(A const &a, B const &b, C const &c) const { + A and_result = a ^ b; + + #if defined(__CUDA__ARCH__) + int popc_result = __popc(and_result); + + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __popc(static_cast(and_result >> 32)); + } + + #else + int popc_result = __builtin_popcount(and_result); + if constexpr (sizeof(A) == sizeof(uint64_t)) { + popc_result += __builtin_popcount(static_cast(and_result >> 32)); + } + + #endif + + return C(popc_result) + c; + + } +}; + namespace detail { // Whether namespace-unqualified conj(t) for t of type T is diff --git a/tools/util/include/cutlass/util/reference/host/gemm.h b/tools/util/include/cutlass/util/reference/host/gemm.h index 0388813109..5a51489d1e 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm.h +++ b/tools/util/include/cutlass/util/reference/host/gemm.h @@ -352,7 +352,7 @@ struct Gemm>( + ScalarType, ComputeType, xor_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); } @@ -367,7 +367,7 @@ struct Gemm>( + ScalarType, ComputeType, xor_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); } }; @@ -389,7 +389,7 @@ struct Gemm>( + ScalarType, ComputeType, and_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); } @@ -404,7 +404,7 @@ struct Gemm>( + ScalarType, ComputeType, and_popc_add>( problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); } };