-
Notifications
You must be signed in to change notification settings - Fork 21
Add WaveActiveMax tests #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dba563f
8fec3b6
059b8c4
f810b53
69c5020
c272633
601f2fa
31759d7
3021839
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
#--- source.hlsl | ||
StructuredBuffer<half4> In : register(t0); | ||
RWStructuredBuffer<half4> Out1 : register(u1); // test scalar | ||
RWStructuredBuffer<half4> Out2 : register(u2); // test half2 | ||
RWStructuredBuffer<half4> Out3 : register(u3); // test half3 | ||
RWStructuredBuffer<half4> Out4 : register(u4); // test half4 | ||
RWStructuredBuffer<half4> Out5 : register(u5); // constant folding | ||
|
||
[numthreads(4,1,1)] | ||
void main(uint3 tid : SV_GroupThreadID) | ||
{ | ||
half4 v = In[tid.x]; | ||
|
||
half s1 = WaveActiveMax( v.x ); | ||
half s2 = tid.x < 3 ? WaveActiveMax( v.x ) : 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not obvious, but this is dependent on short-circuiting of ternary operator, introduced in HLSL 2021. It obfuscates the control flow a bit, which could be confusing for some at first read. Another approach would be to group assignments under explicit control flow blocks. You could even utilize arrays to reduce the duplicated code. I believe the following is equivalent in functionality: half s1[4] = (half[4])0;
half2 v2[4] = (half2[4])0;
half3 v3[4] = (half3[4])0;
half4 v4[4] = (half4[4])0;
for (int i = 0; i < 4; i++) {
if (tid.x <= i) {
s1[i] = WaveActiveMax( v.x );
v2[i] = WaveActiveMax( v.xy );
v3[i] = WaveActiveMax( v.xyz );
v4[i] = WaveActiveMax( v );
}
}
Out1[tid.x].x = s1[tid.x];
Out2[tid.x].xy = v2[tid.x];
Out3[tid.x].xyz = v3[tid.x];
Out4[tid.x] = v4[tid.x]; This seems easier to follow. It might also catch implementations that might apply illegal control flow optimizations impacting wave ops. Written this way, I notice that it is a bit of an odd approach with the arrays. While we write to local arrays on each thread, we only ever output the array element corresponding to thread id on each thread. It seems you could do away with the local arrays altogether. Like this: half s1 = 0;
half2 v2 = 0;
half3 v3 = 0;
half4 v4 = 0;
// Reverse order allows thread local values to end up
// with max value for all threads <= tid.x.
for (int i = 4; i > 0; i--) {
if (tid.x < i) {
s1 = WaveActiveMax( v.x );
v2 = WaveActiveMax( v.xy );
v3 = WaveActiveMax( v.xyz );
v4 = WaveActiveMax( v );
}
}
Out1[tid.x].x = s1;
Out2[tid.x].xy = v2;
Out3[tid.x].xyz = v3;
Out4[tid.x] = v4; With this, thread 0 should end up with max of just thread 0, thread 1 will be max of thread 0 and thread 1, and so on with thread 3 being the max of threads 0, 1, 2, and 3. Each thread overwrites the local values until it reaches the final iteration the thread participates in (max of all thread values up to and including this thread). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another limitation to this approach is that we aren't really verifying that prior threads are getting the same max value as later threads. To do so simply, we might have to expand the outputs verified to write each result for each thread, instead of just the one corresponding to that thread (using the original or local array version I suggested first). This would require 4 times the output, but would be straightforward copy-paste or script work to extend expected outputs. Like this: half s1[4] = (half[4])0;
half2 v2[4] = (half2[4])0;
half3 v3[4] = (half3[4])0;
half4 v4[4] = (half4[4])0;
for (int i = 0; i < 4; i++) {
if (tid.x <= i) {
s1[i] = WaveActiveMax( v.x );
v2[i] = WaveActiveMax( v.xy );
v3[i] = WaveActiveMax( v.xyz );
v4[i] = WaveActiveMax( v );
}
}
// Output all results for each thread to verify max broadcast
for (int i = 0; i < 4; i++) {
Out1[i * 4 + tid.x].x = s1[i];
Out2[i * 4 + tid.x].xy = v2[i];
Out3[i * 4 + tid.x].xyz = v3[i];
Out4[i * 4 + tid.x] = v4[i];
} |
||
half s3 = tid.x < 2 ? WaveActiveMax( v.x ) : 0; | ||
half s4 = tid.x < 1 ? WaveActiveMax( v.x ) : 0; | ||
|
||
half2 v2_1 = WaveActiveMax( v.xy ); | ||
half2 v2_2 = tid.x < 3 ? WaveActiveMax( v.xy ) : half2(0,0); | ||
half2 v2_3 = tid.x < 2 ? WaveActiveMax( v.xy ) : half2(0,0); | ||
half2 v2_4 = tid.x < 1 ? WaveActiveMax( v.xy ) : half2(0,0); | ||
|
||
half3 v3_1 = WaveActiveMax( v.xyz ); | ||
half3 v3_2 = tid.x < 3 ? WaveActiveMax( v.xyz ) : half3(0,0,0); | ||
half3 v3_3 = tid.x < 2 ? WaveActiveMax( v.xyz ) : half3(0,0,0); | ||
half3 v3_4 = tid.x < 1 ? WaveActiveMax( v.xyz ) : half3(0,0,0); | ||
|
||
half4 v4_1 = WaveActiveMax( v ); | ||
half4 v4_2 = tid.x < 3 ? WaveActiveMax( v ) : half4(0,0,0,0); | ||
half4 v4_3 = tid.x < 2 ? WaveActiveMax( v ) : half4(0,0,0,0); | ||
half4 v4_4 = tid.x < 1 ? WaveActiveMax( v ) : half4(0,0,0,0); | ||
|
||
half scalars[4] = { s4, s3, s2, s1 }; | ||
half2 vec2s [4] = { v2_4, v2_3, v2_2, v2_1 }; | ||
half3 vec3s [4] = { v3_4, v3_3, v3_2, v3_1 }; | ||
half4 vec4s [4] = { v4_4, v4_3, v4_2, v4_1 }; | ||
|
||
Out1[tid.x].x = scalars[tid.x]; | ||
Out2[tid.x].xy = vec2s[tid.x]; | ||
Out3[tid.x].xyz = vec3s[tid.x]; | ||
Out4[tid.x] = vec4s[tid.x]; | ||
|
||
// constant folding case | ||
Out5[0] = WaveActiveMax(half4(1,2,3,4)); | ||
} | ||
|
||
//--- pipeline.yaml | ||
|
||
--- | ||
Shaders: | ||
- Stage: Compute | ||
Entry: main | ||
DispatchSize: [1, 1, 1] | ||
Buffers: | ||
- Name: In | ||
Format: Float16 | ||
Stride: 8 | ||
# 1, 10, 100, 1000, 2, 20, 200, 2000, 3, 30, 300, 3000, 4, 40, 400, 4000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All tests appear to use the same whole-number test values, which always increase in value for higher thread ids. I feel like this could miss implementation errors like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additional note: the comment with the values is helpful, but would be even more helpful if formatted so you could line up values compared across threads, like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally, more sets of values could be tested with an outer loop around the value set, if that's desired given my feedback on the limitations of this chosen set of values. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another option for selecting sets of threads could have been an input mask input set instead of always using max of values from threads [0,n] for each n=[0,3]. That could look like:
An updated shader that could work with this and multiple value sets: #define VALUE_SETS 2
#define NUM_MASKS 13
#define NUM_THREADS 4
struct MaskStruct {
int mask[NUM_THREADS];
};
StructuredBuffer<half4> In : register(t0);
StructuredBuffer<MaskStruct> Masks : register(t1);
RWStructuredBuffer<half4> Out1 : register(u2); // test scalar
RWStructuredBuffer<half4> Out2 : register(u3); // test half2
RWStructuredBuffer<half4> Out3 : register(u4); // test half3
RWStructuredBuffer<half4> Out4 : register(u5); // test half4
RWStructuredBuffer<half4> Out5 : register(u6); // constant folding
[numthreads(NUM_THREADS,1,1)]
void main(uint3 tid : SV_GroupThreadID)
{
for (int ValueSet = 0; ValueSet < VALUE_SETS; ValueSet++) {
const uint ValueSetOffset = ValueSet * NUM_MASKS * NUM_THREADS;
half4 v = In[ValueSet * NUM_THREADS + tid.x];
for (int MaskIdx = 0; MaskIdx < NUM_MASKS; MaskIdx++) {
const uint OutIdx = ValueSetOffset + MaskIdx * NUM_THREADS + tid.x;
if (Masks[MaskIdx].mask[tid.x]) {
Out1[OutIdx].x = WaveActiveMax( v.x );
Out2[OutIdx].xy = WaveActiveMax( v.xy );
Out3[OutIdx].xyz = WaveActiveMax( v.xyz );
Out4[OutIdx] = WaveActiveMax( v );
}
}
}
// constant folding case
Out5[0] = WaveActiveMax(half4(1,2,3,4));
} |
||
Data: [ 0x3c00, 0x4900, 0x5640, 0x63d0, 0x4000, 0x4d00, 0x5a40, 0x67d0, 0x4200, 0x4f80, 0x5cb0, 0x69dc, 0x4400, 0x5100, 0x5e40, 0x6bd0 ] | ||
- Name: Out1 | ||
Format: Float16 | ||
Stride: 8 | ||
ZeroInitSize: 32 | ||
- Name: Out2 | ||
Format: Float16 | ||
Stride: 8 | ||
ZeroInitSize: 32 | ||
- Name: Out3 | ||
Format: Float16 | ||
Stride: 8 | ||
ZeroInitSize: 32 | ||
- Name: Out4 | ||
Format: Float16 | ||
Stride: 8 | ||
ZeroInitSize: 32 | ||
- Name: Out5 | ||
Format: Float16 | ||
Stride: 8 | ||
ZeroInitSize: 8 | ||
- Name: ExpectedOut1 | ||
Format: Float16 | ||
Stride: 8 | ||
Data: [ 0x3c00, 0x0, 0x0, 0x0, 0x4000, 0x0, 0x0, 0x0, 0x4200, 0x0, 0x0, 0x0, 0x4400, 0x0, 0x0, 0x0 ] | ||
- Name: ExpectedOut2 | ||
Format: Float16 | ||
Stride: 8 | ||
Data: [ 0x3c00, 0x4900, 0x0, 0x0, 0x4000, 0x4d00, 0x0, 0x0, 0x4200, 0x4f80, 0x0, 0x0, 0x4400, 0x5100, 0x0, 0x0 ] | ||
- Name: ExpectedOut3 | ||
Format: Float16 | ||
Stride: 8 | ||
Data: [ 0x3c00, 0x4900, 0x5640, 0x0, 0x4000, 0x4d00, 0x5a40, 0x0, 0x4200, 0x4f80, 0x5cb0, 0x0, 0x4400, 0x5100, 0x5e40, 0x0 ] | ||
- Name: ExpectedOut4 | ||
Format: Float16 | ||
Stride: 8 | ||
Data: [ 0x3c00, 0x4900, 0x5640, 0x63d0, 0x4000, 0x4d00, 0x5a40, 0x67d0, 0x4200, 0x4f80, 0x5cb0, 0x69dc, 0x4400, 0x5100, 0x5e40, 0x6bd0 ] | ||
- Name: ExpectedOut5 | ||
Format: Float16 | ||
Stride: 8 | ||
Data: [ 0x3C00, 0x4000, 0x4200, 0x4400 ] | ||
Results: | ||
- Result: ExpectedOut1 | ||
Rule: BufferExact | ||
Actual: Out1 | ||
Expected: ExpectedOut1 | ||
- Result: ExpectedOut2 | ||
Rule: BufferExact | ||
Actual: Out2 | ||
Expected: ExpectedOut2 | ||
- Result: ExpectedOut3 | ||
Rule: BufferExact | ||
Actual: Out3 | ||
Expected: ExpectedOut3 | ||
- Result: ExpectedOut4 | ||
Rule: BufferExact | ||
Actual: Out4 | ||
Expected: ExpectedOut4 | ||
- Result: ExpectedOut5 | ||
Rule: BufferExact | ||
Actual: Out5 | ||
Expected: ExpectedOut5 | ||
DescriptorSets: | ||
- Resources: | ||
- Name: In | ||
Kind: StructuredBuffer | ||
DirectXBinding: | ||
Register: 0 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 0 | ||
- Name: Out1 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 1 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 1 | ||
- Name: Out2 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 2 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 2 | ||
- Name: Out3 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 3 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 3 | ||
- Name: Out4 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 4 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 4 | ||
- Name: Out5 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 5 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 5 | ||
|
||
... | ||
#--- end | ||
|
||
# Bug https://github.com/llvm/llvm-project/issues/156775 | ||
# XFAIL: Clang | ||
|
||
# Bug https://github.com/llvm/offload-test-suite/issues/393 | ||
# XFAIL: Metal | ||
|
||
# RUN: split-file %s %t | ||
# RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl | ||
# RUN: %offloader %t/pipeline.yaml %t.o |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
#--- source.hlsl | ||
StructuredBuffer<float4> In : register(t0); | ||
RWStructuredBuffer<float4> Out1 : register(u1); // test scalar | ||
RWStructuredBuffer<float4> Out2 : register(u2); // test float2 | ||
RWStructuredBuffer<float4> Out3 : register(u3); // test float3 | ||
RWStructuredBuffer<float4> Out4 : register(u4); // test float4 | ||
RWStructuredBuffer<float4> Out5 : register(u5); // constant folding | ||
|
||
[numthreads(4,1,1)] | ||
void main(uint3 tid : SV_GroupThreadID) | ||
{ | ||
float4 v = In[tid.x]; | ||
|
||
float s1 = WaveActiveMax( v.x ); | ||
float s2 = tid.x < 3 ? WaveActiveMax( v.x ) : 0; | ||
float s3 = tid.x < 2 ? WaveActiveMax( v.x ) : 0; | ||
float s4 = tid.x < 1 ? WaveActiveMax( v.x ) : 0; | ||
|
||
float2 v2_1 = WaveActiveMax( v.xy ); | ||
float2 v2_2 = tid.x < 3 ? WaveActiveMax( v.xy ) : float2(0,0); | ||
float2 v2_3 = tid.x < 2 ? WaveActiveMax( v.xy ) : float2(0,0); | ||
float2 v2_4 = tid.x < 1 ? WaveActiveMax( v.xy ) : float2(0,0); | ||
|
||
float3 v3_1 = WaveActiveMax( v.xyz ); | ||
float3 v3_2 = tid.x < 3 ? WaveActiveMax( v.xyz ) : float3(0,0,0); | ||
float3 v3_3 = tid.x < 2 ? WaveActiveMax( v.xyz ) : float3(0,0,0); | ||
float3 v3_4 = tid.x < 1 ? WaveActiveMax( v.xyz ) : float3(0,0,0); | ||
|
||
float4 v4_1 = WaveActiveMax( v ); | ||
float4 v4_2 = tid.x < 3 ? WaveActiveMax( v ) : float4(0,0,0,0); | ||
float4 v4_3 = tid.x < 2 ? WaveActiveMax( v ) : float4(0,0,0,0); | ||
float4 v4_4 = tid.x < 1 ? WaveActiveMax( v ) : float4(0,0,0,0); | ||
|
||
float scalars[4] = { s4, s3, s2, s1 }; | ||
float2 vec2s [4] = { v2_4, v2_3, v2_2, v2_1 }; | ||
float3 vec3s [4] = { v3_4, v3_3, v3_2, v3_1 }; | ||
float4 vec4s [4] = { v4_4, v4_3, v4_2, v4_1 }; | ||
|
||
Out1[tid.x].x = scalars[tid.x]; | ||
Out2[tid.x].xy = vec2s[tid.x]; | ||
Out3[tid.x].xyz = vec3s[tid.x]; | ||
Out4[tid.x] = vec4s[tid.x]; | ||
|
||
// constant folding case | ||
Out5[0] = WaveActiveMax(float4(1,2,3,4)); | ||
} | ||
|
||
//--- pipeline.yaml | ||
|
||
--- | ||
Shaders: | ||
- Stage: Compute | ||
Entry: main | ||
DispatchSize: [1, 1, 1] | ||
Buffers: | ||
- Name: In | ||
Format: Float32 | ||
Stride: 16 | ||
Data: [ 1.0, 10.0, 100.0, 1000.0, 2.0, 20.0, 200.0, 2000.0, 3.0, 30.0, 300.0, 3000.0, 4.0, 40.0, 400.0, 4000.0 ] | ||
- Name: Out1 | ||
Format: Float32 | ||
Stride: 16 | ||
ZeroInitSize: 64 | ||
- Name: Out2 | ||
Format: Float32 | ||
Stride: 16 | ||
ZeroInitSize: 64 | ||
- Name: Out3 | ||
Format: Float32 | ||
Stride: 16 | ||
ZeroInitSize: 64 | ||
- Name: Out4 | ||
Format: Float32 | ||
Stride: 16 | ||
ZeroInitSize: 64 | ||
- Name: Out5 | ||
Format: Float32 | ||
Stride: 16 | ||
ZeroInitSize: 16 | ||
- Name: ExpectedOut1 | ||
Format: Float32 | ||
Stride: 16 | ||
Data: [ 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0 ] | ||
- Name: ExpectedOut2 | ||
Format: Float32 | ||
Stride: 16 | ||
Data: [ 1.0, 10.0, 0.0, 0.0, 2.0, 20.0, 0.0, 0.0, 3.0, 30.0, 0.0, 0.0, 4.0, 40.0, 0.0, 0.0 ] | ||
- Name: ExpectedOut3 | ||
Format: Float32 | ||
Stride: 16 | ||
Data: [ 1.0, 10.0, 100.0, 0.0, 2.0, 20.0, 200.0, 0.0, 3.0, 30.0, 300.0, 0.0, 4.0, 40.0, 400.0, 0.0 ] | ||
- Name: ExpectedOut4 | ||
Format: Float32 | ||
Stride: 16 | ||
Data: [ 1.0, 10.0, 100.0, 1000.0, 2.0, 20.0, 200.0, 2000.0, 3.0, 30.0, 300.0, 3000.0, 4.0, 40.0, 400.0, 4000.0 ] | ||
- Name: ExpectedOut5 | ||
Format: Float32 | ||
Stride: 16 | ||
Data: [ 1.0, 2.0, 3.0, 4.0 ] | ||
Results: | ||
- Result: ExpectedOut1 | ||
Rule: BufferExact | ||
Actual: Out1 | ||
Expected: ExpectedOut1 | ||
- Result: ExpectedOut2 | ||
Rule: BufferExact | ||
Actual: Out2 | ||
Expected: ExpectedOut2 | ||
- Result: ExpectedOut3 | ||
Rule: BufferExact | ||
Actual: Out3 | ||
Expected: ExpectedOut3 | ||
- Result: ExpectedOut4 | ||
Rule: BufferExact | ||
Actual: Out4 | ||
Expected: ExpectedOut4 | ||
- Result: ExpectedOut5 | ||
Rule: BufferExact | ||
Actual: Out5 | ||
Expected: ExpectedOut5 | ||
DescriptorSets: | ||
- Resources: | ||
- Name: In | ||
Kind: StructuredBuffer | ||
DirectXBinding: | ||
Register: 0 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 0 | ||
- Name: Out1 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 1 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 1 | ||
- Name: Out2 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 2 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 2 | ||
- Name: Out3 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 3 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 3 | ||
- Name: Out4 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 4 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 4 | ||
- Name: Out5 | ||
Kind: RWStructuredBuffer | ||
DirectXBinding: | ||
Register: 5 | ||
Space: 0 | ||
VulkanBinding: | ||
Binding: 5 | ||
|
||
... | ||
#--- end | ||
|
||
# Bug https://github.com/llvm/llvm-project/issues/156775 | ||
# XFAIL: Clang | ||
|
||
# Tracked by https://github.com/llvm/offload-test-suite/issues/393 | ||
# XFAIL: Metal | ||
|
||
# RUN: split-file %s %t | ||
# RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl | ||
# RUN: %offloader %t/pipeline.yaml %t.o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these HLSL sources appear to be identical except the base type used. Is there any way to reference a shared source file and use compilation arguments, like
-D TYPE=half
instead?