Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions test/WaveOps/WaveActiveMax.fp16.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#--- source.hlsl
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these HLSL sources appear to be identical except the base type used. Is there any way to reference a shared source file and use compilation arguments, like -D TYPE=half instead?

StructuredBuffer<half4> In : register(t0);
RWStructuredBuffer<half4> Out1 : register(u1); // test scalar
RWStructuredBuffer<half4> Out2 : register(u2); // test half2
RWStructuredBuffer<half4> Out3 : register(u3); // test half3
RWStructuredBuffer<half4> Out4 : register(u4); // test half4
RWStructuredBuffer<half4> Out5 : register(u5); // constant folding

[numthreads(4,1,1)]
void main(uint3 tid : SV_GroupThreadID)
{
half4 v = In[tid.x];

half s1 = WaveActiveMax( v.x );
half s2 = tid.x < 3 ? WaveActiveMax( v.x ) : 0;
Copy link

@tex3d tex3d Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious, but this is dependent on short-circuiting of ternary operator, introduced in HLSL 2021. It obfuscates the control flow a bit, which could be confusing for some at first read.

Another approach would be to group assignments under explicit control flow blocks. You could even utilize arrays to reduce the duplicated code. I believe the following is equivalent in functionality:

    half s1[4] = (half[4])0;
    half2 v2[4] = (half2[4])0;
    half3 v3[4] = (half3[4])0;
    half4 v4[4] = (half4[4])0;

    for (int i = 0; i < 4; i++) {
        if (tid.x <= i) {
            s1[i] = WaveActiveMax( v.x );
            v2[i] = WaveActiveMax( v.xy );
            v3[i] = WaveActiveMax( v.xyz );
            v4[i] = WaveActiveMax( v );
        }
    }

    Out1[tid.x].x = s1[tid.x];
    Out2[tid.x].xy = v2[tid.x];
    Out3[tid.x].xyz = v3[tid.x];
    Out4[tid.x] = v4[tid.x];

This seems easier to follow. It might also catch implementations that might apply illegal control flow optimizations impacting wave ops.

Written this way, I notice that it is a bit of an odd approach with the arrays. While we write to local arrays on each thread, we only ever output the array element corresponding to thread id on each thread. It seems you could do away with the local arrays altogether. Like this:

    half s1 = 0;
    half2 v2 = 0;
    half3 v3 = 0;
    half4 v4 = 0;

    // Reverse order allows thread local values to end up
    // with max value for all threads <= tid.x.
    for (int i = 4; i > 0; i--) {
        if (tid.x < i) {
            s1 = WaveActiveMax( v.x );
            v2 = WaveActiveMax( v.xy );
            v3 = WaveActiveMax( v.xyz );
            v4 = WaveActiveMax( v );
        }
    }

    Out1[tid.x].x = s1;
    Out2[tid.x].xy = v2;
    Out3[tid.x].xyz = v3;
    Out4[tid.x] = v4;

With this, thread 0 should end up with max of just thread 0, thread 1 will be max of thread 0 and thread 1, and so on with thread 3 being the max of threads 0, 1, 2, and 3. Each thread overwrites the local values until it reaches the final iteration the thread participates in (max of all thread values up to and including this thread).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another limitation to this approach is that we aren't really verifying that prior threads are getting the same max value as later threads. To do so simply, we might have to expand the outputs verified to write each result for each thread, instead of just the one corresponding to that thread (using the original or local array version I suggested first). This would require 4 times the output, but would be straightforward copy-paste or script work to extend expected outputs.

Like this:

    half s1[4] = (half[4])0;
    half2 v2[4] = (half2[4])0;
    half3 v3[4] = (half3[4])0;
    half4 v4[4] = (half4[4])0;

    for (int i = 0; i < 4; i++) {
        if (tid.x <= i) {
            s1[i] = WaveActiveMax( v.x );
            v2[i] = WaveActiveMax( v.xy );
            v3[i] = WaveActiveMax( v.xyz );
            v4[i] = WaveActiveMax( v );
        }
    }

    // Output all results for each thread to verify max broadcast
    for (int i = 0; i < 4; i++) {
        Out1[i * 4 + tid.x].x = s1[i];
        Out2[i * 4 + tid.x].xy = v2[i];
        Out3[i * 4 + tid.x].xyz = v3[i];
        Out4[i * 4 + tid.x] = v4[i];
    }

half s3 = tid.x < 2 ? WaveActiveMax( v.x ) : 0;
half s4 = tid.x < 1 ? WaveActiveMax( v.x ) : 0;

half2 v2_1 = WaveActiveMax( v.xy );
half2 v2_2 = tid.x < 3 ? WaveActiveMax( v.xy ) : half2(0,0);
half2 v2_3 = tid.x < 2 ? WaveActiveMax( v.xy ) : half2(0,0);
half2 v2_4 = tid.x < 1 ? WaveActiveMax( v.xy ) : half2(0,0);

half3 v3_1 = WaveActiveMax( v.xyz );
half3 v3_2 = tid.x < 3 ? WaveActiveMax( v.xyz ) : half3(0,0,0);
half3 v3_3 = tid.x < 2 ? WaveActiveMax( v.xyz ) : half3(0,0,0);
half3 v3_4 = tid.x < 1 ? WaveActiveMax( v.xyz ) : half3(0,0,0);

half4 v4_1 = WaveActiveMax( v );
half4 v4_2 = tid.x < 3 ? WaveActiveMax( v ) : half4(0,0,0,0);
half4 v4_3 = tid.x < 2 ? WaveActiveMax( v ) : half4(0,0,0,0);
half4 v4_4 = tid.x < 1 ? WaveActiveMax( v ) : half4(0,0,0,0);

half scalars[4] = { s4, s3, s2, s1 };
half2 vec2s [4] = { v2_4, v2_3, v2_2, v2_1 };
half3 vec3s [4] = { v3_4, v3_3, v3_2, v3_1 };
half4 vec4s [4] = { v4_4, v4_3, v4_2, v4_1 };

Out1[tid.x].x = scalars[tid.x];
Out2[tid.x].xy = vec2s[tid.x];
Out3[tid.x].xyz = vec3s[tid.x];
Out4[tid.x] = vec4s[tid.x];

// constant folding case
Out5[0] = WaveActiveMax(half4(1,2,3,4));
}

