Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working Compute Blit #780

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions include/nbl/builtin/hlsl/blit/common.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,30 @@ struct HistogramAccessor
InterlockedAdd(statsBuff[wgID * (ConstevalParameters::AlphaBinCount + 1) + bucket], v);
}
};
*/

struct SharedAccessor
{
float32_t get(float32_t idx)
template<typename T NBL_FUNC_REQUIRES(sizeof(T)==sizeof(uint32_t) && is_fundamental_v<T>)
T get(uint16_t idx)
{
return sMem[idx];
return bit_cast<T>(sMem[idx]);
}
void set(float32_t idx, float32_t val)

template<typename T NBL_FUNC_REQUIRES(sizeof(T)==sizeof(uint32_t) && is_integral_v<T>)
void atomicIncr(uint16_t idx)
{
sMem[idx] = val;
glsl::atomicAdd(sMem[idx],1u);
}

// TODO: figure out how to provide 16bit access, subgroup op compact?
template<typename T NBL_FUNC_REQUIRES(sizeof(T)==sizeof(uint32_t) && is_fundamental_v<T>)
void set(uint16_t idx, T val)
{
sMem[idx] = bit_cast<uint32_t>(val);
}
};
*/
static SharedAccessor sharedAccessor;

struct OutImgAccessor
{
Expand Down
290 changes: 137 additions & 153 deletions include/nbl/builtin/hlsl/blit/compute_blit.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#define _NBL_BUILTIN_HLSL_BLIT_INCLUDED_


#include <nbl/builtin/hlsl/ndarray_addressing.hlsl>
#include <nbl/builtin/hlsl/glsl_compat/core.hlsl>
#include <nbl/builtin/hlsl/blit/parameters.hlsl>
#include <nbl/builtin/hlsl/blit/common.hlsl>


namespace nbl
Expand All @@ -17,177 +16,162 @@ namespace hlsl
namespace blit
{

template <typename ConstevalParameters>
struct compute_blit_t
template<
bool DoCoverage,
uint16_t WorkGroupSize,
int32_t Dims,
typename InCombinedSamplerAccessor,
typename OutImageAccessor,
// typename KernelWeightsAccessor,
// typename HistogramAccessor,
typename SharedAccessor
>
void execute(
NBL_CONST_REF_ARG(InCombinedSamplerAccessor) inCombinedSamplerAccessor,
NBL_REF_ARG(OutImageAccessor) outImageAccessor,
// NBL_CONST_REF_ARG(KernelWeightsAccessor) kernelWeightsAccessor,
// NBL_REF_ARG(HistogramAccessor) histogramAccessor,
NBL_REF_ARG(SharedAccessor) sharedAccessor,
NBL_CONST_REF_ARG(SPerWorkgroup) params,
const uint16_t layer,
const vector<uint16_t,Dims> virtWorkGroupID
)
{
float32_t3 scale;
float32_t3 negativeSupport;
uint32_t kernelWeightsOffsetY;
uint32_t kernelWeightsOffsetZ;
uint32_t inPixelCount;
uint32_t outPixelCount;
uint16_t3 outputTexelsPerWG;
uint16_t3 inDims;
uint16_t3 outDims;
uint16_t3 windowDims;
uint16_t3 phaseCount;
uint16_t3 preloadRegion;
uint16_t3 iterationRegionXPrefixProducts;
uint16_t3 iterationRegionYPrefixProducts;
uint16_t3 iterationRegionZPrefixProducts;
uint16_t secondScratchOffset;

static compute_blit_t create(NBL_CONST_REF_ARG(parameters_t) params)
const uint16_t lastChannel = params.lastChannel;
const uint16_t coverageChannel = params.coverageChannel;

using uint16_tN = vector<uint16_t,Dims>;
// the dimensional truncation is desired
const uint16_tN outputTexelsPerWG = params.template getPerWGOutputExtent<Dims>();
// its the min XYZ corner of the area the workgroup will sample from to produce its output
const uint16_tN minOutputTexel = virtWorkGroupID*outputTexelsPerWG;

using float32_tN = vector<float32_t,Dims>;
const float32_tN scale = truncate<Dims>(params.scale);
const float32_tN inputMaxCoord = params.template getInputMaxCoord<Dims>();
const uint16_t inLevel = _static_cast<uint16_t>(params.inLevel);
const float32_tN inImageSizeRcp = inCombinedSamplerAccessor.template extentRcp<Dims>(inLevel);

using int32_tN = vector<int32_t,Dims>;
// can be negative, its the min XYZ corner of the area the workgroup will sample from to produce its output
const float32_tN regionStartCoord = params.inputUpperBound<Dims>(minOutputTexel);
const float32_tN regionNextStartCoord = params.inputUpperBound<Dims>(minOutputTexel+outputTexelsPerWG);

const uint16_t localInvocationIndex = _static_cast<uint16_t>(glsl::gl_LocalInvocationIndex()); // workgroup::SubgroupContiguousIndex()

// need to clear our atomic coverage counter to 0
const uint16_t coverageDWORD = _static_cast<uint16_t>(params.coverageDWORD);
if (DoCoverage)
{
compute_blit_t compute_blit;

compute_blit.scale = params.fScale;
compute_blit.negativeSupport = params.negativeSupport;
compute_blit.kernelWeightsOffsetY = params.kernelWeightsOffsetY;
compute_blit.kernelWeightsOffsetZ = params.kernelWeightsOffsetZ;
compute_blit.inPixelCount = params.inPixelCount;
compute_blit.outPixelCount = params.outPixelCount;
compute_blit.outputTexelsPerWG = params.getOutputTexelsPerWG();
compute_blit.inDims = params.inputDims;
compute_blit.outDims = params.outputDims;
compute_blit.windowDims = params.windowDims;
compute_blit.phaseCount = params.phaseCount;
compute_blit.preloadRegion = params.preloadRegion;
compute_blit.iterationRegionXPrefixProducts = params.iterationRegionXPrefixProducts;
compute_blit.iterationRegionYPrefixProducts = params.iterationRegionYPrefixProducts;
compute_blit.iterationRegionZPrefixProducts = params.iterationRegionZPrefixProducts;
compute_blit.secondScratchOffset = params.secondScratchOffset;

return compute_blit;
if (localInvocationIndex==0)
sharedAccessor.set(coverageDWORD,0u);
glsl::barrier();
}

template <
typename InCombinedSamplerAccessor,
typename OutImageAccessor,
typename KernelWeightsAccessor,
typename HistogramAccessor,
typename SharedAccessor>
void execute(
NBL_CONST_REF_ARG(InCombinedSamplerAccessor) inCombinedSamplerAccessor,
NBL_REF_ARG(OutImageAccessor) outImageAccessor,
NBL_CONST_REF_ARG(KernelWeightsAccessor) kernelWeightsAccessor,
NBL_REF_ARG(HistogramAccessor) histogramAccessor,
NBL_REF_ARG(SharedAccessor) sharedAccessor,
uint16_t3 workGroupID,
uint16_t localInvocationIndex)
//
const PatchLayout<Dims> preloadLayout = params.getPreloadMeta();
for (uint16_t virtualInvocation=localInvocationIndex; virtualInvocation<preloadLayout.getLinearEnd(); virtualInvocation+=WorkGroupSize)
{
const float3 halfScale = scale * float3(0.5f, 0.5f, 0.5f);
// bottom of the input tile
const uint32_t3 minOutputPixel = workGroupID * outputTexelsPerWG;
const float3 minOutputPixelCenterOfWG = float3(minOutputPixel)*scale + halfScale;
// this can be negative, in which case HW sampler takes care of wrapping for us
const int32_t3 regionStartCoord = int32_t3(ceil(minOutputPixelCenterOfWG - float3(0.5f, 0.5f, 0.5f) + negativeSupport));

const uint32_t virtualInvocations = preloadRegion.x * preloadRegion.y * preloadRegion.z;
for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < virtualInvocations; virtualInvocation += ConstevalParameters::WorkGroupSize)
// if we make all args in snakeCurveInverse 16bit maybe compiler will optimize the divisions into using float32_t
const uint16_tN virtualInvocationID = preloadLayout.getID(virtualInvocation);
const float32_tN inputTexCoordUnnorm = regionStartCoord + float32_tN(virtualInvocationID);

const float32_tN inputTexCoord = (inputTexCoordUnnorm + promote<float32_tN>(0.5f)) * inImageSizeRcp;
const float32_t4 loadedData = inCombinedSamplerAccessor.template get<float32_t,Dims>(inputTexCoord,layer,inLevel);

if (DoCoverage)
if (loadedData[coverageChannel]>=params.alphaRefValue &&
all(inputTexCoordUnnorm<regionNextStartCoord) && // not overlapping with the next tile
all(inputTexCoordUnnorm>=promote<float32_tN>(0.f)) && // within the image from below
all(inputTexCoordUnnorm<=inputMaxCoord) // within the image from above
)
{
const int32_t3 inputPixelCoord = regionStartCoord + int32_t3(ndarray_addressing::snakeCurveInverse(virtualInvocation, preloadRegion));
float32_t3 inputTexCoord = (inputPixelCoord + float32_t3(0.5f, 0.5f, 0.5f)) / inDims;
const float4 loadedData = inCombinedSamplerAccessor.get(inputTexCoord, workGroupID.z);

for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
sharedAccessor.set(ch * ConstevalParameters::SMemFloatsPerChannel + virtualInvocation, loadedData[ch]);
// TODO: atomicIncr or a workgroup reduction of ballots?
// sharedAccessor.template atomicIncr<uint32_t>(coverageDWORD);
}
GroupMemoryBarrierWithGroupSync();

const uint32_t3 iterationRegionPrefixProducts[3] = {iterationRegionXPrefixProducts, iterationRegionYPrefixProducts, iterationRegionZPrefixProducts};

uint32_t readScratchOffset = 0;
uint32_t writeScratchOffset = secondScratchOffset;
for (uint32_t axis = 0; axis < ConstevalParameters::BlitDimCount; ++axis)
[unroll(4)]
for (uint16_t ch=0; ch<4 && ch<=lastChannel; ch++)
sharedAccessor.template set<float32_t>(preloadCount*ch+virtualInvocation,loadedData[ch]);
}
glsl::barrier();

uint16_t readScratchOffset = uint16_t(0);
uint16_t writeScratchOffset = _static_cast<uint16_t>(params.secondScratchOffDWORD);
const uint16_tN windowExtent = params.template getWindowExtent<Dims>();
uint16_t prevLayout = preloadLayout;
uint32_t kernelWeightOffset = 0;
[unroll(3)]
for (int32_t axis=0; axis<Dims; axis++)
{
const PatchLayout<Dims> outputLayout = params.getPassMeta<Dims>(axis);
const uint16_t invocationCount = outputLayout.getLinearEnd();
const uint16_t phaseCount = params.getPhaseCount(axis);
const uint16_t windowLength = windowExtent[axis];
const uint16_t prevPassInvocationCount = prevLayout.getLinearEnd();
for (uint16_t virtualInvocation=localInvocationIndex; virtualInvocation<invocationCount; virtualInvocation+=WorkGroupSize)
{
for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < iterationRegionPrefixProducts[axis].z; virtualInvocation += ConstevalParameters::WorkGroupSize)
{
const uint32_t3 virtualInvocationID = ndarray_addressing::snakeCurveInverse(virtualInvocation, iterationRegionPrefixProducts[axis].xy);

uint32_t outputPixel = virtualInvocationID.x;
if (axis == 2)
outputPixel = virtualInvocationID.z;
outputPixel += minOutputPixel[axis];

if (outputPixel >= outDims[axis])
break;

const int32_t minKernelWindow = int32_t(ceil((outputPixel + 0.5f) * scale[axis] - 0.5f + negativeSupport[axis]));

// Combined stride for the two non-blitting dimensions, tightly coupled and experimentally derived with/by `iterationRegionPrefixProducts` above and the general order of iteration we use to avoid
// read bank conflicts.
uint32_t combinedStride;
{
if (axis == 0)
combinedStride = virtualInvocationID.z * preloadRegion.y + virtualInvocationID.y;
else if (axis == 1)
combinedStride = virtualInvocationID.z * outputTexelsPerWG.x + virtualInvocationID.y;
else if (axis == 2)
combinedStride = virtualInvocationID.y * outputTexelsPerWG.y + virtualInvocationID.x;
}

uint32_t offset = readScratchOffset + (minKernelWindow - regionStartCoord[axis]) + combinedStride*preloadRegion[axis];
const uint32_t windowPhase = outputPixel % phaseCount[axis];

uint32_t kernelWeightIndex;
if (axis == 0)
kernelWeightIndex = windowPhase * windowDims.x;
else if (axis == 1)
kernelWeightIndex = kernelWeightsOffsetY + windowPhase * windowDims.y;
else if (axis == 2)
kernelWeightIndex = kernelWeightsOffsetZ + windowPhase * windowDims.z;
// this always maps to the index in the current pass output
const uint16_tN virtualInvocationID = outputLayout.getID(virtualInvocation);

float4 kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex);
// we sweep along a line at a time, `[0]` is not a typo, look at the definition of `params.getPassMeta`
uint16_t localOutputCoord = virtualInvocationID[0];
// we can actually compute the output position of this line
const uint16_t globalOutputCoord = localOutputCoord+minOutputTexel[axis];
// hopefull the compiler will see that float32_t may be possible here due to `sizeof(float32_t mantissa)>sizeof(uint16_t)`
const uint32_t windowPhase = globalOutputCoord % phaseCount;

float4 accum = float4(0.f, 0.f, 0.f, 0.f);
for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
accum[ch] = sharedAccessor.get(ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch];
//const int32_t windowStart = ceil(localOutputCoord+0.5f;

for (uint32_t i = 1; i < windowDims[axis]; ++i)
// let us sweep
float32_t4 accum = promote<float32_t4>(0.f);
{
uint32_t kernelWeightIndex = windowPhase*windowLength+kernelWeightOffset;
// Need to use global coordinate because of ceil(x*scale) involvement
uint16_tN tmp; tmp[0] = params.inputUpperBound(globalOutputCoord,axis)-regionStartCoord;
[unroll(2)]
for (int32_t i=1; i<Dims; i++)
tmp[i] = virtualInvocationID[i];
// initialize to the first gather texel in range of the window for the output
uint16_t inputIndex = readScratchOffset+prevLayout.getIndex(tmp);
for (uint16_t i=0; i<windowLength; i++,inputIndex++)
{
kernelWeightIndex++;
offset++;

kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex);
for (uint ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
accum[ch] += sharedAccessor.get(ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch];
const float32_t4 kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex++);
[unroll(4)]
for (uint16_t ch=0; ch<4 && ch<=lastChannel; ch++)
accum[ch] += sharedAccessor.template get<float32_t>(ch*prevPassInvocationCount+inputIndex)*kernelWeight[ch];
}
}

const bool lastPass = (axis == (ConstevalParameters::BlitDimCount - 1));
if (lastPass)
{
// Tightly coupled with iteration order (`iterationRegionPrefixProducts`)
uint32_t3 outCoord = virtualInvocationID.yxz;
if (axis == 0)
outCoord = virtualInvocationID.xyz;
outCoord += minOutputPixel;

const uint32_t bucketIndex = uint32_t(round(clamp(accum.a, 0, 1) * float(ConstevalParameters::AlphaBinCount-1)));
histogramAccessor.atomicAdd(workGroupID.z, bucketIndex, uint32_t(1));

outImageAccessor.set(outCoord, workGroupID.z, accum);
}
else
// now write outputs
if (axis!=Dims-1) // not last pass
{
const uint32_t scratchOffset = writeScratchOffset+params.template getStorageIndex<Dims>(axis,virtualInvocationID);
[unroll(4)]
for (uint16_t ch=0; ch<4 && ch<=lastChannel; ch++)
sharedAccessor.template set(ch*invocationCount+scratchOffset,accum[ch]);
}
else
{
const uint16_tN coord = SPerWorkgroup::unswizzle<Dims>(virtualInvocationID)+minOutputTexel;
outImageAccessor.template set<float32_t,Dims>(coord,layer,accum);
if (DoCoverage)
{
uint32_t scratchOffset = writeScratchOffset;
if (axis == 0)
scratchOffset += ndarray_addressing::snakeCurve(virtualInvocationID.yxz, uint32_t3(preloadRegion.y, outputTexelsPerWG.x, preloadRegion.z));
else
scratchOffset += writeScratchOffset + ndarray_addressing::snakeCurve(virtualInvocationID.zxy, uint32_t3(preloadRegion.z, outputTexelsPerWG.y, outputTexelsPerWG.x));

for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
sharedAccessor.set(ch * ConstevalParameters::SMemFloatsPerChannel + scratchOffset, accum[ch]);
// const uint32_t bucketIndex = uint32_t(round(accum[coverageChannel] * float(ConstevalParameters::AlphaBinCount - 1)));
// histogramAccessor.atomicAdd(workGroupID.z,bucketIndex,uint32_t(1));
// intermediateAlphaImageAccessor.template set<float32_t,Dims>(coord,layer,accum);
}
}

const uint32_t tmp = readScratchOffset;
readScratchOffset = writeScratchOffset;
writeScratchOffset = tmp;
GroupMemoryBarrierWithGroupSync();
}
glsl::barrier();
kernelWeightOffset += phaseCount*windowExtent;
prevLayout = outputLayout;
// TODO: use Przemog's `nbl::hlsl::swap` method when the float64 stuff gets merged
const uint32_t tmp = readScratchOffset;
readScratchOffset = writeScratchOffset;
writeScratchOffset = tmp;
}
};
}

}
}
Expand Down
Loading