Skip to content

Improvements to workgroup reduce + scan #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: master
Choose a base branch
from

Conversation

keptsecret
Copy link
Contributor

No description provided.

Comment on lines 21 to 52
template<class Config, class BinOp, class device_capabilities=void>
struct reduction
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn;
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
}
};

template<class Config, class BinOp, class device_capabilities=void>
struct inclusive_scan
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
impl::scan<Config,BinOp,false,Config::LevelCount,device_capabilities> fn;
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
}
};

template<class Config, class BinOp, class device_capabilities=void>
struct exclusive_scan
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
impl::scan<Config,BinOp,true,Config::LevelCount,device_capabilities> fn;
fn.template __call<DataAccessor,ScratchAccessor>(dataAccessor, scratchAccessor);
}
};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make concepts for the accessors?

Comment on lines 21 to 89
namespace impl
{
template<uint16_t WorkgroupSizeLog2, uint16_t SubgroupSizeLog2>
struct virtual_wg_size_log2
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value;
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>+SubgroupSizeLog2;
};

template<class VirtualWorkgroup, uint16_t BaseItemsPerInvocation, uint16_t WorkgroupSizeLog2, uint16_t SubgroupSizeLog2>
struct items_per_invocation
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v<int16_t,WorkgroupSizeLog2-SubgroupSizeLog2*VirtualWorkgroup::levels,0>;
NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation;
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value<VirtualWorkgroup::levels==3, uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2>, ItemsPerInvocationProductLog2>::value;
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2,0>;
};
}

template<uint32_t WorkgroupSizeLog2, uint32_t _SubgroupSizeLog2, uint32_t _ItemsPerInvocation>
struct Configuration
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2;
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2);
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;
static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize");

// must have at least enough level 0 outputs to feed a single subgroup
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 0x1u << SubgroupsPerVirtualWorkgroupLog2;

using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value;
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation, WorkgroupSizeLog2, SubgroupSizeLog2>;
// NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!");

NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemSize = conditional_value<LevelCount==3,uint32_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1;
};

// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096
// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007
#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\
struct Configuration<11, 4, ITEMS_PER_INVOC>\
{\
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 128u;\
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = 1u;\
};\

SPECIALIZE_CONFIG_CASE_2048_16(1)
SPECIALIZE_CONFIG_CASE_2048_16(2)
SPECIALIZE_CONFIG_CASE_2048_16(4)

#undef SPECIALIZE_CONFIG_CASE_2048_16


namespace impl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split the config into a separate header from the implementation

NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = 1u;\

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your specializations are missing the SharedMemSize member, also it should be ElementCount, Elements or Count instead of Size because Size is often associated with a byte-size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be also useful to have the same virtual_wg_size_log2 alias as a regular default spec

P.S. why do you need to partial specialize?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho its better to explicit specialize items_per_invocation and virtual_wg_size_log2 (ideally make them into one struct) to avoid specilizing this one

#include "nbl/builtin/hlsl/cpp_compat.hlsl"
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use non-subgroup2 in workgroup2

Comment on lines 40 to 43
template<uint32_t WorkgroupSizeLog2, uint32_t _SubgroupSizeLog2, uint32_t _ItemsPerInvocation>
struct Configuration
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the WorkgroupSizeLog2 around as well

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why are the template params uint32_t and not uint16_t

NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!");

NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemSize = conditional_value<LevelCount==3,uint32_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need any SharedMemory if you have LevelCount<2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I don't get why you need SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1 shared memory, the subgroups working at level 0 each produce only one output, and ItemsPerInvocation_1 only control how many level 0 outputs a single invocation from level 1 ingests (doesn't change the level 0 data size, only number of consumers)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it that SubgroupsPerVirtualWorkgroup is how many subgroups you'll run at level 0, not 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's an easier formula (pseudocode)