//--- pipeline.yaml

---
Shaders:
- Stage: Compute
Entry: main
DispatchSize: [1, 1, 1]
Buffers:
- Name: In
Format: Float16
Stride: 8
# 1, 10, 100, 1000, 2, 20, 200, 2000, 3, 30, 300, 3000, 4, 40, 400, 4000
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tests appear to use the same whole-number test values, which always increase in value for higher thread ids. I feel like this could miss implementation errors like:

  • implicit casting for wave op to int (no fractional values)
  • implicit casting for wave op to different bit-size (no values requiring selected bit size to accurately express)
  • not handling/preserving denorms when required (for half and double, as well as float when denorm mode is preserve)
  • mishandling of negative values
  • mishandling of INF/-INF
    • I believe inf/-inf should be reliably handled for this op, but could be wrong
  • just returning the value from the highest active thread index (bad implementation)

Copy link

@tex3d tex3d Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional note: the comment with the values is helpful, but would be even more helpful if formatted so you could line up values compared across threads, like:

    # x,  y,   z,    w
    # 1, 10, 100, 1000, # thread 0
    # 2, 20, 200, 2000, # thread 1
    # 3, 30, 300, 3000, # thread 2
    # 4, 40, 400, 4000  # thread 3

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, more sets of values could be tested with an outer loop around the value set, if that's desired given my feedback on the limitations of this chosen set of values.

Copy link

@tex3d tex3d Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option for selecting sets of threads could have been an input mask input set instead of always using max of values from threads [0,n] for each n=[0,3]. That could look like:

    # 13 active mask sets for threads 0, 1, 2, 3:
    # 1 1 1 1
    # 1 0 0 0
    # 0 1 0 0
    # 0 0 1 0
    # 0 0 0 1
    # 0 1 1 1
    # 1 0 1 1
    # 1 1 0 1
    # 1 1 1 0
    # 1 1 0 0
    # 0 0 1 1
    # 0 1 1 0
    # 1 0 0 1

