-
Notifications
You must be signed in to change notification settings - Fork 64
Improvements to workgroup reduce + scan #876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
template<class Config, class BinOp, class device_capabilities=void> | ||
struct reduction | ||
{ | ||
template<class DataAccessor, class ScratchAccessor> | ||
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) | ||
{ | ||
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn; | ||
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor); | ||
} | ||
}; | ||
|
||
template<class Config, class BinOp, class device_capabilities=void> | ||
struct inclusive_scan | ||
{ | ||
template<class DataAccessor, class ScratchAccessor> | ||
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) | ||
{ | ||
impl::scan<Config,BinOp,false,Config::LevelCount,device_capabilities> fn; | ||
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor); | ||
} | ||
}; | ||
|
||
template<class Config, class BinOp, class device_capabilities=void> | ||
struct exclusive_scan | ||
{ | ||
template<class DataAccessor, class ScratchAccessor> | ||
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) | ||
{ | ||
impl::scan<Config,BinOp,true,Config::LevelCount,device_capabilities> fn; | ||
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make concepts for the accessors?
namespace impl | ||
{ | ||
template<uint16_t WorkgroupSizeLog2, uint16_t SubgroupSizeLog2> | ||
struct virtual_wg_size_log2 | ||
{ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>+SubgroupSizeLog2; | ||
}; | ||
|
||
template<class VirtualWorkgroup, uint16_t BaseItemsPerInvocation, uint16_t WorkgroupSizeLog2, uint16_t SubgroupSizeLog2> | ||
struct items_per_invocation | ||
{ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v<int16_t,WorkgroupSizeLog2-SubgroupSizeLog2*VirtualWorkgroup::levels,0>; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value<VirtualWorkgroup::levels==3, uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2>, ItemsPerInvocationProductLog2>::value; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2,0>; | ||
}; | ||
} | ||
|
||
template<uint32_t WorkgroupSizeLog2, uint32_t _SubgroupSizeLog2, uint32_t _ItemsPerInvocation> | ||
struct Configuration | ||
{ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2); | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; | ||
static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); | ||
|
||
// must have at least enough level 0 outputs to feed a single subgroup | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>; | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 0x1u << SubgroupsPerVirtualWorkgroupLog2; | ||
|
||
using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value; | ||
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation, WorkgroupSizeLog2, SubgroupSizeLog2>; | ||
// NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = items_per_invoc_t::value0; | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1; | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2; | ||
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); | ||
|
||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemSize = conditional_value<LevelCount==3,uint32_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1; | ||
}; | ||
|
||
// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 | ||
// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007 | ||
#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\ | ||
struct Configuration<11, 4, ITEMS_PER_INVOC>\ | ||
{\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 128u;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = 1u;\ | ||
};\ | ||
|
||
SPECIALIZE_CONFIG_CASE_2048_16(1) | ||
SPECIALIZE_CONFIG_CASE_2048_16(2) | ||
SPECIALIZE_CONFIG_CASE_2048_16(4) | ||
|
||
#undef SPECIALIZE_CONFIG_CASE_2048_16 | ||
|
||
|
||
namespace impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split the config into a separate header from the implementation
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\ | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = 1u;\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
your specializations are missing the SharedMemSize
member, also it should be ElementCount
, Elements
or Count
instead of Size
because Size is often associated with a byte-size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be also useful to have the same virtual_wg_size_log2
alias as a regular default spec
P.S. why do you need to partial specialize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho its better to explicit specialize items_per_invocation
and virtual_wg_size_log2
(ideally make them into one struct) to avoid specilizing this one
#include "nbl/builtin/hlsl/cpp_compat.hlsl" | ||
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" | ||
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" | ||
#include "nbl/builtin/hlsl/subgroup/ballot.hlsl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not use non-subgroup2 in workgroup2
template<uint32_t WorkgroupSizeLog2, uint32_t _SubgroupSizeLog2, uint32_t _ItemsPerInvocation> | ||
struct Configuration | ||
{ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep the WorkgroupSizeLog2
around as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also why are the template params uint32_t
and not uint16_t
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2; | ||
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); | ||
|
||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemSize = conditional_value<LevelCount==3,uint32_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need any SharedMemory if you have LevelCount<2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw I don't get why you need SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1
shared memory, the subgroups working at level 0 each produce only one output, and ItemsPerInvocation_1
only control how many level 0 outputs a single invocation from level 1 ingests (doesn't change the level 0 data size, only number of consumers)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I take it that SubgroupsPerVirtualWorkgroup
is how many subgroups you'll run at level 0, not 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's an easier formula (pseudocode)
uint16_t minInputLevel_1 = SubgroupSize*ItemsPerInvocation_1;
uint16_t SharedMemSize = LevelCount>1 ? ((LevelCount>2 ? (SubgroupSize*ItemsPerInvocation_2+1u):1u)*minInputLevel_1):0u;
// must have at least enough level 0 outputs to feed a single subgroup | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>; | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 0x1u << SubgroupsPerVirtualWorkgroupLog2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix with __
to indicate its an implementation detail (something that would be private
in C++)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw isn't SubgroupsPerVirtualWorkgroup
same as 0x1<<(virtual_wg_t::value-SubgroupSizeLog2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
work all this out from items_per_invoc_t::ItemsPerInvocationProductLog2
, level count, and subgroupsizelog2
struct virtual_wg_size_log2 | ||
{ | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; | ||
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>+SubgroupSizeLog2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need a static assert here that WorkgroupSizeLog2>=_SubgroupSizeLog2
you should also static assert that WorkgroupSizeLog2<=_SubgroupSizeLog2+4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the computation of value
is wrong (however levels
is correct), see the spreadsheet https://docs.google.com/spreadsheets/d/1tPbXd8AgSl3U0XSszx5TFiugP_ls_nBlWPqvvctSBnw/edit?usp=sharing
I marked in red the virtual workgroup counts which get needlessly inflated, e.g.:
- WG size 128 with SG size 128 will spit out a VG of 16k instead of 1, even though the level count is 1
- 64, 64 does 4096
- 32,32 does 1024
and so on
I'd probably explicit specialize this struct for the cases needing 1 level and 3 level scans.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that is always seems to happen when subgrouplog2==workgrouplog2 and level=1
{ | ||
using scalar_t = typename BinOp::type_t; | ||
using vector_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type | ||
// doesn't use scratch smem, need as param? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes for consistency, user can pass a dummy NOOP Scratch Accessor if they're aware the condif says there's 1 level of scanning to do
dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); | ||
value = reduction(value); | ||
dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no you're not supposed to be interested in the Work Group ID at all, the accessor is suposed to be just invoked with workgroup::SubgroupContiguousIndex
btw if the get
and set
are templated, aren't you supposed to call them with .template set<vector_t>
and so on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're not templated though?
// NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = items_per_invoc_t::value0; | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1; | ||
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me cook you up something stupid
template<typename T0, typename T1=void, typename T2=void> // TODO: in the future use BOOST_PP to make this
struct tuple;
template<uint32_t N, typename Tuple>
struct tuple_element;
template<typename T0>
struct tuple<T0,void,void>
{
T0 t0;
};
// specializations for less and less void elements
// base case
template<typename Head, typename T1, typename T2>
struct tuple_element<0,tuple<Head,T1,T2> >
{
using type = Head;
};
you can slap this stub into tuple.hlsl
then you go
using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >
I don't know if its useful, but just noting
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>; | ||
BinOp binop; | ||
|
||
vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this is a reduction, you don't need to keep any temporaries around
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>; | ||
BinOp binop; | ||
|
||
vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd use the dataAccessor.set
to store this temporary data, then the user has a choice of using an accessor with such a hidden "register" or save registers and do roundtrips through memory... see @Fletterio's FFT Bloom code
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) | ||
{ | ||
dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); | ||
scan_local[idx] = reduction0(scan_local[idx]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leverage your knowledge of this being a reduction, and get into a temporary, and scan into a temporary (no array necessary)
if (subgroup::ElectLast()) | ||
{ | ||
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); | ||
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hoist this out into some function in a base class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe into the config or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both the virtualSubgroupID
and bankedIndex
computation
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) | ||
{ | ||
scalar_t reduce_val; | ||
scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
everyone should get the same value from index 0 (outside the loop), its a reduction
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) | ||
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); | ||
lv1_val = reduction1(lv1_val); | ||
scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have one invocation write to index 0, its a reduction
struct reduction | ||
{ | ||
template<class DataAccessor, class ScratchAccessor> | ||
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking the reduction
should return the reduction instead of void
and not use the dataAccessor
for setting at all (you can even call it a ReadOnlyAccessor
)
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) | ||
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]); | ||
vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv1_t>(BinOp::identity), lv1_val, bool(invocationIndex)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you broke your exclusive prefix sum how did the test in the example not catch it?
shouldn't it be
lv1_val[0] = hlsl::mix(BinOp::identity,lv1_val[0],bool(invocationIndex));
instead?
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]); | ||
vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv1_t>(BinOp::identity), lv1_val, bool(invocationIndex)); | ||
shiftedInput = inclusiveScan1(shiftedInput); | ||
scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, you want to store the whole scan output over the whole scan input so that we can read it back properly (each subgroup needs its output's prefix sum back)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs to be the exact same loop as for the input
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); | ||
scalar_t left; | ||
scratchAccessor.get(virtualSubgroupID,left); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're not reading from the same index you wrote to
{ | ||
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
write the loop in reverse for readability, from ItemsPerInvocation_0-1
to 0 with a signed integer i
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) | ||
scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it UB to access an out of bound (negative) component on a vector
? This will attempt ot access -1
this is one case where its better to use an if
statement than a mix, your condition is consta,t you don't want predicated execution
you must write a loop from ItemsPerInvocation_0-1
to 1, and then handle the case of i=0
in a special way without a mix
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type | ||
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>; | ||
using vector_lv2_t = vector<scalar_t, Config::ItemsPerInvocation_2>; | ||
|
||
template<class DataAccessor, class ScratchAccessor> | ||
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) | ||
{ | ||
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>; | ||
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>; | ||
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>; | ||
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>; | ||
BinOp binop; | ||
|
||
vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; | ||
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); | ||
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0; | ||
// level 0 scan | ||
[unroll] | ||
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) | ||
{ | ||
dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); | ||
scan_local[idx] = inclusiveScan0(scan_local[idx]); | ||
if (subgroup::ElectLast()) | ||
{ | ||
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); | ||
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); | ||
scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan | ||
} | ||
} | ||
scratchAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
||
// level 1 scan | ||
const uint32_t lv1_smem_size = Config::SubgroupsPerVirtualWorkgroup*Config::ItemsPerInvocation_1; | ||
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1; | ||
if (glsl::gl_SubgroupID() < lv1_smem_size) | ||
{ | ||
vector_lv1_t lv1_val; | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) | ||
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); | ||
lv1_val = inclusiveScan1(lv1_val); | ||
if (subgroup::ElectLast()) | ||
{ | ||
const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); | ||
scratchAccessor.set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); | ||
} | ||
} | ||
scratchAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
||
// level 2 scan | ||
subgroup2::inclusive_scan<params_lv2_t> inclusiveScan2; | ||
if (glsl::gl_SubgroupID() == 0) | ||
{ | ||
vector_lv2_t lv2_val; | ||
const uint32_t prevIndex = invocationIndex-1; | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) | ||
scratchAccessor.get(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]); | ||
vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(invocationIndex)); | ||
shiftedInput = inclusiveScan2(shiftedInput); | ||
|
||
// combine with level 1, only last element of each | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::SubgroupsPerVirtualWorkgroup; i++) | ||
{ | ||
scalar_t last_val; | ||
scratchAccessor.get((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val); | ||
scalar_t val = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(i)); | ||
val = binop(last_val, shiftedInput[Config::ItemsPerInvocation_2-1]); | ||
scratchAccessor.set((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val); | ||
} | ||
} | ||
scratchAccessor.workgroupExecutionAndMemoryBarrier(); | ||
|
||
// combine with level 0 | ||
[unroll] | ||
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) | ||
{ | ||
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); | ||
const scalar_t left; | ||
scratchAccessor.get(virtualSubgroupID, left); | ||
if (Exclusive) | ||
{ | ||
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) | ||
scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0))); | ||
} | ||
else | ||
{ | ||
[unroll] | ||
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) | ||
scan_local[idx][i] = binop(left, scan_local[idx][i]); | ||
} | ||
dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); | ||
} | ||
} | ||
}; | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll review when you fix the rest (and 2 level scans)
uint32_t LastSubgroupInvocation() | ||
{ | ||
// why this code was wrong before: | ||
// - only compute can use SubgroupID |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SubgroupID is the subgroup ID, not SubgroupInvocationID
I'd template
template<int32_t AssumeAllActive=false>
uint32_t LastSubgroupInvocation()
{
if (AssumeAllActive)
return glsl::gl_SubgroupSize()-1;
else
return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true));
}
btw in code that knows you're going to have all invocations active and you KNOW the subgroup size, you should never call LastSubgroupInvocation
or ElectLast
|
||
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation> | ||
struct Configuration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config, but config for what?
should have arithmetic in the name for both the struct and heeader name
No description provided.