uint16_t minInputLevel_1 = SubgroupSize*ItemsPerInvocation_1;
uint16_t SharedMemSize = LevelCount>1 ? ((LevelCount>2 ? (SubgroupSize*ItemsPerInvocation_2+1u):1u)*minInputLevel_1):0u;

Comment on lines 48 to 50
// must have at least enough level 0 outputs to feed a single subgroup
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 0x1u << SubgroupsPerVirtualWorkgroupLog2;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix with __ to indicate its an implementation detail (something that would be private in C++)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw isn't SubgroupsPerVirtualWorkgroup same as 0x1<<(virtual_wg_t::value-SubgroupSizeLog2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

work all this out from items_per_invoc_t::ItemsPerInvocationProductLog2, level count, and subgroupsizelog2

struct virtual_wg_size_log2
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value;
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>+SubgroupSizeLog2;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need a static assert here that WorkgroupSizeLog2>=_SubgroupSizeLog2

you should also static assert that WorkgroupSizeLog2<=_SubgroupSizeLog2+4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the computation of value is wrong (however levels is correct), see the spreadsheet https://docs.google.com/spreadsheets/d/1tPbXd8AgSl3U0XSszx5TFiugP_ls_nBlWPqvvctSBnw/edit?usp=sharing

I marked in red the virtual workgroup counts which get needlessly inflated, e.g.:

  • WG size 128 with SG size 128 will spit out a VG of 16k instead of 1, even though the level count is 1
  • 64, 64 does 4096
  • 32,32 does 1024
    and so on

I'd probably explicit specialize this struct for the cases needing 1 level and 3 level scans.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that is always seems to happen when subgrouplog2==workgrouplog2 and level=1

{
using scalar_t = typename BinOp::type_t;
using vector_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
// doesn't use scratch smem, need as param?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes for consistency, user can pass a dummy NOOP Scratch Accessor if they're aware the condif says there's 1 level of scanning to do

Comment on lines 114 to 116
dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value);
value = reduction(value);
dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no you're not supposed to be interested in the Work Group ID at all, the accessor is suposed to be just invoked with workgroup::SubgroupContiguousIndex

btw if the get and set are templated, aren't you supposed to call them with .template set<vector_t> and so on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not templated though?

Comment on lines 56 to 59
// NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me cook you up something stupid

template<typename T0, typename T1=void, typename T2=void> // TODO: in the future use BOOST_PP to make this
struct tuple;

template<uint32_t N, typename Tuple>
struct tuple_element;

template<typename T0>
struct tuple<T0,void,void>
{
   T0 t0;
};
// specializations for less and less void elements

// base case
template<typename Head, typename T1, typename T2>
struct tuple_element<0,tuple<Head,T1,T2> >
{
   using type = Head;
};

you can slap this stub into tuple.hlsl

then you go

using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >

I don't know if its useful, but just noting

using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
BinOp binop;

vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is a reduction, you don't need to keep any temporaries around

using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
BinOp binop;

vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use the dataAccessor.set to store this temporary data, then the user has a choice of using an accessor with such a hidden "register" or save registers and do roundtrips through memory... see @Fletterio's FFT Bloom code