An updated shader that could work with this and multiple value sets:

#define VALUE_SETS 2
#define NUM_MASKS 13
#define NUM_THREADS 4

struct MaskStruct {
    int mask[NUM_THREADS];
};

StructuredBuffer<half4> In  : register(t0);
StructuredBuffer<MaskStruct> Masks  : register(t1);
RWStructuredBuffer<half4> Out1 : register(u2); // test scalar
RWStructuredBuffer<half4> Out2 : register(u3); // test half2
RWStructuredBuffer<half4> Out3 : register(u4); // test half3
RWStructuredBuffer<half4> Out4 : register(u5); // test half4
RWStructuredBuffer<half4> Out5 : register(u6); // constant folding

[numthreads(NUM_THREADS,1,1)]
void main(uint3 tid : SV_GroupThreadID)
{
    for (int ValueSet = 0; ValueSet < VALUE_SETS; ValueSet++) {
        const uint ValueSetOffset = ValueSet * NUM_MASKS * NUM_THREADS;
        half4 v = In[ValueSet * NUM_THREADS + tid.x];
        for (int MaskIdx = 0; MaskIdx < NUM_MASKS; MaskIdx++) {
            const uint OutIdx = ValueSetOffset + MaskIdx * NUM_THREADS + tid.x;
            if (Masks[MaskIdx].mask[tid.x]) {
                Out1[OutIdx].x = WaveActiveMax( v.x );
                Out2[OutIdx].xy = WaveActiveMax( v.xy );
                Out3[OutIdx].xyz = WaveActiveMax( v.xyz );
                Out4[OutIdx] = WaveActiveMax( v );
            }
        }
    }

    // constant folding case
    Out5[0] = WaveActiveMax(half4(1,2,3,4));
}

See: https://www.godbolt.org/z/P1r3E869h

Data: [ 0x3c00, 0x4900, 0x5640, 0x63d0, 0x4000, 0x4d00, 0x5a40, 0x67d0, 0x4200, 0x4f80, 0x5cb0, 0x69dc, 0x4400, 0x5100, 0x5e40, 0x6bd0 ]
- Name: Out1
Format: Float16
Stride: 8
ZeroInitSize: 32
- Name: Out2
Format: Float16
Stride: 8
ZeroInitSize: 32
- Name: Out3
Format: Float16
Stride: 8
ZeroInitSize: 32
- Name: Out4
Format: Float16
Stride: 8
ZeroInitSize: 32
- Name: Out5
Format: Float16
Stride: 8
ZeroInitSize: 8
- Name: ExpectedOut1
Format: Float16
Stride: 8
Data: [ 0x3c00, 0x0, 0x0, 0x0, 0x4000, 0x0, 0x0, 0x0, 0x4200, 0x0, 0x0, 0x0, 0x4400, 0x0, 0x0, 0x0 ]
- Name: ExpectedOut2
Format: Float16
Stride: 8
Data: [ 0x3c00, 0x4900, 0x0, 0x0, 0x4000, 0x4d00, 0x0, 0x0, 0x4200, 0x4f80, 0x0, 0x0, 0x4400, 0x5100, 0x0, 0x0 ]
- Name: ExpectedOut3
Format: Float16
Stride: 8
Data: [ 0x3c00, 0x4900, 0x5640, 0x0, 0x4000, 0x4d00, 0x5a40, 0x0, 0x4200, 0x4f80, 0x5cb0, 0x0, 0x4400, 0x5100, 0x5e40, 0x0 ]
- Name: ExpectedOut4
Format: Float16
Stride: 8
Data: [ 0x3c00, 0x4900, 0x5640, 0x63d0, 0x4000, 0x4d00, 0x5a40, 0x67d0, 0x4200, 0x4f80, 0x5cb0, 0x69dc, 0x4400, 0x5100, 0x5e40, 0x6bd0 ]
- Name: ExpectedOut5
Format: Float16
Stride: 8
Data: [ 0x3C00, 0x4000, 0x4200, 0x4400 ]
Results:
- Result: ExpectedOut1
Rule: BufferExact
Actual: Out1
Expected: ExpectedOut1
- Result: ExpectedOut2
Rule: BufferExact
Actual: Out2
Expected: ExpectedOut2
- Result: ExpectedOut3
Rule: BufferExact
Actual: Out3
Expected: ExpectedOut3
- Result: ExpectedOut4
Rule: BufferExact
Actual: Out4
Expected: ExpectedOut4
- Result: ExpectedOut5
Rule: BufferExact
Actual: Out5
Expected: ExpectedOut5
DescriptorSets:
- Resources:
- Name: In
Kind: StructuredBuffer
DirectXBinding:
Register: 0
Space: 0
VulkanBinding:
Binding: 0
- Name: Out1
Kind: RWStructuredBuffer
DirectXBinding:
Register: 1
Space: 0
VulkanBinding:
Binding: 1
- Name: Out2
Kind: RWStructuredBuffer
DirectXBinding:
Register: 2
Space: 0
VulkanBinding:
Binding: 2
- Name: Out3
Kind: RWStructuredBuffer
DirectXBinding:
Register: 3
Space: 0
VulkanBinding:
Binding: 3
- Name: Out4
Kind: RWStructuredBuffer
DirectXBinding:
Register: 4
Space: 0
VulkanBinding:
Binding: 4
- Name: Out5
Kind: RWStructuredBuffer
DirectXBinding:
Register: 5
Space: 0
VulkanBinding:
Binding: 5

...
#--- end

# Bug https://github.com/llvm/llvm-project/issues/156775
# XFAIL: Clang

# Bug https://github.com/llvm/offload-test-suite/issues/393
# XFAIL: Metal

# RUN: split-file %s %t
# RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl
# RUN: %offloader %t/pipeline.yaml %t.o
177 changes: 177 additions & 0 deletions test/WaveOps/WaveActiveMax.fp32.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#--- source.hlsl
StructuredBuffer<float4> In : register(t0);
RWStructuredBuffer<float4> Out1 : register(u1); // test scalar
RWStructuredBuffer<float4> Out2 : register(u2); // test float2
RWStructuredBuffer<float4> Out3 : register(u3); // test float3
RWStructuredBuffer<float4> Out4 : register(u4); // test float4
RWStructuredBuffer<float4> Out5 : register(u5); // constant folding

[numthreads(4,1,1)]
void main(uint3 tid : SV_GroupThreadID)
{
float4 v = In[tid.x];

float s1 = WaveActiveMax( v.x );
float s2 = tid.x < 3 ? WaveActiveMax( v.x ) : 0;
float s3 = tid.x < 2 ? WaveActiveMax( v.x ) : 0;
float s4 = tid.x < 1 ? WaveActiveMax( v.x ) : 0;

float2 v2_1 = WaveActiveMax( v.xy );
float2 v2_2 = tid.x < 3 ? WaveActiveMax( v.xy ) : float2(0,0);
float2 v2_3 = tid.x < 2 ? WaveActiveMax( v.xy ) : float2(0,0);
float2 v2_4 = tid.x < 1 ? WaveActiveMax( v.xy ) : float2(0,0);

float3 v3_1 = WaveActiveMax( v.xyz );
float3 v3_2 = tid.x < 3 ? WaveActiveMax( v.xyz ) : float3(0,0,0);
float3 v3_3 = tid.x < 2 ? WaveActiveMax( v.xyz ) : float3(0,0,0);
float3 v3_4 = tid.x < 1 ? WaveActiveMax( v.xyz ) : float3(0,0,0);

float4 v4_1 = WaveActiveMax( v );
float4 v4_2 = tid.x < 3 ? WaveActiveMax( v ) : float4(0,0,0,0);
float4 v4_3 = tid.x < 2 ? WaveActiveMax( v ) : float4(0,0,0,0);
float4 v4_4 = tid.x < 1 ? WaveActiveMax( v ) : float4(0,0,0,0);

float scalars[4] = { s4, s3, s2, s1 };
float2 vec2s [4] = { v2_4, v2_3, v2_2, v2_1 };
float3 vec3s [4] = { v3_4, v3_3, v3_2, v3_1 };
float4 vec4s [4] = { v4_4, v4_3, v4_2, v4_1 };

Out1[tid.x].x = scalars[tid.x];
Out2[tid.x].xy = vec2s[tid.x];
Out3[tid.x].xyz = vec3s[tid.x];
Out4[tid.x] = vec4s[tid.x];

// constant folding case
Out5[0] = WaveActiveMax(float4(1,2,3,4));
}

//--- pipeline.yaml

---
Shaders:
- Stage: Compute
Entry: main
DispatchSize: [1, 1, 1]
Buffers:
- Name: In
Format: Float32
Stride: 16
Data: [ 1.0, 10.0, 100.0, 1000.0, 2.0, 20.0, 200.0, 2000.0, 3.0, 30.0, 300.0, 3000.0, 4.0, 40.0, 400.0, 4000.0 ]
- Name: Out1
Format: Float32
Stride: 16
ZeroInitSize: 64
- Name: Out2
Format: Float32
Stride: 16
ZeroInitSize: 64
- Name: Out3
Format: Float32
Stride: 16
ZeroInitSize: 64
- Name: Out4
Format: Float32
Stride: 16
ZeroInitSize: 64
- Name: Out5
Format: Float32
Stride: 16
ZeroInitSize: 16
- Name: ExpectedOut1
Format: Float32
Stride: 16
Data: [ 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0 ]
- Name: ExpectedOut2
Format: Float32
Stride: 16
Data: [ 1.0, 10.0, 0.0, 0.0, 2.0, 20.0, 0.0, 0.0, 3.0, 30.0, 0.0, 0.0, 4.0, 40.0, 0.0, 0.0 ]
- Name: ExpectedOut3
Format: Float32
Stride: 16
Data: [ 1.0, 10.0, 100.0, 0.0, 2.0, 20.0, 200.0, 0.0, 3.0, 30.0, 300.0, 0.0, 4.0, 40.0, 400.0, 0.0 ]
- Name: ExpectedOut4
Format: Float32
Stride: 16
Data: [ 1.0, 10.0, 100.0, 1000.0, 2.0, 20.0, 200.0, 2000.0, 3.0, 30.0, 300.0, 3000.0, 4.0, 40.0, 400.0, 4000.0 ]
- Name: ExpectedOut5
Format: Float32
Stride: 16
Data: [ 1.0, 2.0, 3.0, 4.0 ]
Results:
- Result: ExpectedOut1
Rule: BufferExact
Actual: Out1
Expected: ExpectedOut1
- Result: ExpectedOut2
Rule: BufferExact
Actual: Out2
Expected: ExpectedOut2
- Result: ExpectedOut3
Rule: BufferExact
Actual: Out3
Expected: ExpectedOut3
- Result: ExpectedOut4
Rule: BufferExact
Actual: Out4
Expected: ExpectedOut4
- Result: ExpectedOut5
Rule: BufferExact
Actual: Out5
Expected: ExpectedOut5
DescriptorSets:
- Resources:
- Name: In
Kind: StructuredBuffer
DirectXBinding:
Register: 0
Space: 0
VulkanBinding:
Binding: 0
- Name: Out1
Kind: RWStructuredBuffer
DirectXBinding:
Register: 1
Space: 0
VulkanBinding:
Binding: 1
- Name: Out2
Kind: RWStructuredBuffer
DirectXBinding:
Register: 2
Space: 0
VulkanBinding:
Binding: 2
- Name: Out3
Kind: RWStructuredBuffer
DirectXBinding:
Register: 3
Space: 0
VulkanBinding:
Binding: 3
- Name: Out4
Kind: RWStructuredBuffer
DirectXBinding:
Register: 4
Space: 0
VulkanBinding:
Binding: 4
- Name: Out5
Kind: RWStructuredBuffer
DirectXBinding:
Register: 5
Space: 0
VulkanBinding:
Binding: 5

...
#--- end

# Bug https://github.com/llvm/llvm-project/issues/156775
# XFAIL: Clang

# Tracked by https://github.com/llvm/offload-test-suite/issues/393
# XFAIL: Metal

# RUN: split-file %s %t
# RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl
# RUN: %offloader %t/pipeline.yaml %t.o
Loading