for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
{
dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
scan_local[idx] = reduction0(scan_local[idx]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leverage your knowledge of this being a reduction, and get into a temporary, and scan into a temporary (no array necessary)

if (subgroup::ElectLast())
{
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hoist this out into some function in a base class

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe into the config or something?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both the virtualSubgroupID and bankedIndex computation

for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
{
scalar_t reduce_val;
scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everyone should get the same value from index 0 (outside the loop), its a reduction

for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
lv1_val = reduction1(lv1_val);
scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have one invocation write to index 0, its a reduction

struct reduction
{
template<class DataAccessor, class ScratchAccessor>
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking the reduction should return the reduction instead of void and not use the dataAccessor for setting at all (you can even call it a ReadOnlyAccessor)

[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv1_t>(BinOp::identity), lv1_val, bool(invocationIndex));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you broke your exclusive prefix sum how did the test in the example not catch it?

shouldn't it be

           lv1_val[0] =  hlsl::mix(BinOp::identity,lv1_val[0],bool(invocationIndex));

instead?

scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv1_t>(BinOp::identity), lv1_val, bool(invocationIndex));
shiftedInput = inclusiveScan1(shiftedInput);
scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, you want to store the whole scan output over the whole scan input so that we can read it back properly (each subgroup needs its output's prefix sum back)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be the exact same loop as for the input

Comment on lines 259 to 261
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
scalar_t left;
scratchAccessor.get(virtualSubgroupID,left);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're not reading from the same index you wrote to

{
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write the loop in reverse for readability, from ItemsPerInvocation_0-1 to 0 with a signed integer i

scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it UB to access an out of bound (negative) component on a vector ? This will attempt ot access -1

this is one case where its better to use an if statement than a mix, your condition is consta,t you don't want predicated execution

you must write a loop from ItemsPerInvocation_0-1 to 1, and then handle the case of i=0 in a special way without a mix

Comment on lines 280 to 460
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
using vector_lv2_t = vector<scalar_t, Config::ItemsPerInvocation_2>;

template<class DataAccessor, class ScratchAccessor>
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
{
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
BinOp binop;

vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
// level 0 scan
[unroll]
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
{
dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
scan_local[idx] = inclusiveScan0(scan_local[idx]);
if (subgroup::ElectLast())
{
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
}
}
scratchAccessor.workgroupExecutionAndMemoryBarrier();

// level 1 scan
const uint32_t lv1_smem_size = Config::SubgroupsPerVirtualWorkgroup*Config::ItemsPerInvocation_1;
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
if (glsl::gl_SubgroupID() < lv1_smem_size)
{
vector_lv1_t lv1_val;
[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
lv1_val = inclusiveScan1(lv1_val);
if (subgroup::ElectLast())
{
const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
scratchAccessor.set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
}
}
scratchAccessor.workgroupExecutionAndMemoryBarrier();

// level 2 scan
subgroup2::inclusive_scan<params_lv2_t> inclusiveScan2;
if (glsl::gl_SubgroupID() == 0)
{
vector_lv2_t lv2_val;
const uint32_t prevIndex = invocationIndex-1;
[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
scratchAccessor.get(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]);
vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(invocationIndex));
shiftedInput = inclusiveScan2(shiftedInput);

// combine with level 1, only last element of each
[unroll]
for (uint32_t i = 0; i < Config::SubgroupsPerVirtualWorkgroup; i++)
{
scalar_t last_val;
scratchAccessor.get((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val);
scalar_t val = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(i));
val = binop(last_val, shiftedInput[Config::ItemsPerInvocation_2-1]);
scratchAccessor.set((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val);
}
}
scratchAccessor.workgroupExecutionAndMemoryBarrier();

// combine with level 0
[unroll]
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
{
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
const scalar_t left;
scratchAccessor.get(virtualSubgroupID, left);
if (Exclusive)
{
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
}
else
{
[unroll]
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
scan_local[idx][i] = binop(left, scan_local[idx][i]);
}
dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
}
}
};

}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll review when you fix the rest (and 2 level scans)

uint32_t LastSubgroupInvocation()
{
// why this code was wrong before:
// - only compute can use SubgroupID

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubgroupID is the subgroup ID, not SubgroupInvocationID

I'd template

template<int32_t AssumeAllActive=false>
uint32_t LastSubgroupInvocation()
{
   if (AssumeAllActive)
      return glsl::gl_SubgroupSize()-1;
   else
      return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true));
}

btw in code that knows you're going to have all invocations active and you KNOW the subgroup size, you should never call LastSubgroupInvocation or ElectLast

Comment on lines 34 to 36

template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation>
struct Configuration

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config, but config for what?

should have arithmetic in the name for both the struct and heeader name

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants