From 2612c2ae2fdab088c9035a2a1880288a0de14e82 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:23:28 -0800 Subject: [PATCH 1/9] Use and() and or() functions for logical-AND and OR With this commit, Slang will emit function calls to `and()` and `or()` for the logical-AND and logical-OR when the operands are non-scalar and the target profile is SM6.0 and above. This is required change from SM6.0. For WGSL, there is no operator overloadings of `&&` and `||` when the operands are non-scalar. Unlike HLSL, WGSL also don't have `and()` nor `or()`. Alternatively, we can use `select()`. --- lock | 0 source/slang/slang-emit-hlsl.cpp | 30 ++++++++++++++ source/slang/slang-emit-wgsl.cpp | 34 ++++++++++++++++ .../logic-no-short-circuit-evaluation.slang | 40 +++++++++++++++++++ .../logic-short-circuit-evaluation.slang | 14 ++++--- ...hort-circuit-evaluation.slang.expected.txt | 16 -------- 6 files changed, 113 insertions(+), 21 deletions(-) create mode 100644 lock create mode 100644 tests/compute/logic-no-short-circuit-evaluation.slang delete mode 100644 tests/compute/logic-short-circuit-evaluation.slang.expected.txt diff --git a/lock b/lock new file mode 100644 index 0000000000..e69de29bb2 diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 9ebec0893a..b36c749e2e 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -821,6 +821,36 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; } + case kIROp_And: + case kIROp_Or: + { + // SM6.0 requires to use `and()` and `or()` functions for the logical-AND and + // logical-OR, respectively, with non-scalar operands. + auto targetProfile = getTargetProgram()->getOptionSet().getProfile(); + if (targetProfile.getVersion() < ProfileVersion::DX_6_0) + return false; + + if (as(inst->getDataType())) + return false; + + const auto emitOp = getEmitOpForOp(inst->getOp()); + const auto info = getInfo(emitOp); + + if (inst->getOp() == kIROp_And) + { + m_writer->emit(" and("); + } + else + { + m_writer->emit(" or("); + } + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_BitCast: { // For simplicity, we will handle all bit-cast operations diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index aea766f9fb..656485cd45 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1312,6 +1312,40 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; + case kIROp_And: + case kIROp_Or: + { + // WGSL doesn't have operator overloadings for `&&` and `||` when the operands are + // non-scalar. Unlike HLSL, WGSL doesn't have `and()` and `or()`. + if (as(inst->getDataType())) + return false; + + const auto emitOp = getEmitOpForOp(inst->getOp()); + const auto info = getInfo(emitOp); + + if (inst->getOp() == kIROp_And) + { + m_writer->emit(" select("); + emitType(inst->getDataType()); + m_writer->emit("(false),"); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + else + { + m_writer->emit(" select("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitType(inst->getDataType()); + m_writer->emit("(true), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + return true; + } + case kIROp_BitCast: { // In WGSL there is a built-in bitcast function! diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang new file mode 100644 index 0000000000..f7ce147881 --- /dev/null +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -0,0 +1,40 @@ +//TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain +//TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain +//TEST(compute):SIMPLE(filecheck=WGSL):-target wgsl -stage compute -entry computeMain +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj + +// SM6.0 and above require to use `and()` and `or()` when the operands are non-scalar. +// And SM below SM6.0 doesn't have the functions, `and()` and `or()`. + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +static int result = 0; + +bool2 assignFunc(int index) +{ + result++; + return bool2(true); +} + +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + int index = dispatchThreadID.x; + + //SM5:!all({{.*}}&& + //SM6:!all({{ *}}and( + //WGSL:!all( select(vec2(false), + if (!all(bool2(index < 2) && assignFunc(index))) + { + result++; + } + + outputBuffer[index] = result; + + //CHK-COUNT-2: 1 + //CHK-COUNT-2: 2 +} diff --git a/tests/compute/logic-short-circuit-evaluation.slang b/tests/compute/logic-short-circuit-evaluation.slang index 585a04770c..31387d4894 100644 --- a/tests/compute/logic-short-circuit-evaluation.slang +++ b/tests/compute/logic-short-circuit-evaluation.slang @@ -1,8 +1,8 @@ -//TEST(compute):COMPARE_COMPUTE:-dx12 -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -compile-arg -O3 -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj // Test doing vector comparisons @@ -25,4 +25,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) // Only the last 4 elements will be 1. (index < 12) || assignFunc(index); + + //CHK-COUNT-4: 1 + //CHK-COUNT-8: 0 + //CHK-COUNT-4: 1 } diff --git a/tests/compute/logic-short-circuit-evaluation.slang.expected.txt b/tests/compute/logic-short-circuit-evaluation.slang.expected.txt deleted file mode 100644 index 945f08f2c0..0000000000 --- a/tests/compute/logic-short-circuit-evaluation.slang.expected.txt +++ /dev/null @@ -1,16 +0,0 @@ -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 -0 -0 -1 -1 -1 -1 From 61d5cd9d91e0a17cddf0db43de9823e5c82ce257 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 6 Feb 2025 17:41:07 -0800 Subject: [PATCH 2/9] Treat select() in the same way --- source/slang/slang-emit-hlsl.cpp | 24 ++++++++++- source/slang/slang-emit-wgsl.cpp | 6 +-- .../logic-no-short-circuit-evaluation.slang | 41 +++++++++++++------ 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index b36c749e2e..57ac1e8323 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -838,11 +838,11 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu if (inst->getOp() == kIROp_And) { - m_writer->emit(" and("); + m_writer->emit("and("); } else { - m_writer->emit(" or("); + m_writer->emit("or("); } emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); @@ -850,6 +850,26 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_Select: + { + // SM6.0 requires to use `select()` instead of the ternary operator "?:" when the + // operands are non-scalar. + auto targetProfile = getTargetProgram()->getOptionSet().getProfile(); + if (targetProfile.getVersion() < ProfileVersion::DX_6_0) + return false; + + if (as(inst->getDataType())) + return false; + + m_writer->emit("select("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } case kIROp_BitCast: { diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 656485cd45..6846c8eb7b 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1325,9 +1325,9 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu if (inst->getOp() == kIROp_And) { - m_writer->emit(" select("); + m_writer->emit("select("); emitType(inst->getDataType()); - m_writer->emit("(false),"); + m_writer->emit("(false), "); emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); @@ -1335,7 +1335,7 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } else { - m_writer->emit(" select("); + m_writer->emit("select("); emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); m_writer->emit(", "); emitType(inst->getDataType()); diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang index f7ce147881..d4fc82b2d1 100644 --- a/tests/compute/logic-no-short-circuit-evaluation.slang +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -1,13 +1,12 @@ //TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain //TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain //TEST(compute):SIMPLE(filecheck=WGSL):-target wgsl -stage compute -entry computeMain -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -xslang -Wno-30056 -// SM6.0 and above require to use `and()` and `or()` when the operands are non-scalar. -// And SM below SM6.0 doesn't have the functions, `and()` and `or()`. +// Testnig logical-AND, logical-OR and ternary operator with non-scalar operands //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -25,16 +24,34 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) { int index = dispatchThreadID.x; - //SM5:!all({{.*}}&& - //SM6:!all({{ *}}and( - //WGSL:!all( select(vec2(false), - if (!all(bool2(index < 2) && assignFunc(index))) + //SM5:(all({{.*}}&& + //SM6:(all(and( + //WGSL:(all(select(vec2(false), + if (all(bool2(index >= 1) && assignFunc(index))) + { + result++; + } + + //SM5:(all({{.*}}|| + //SM6:(all(or( + //WGSL:(all(select({{.*}}vec2(true), + if (all(bool2(index >= 2) || !assignFunc(index))) + { + result++; + } + + //SM5:(all({{.*}}?{{.*}}: + //SM6:(all(select({{ *}} + //WGSL:(all(select( + if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false))) { result++; } outputBuffer[index] = result; - //CHK-COUNT-2: 1 - //CHK-COUNT-2: 2 + //CHK:3 + //CHK-NEXT:4 + //CHK-NEXT:5 + //CHK-NEXT:6 } From 97344e24dad36619b438075f0e329b6ed0451c4f Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 6 Feb 2025 19:05:15 -0800 Subject: [PATCH 3/9] Fix compiler warning on gcc --- source/slang/slang-emit-wgsl.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 6846c8eb7b..e02ca569be 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1320,9 +1320,6 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu if (as(inst->getDataType())) return false; - const auto emitOp = getEmitOpForOp(inst->getOp()); - const auto info = getInfo(emitOp); - if (inst->getOp() == kIROp_And) { m_writer->emit("select("); From e6d17ca40ab305810429a9b5c5214bf53ab18a95 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Thu, 6 Feb 2025 20:02:45 -0800 Subject: [PATCH 4/9] Remove unused variable --- source/slang/slang-emit-hlsl.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 57ac1e8323..ff4514d695 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -833,9 +833,6 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu if (as(inst->getDataType())) return false; - const auto emitOp = getEmitOpForOp(inst->getOp()); - const auto info = getInfo(emitOp); - if (inst->getOp() == kIROp_And) { m_writer->emit("and("); From 6b80360ab8494c9bb8c4feef186d2dfb8003cee8 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:12:26 -0800 Subject: [PATCH 5/9] Handle the implicit cast to/from boolean type --- source/slang/hlsl.meta.slang | 4 +- source/slang/slang-emit-wgsl.cpp | 29 +++--- source/slang/slang-emit.cpp | 4 + source/slang/slang-ir-insts.h | 3 + .../slang-ir-legalize-binary-operator.cpp | 97 +++++++++++++++++++ .../slang/slang-ir-legalize-binary-operator.h | 5 + source/slang/slang-ir.cpp | 14 +++ .../logic-no-short-circuit-evaluation.slang | 35 ++++--- .../logic-short-circuit-evaluation.slang | 1 + 9 files changed, 163 insertions(+), 29 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 2399df2546..30ad77c4a5 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6256,7 +6256,9 @@ bool all(vector x) }; } case wgsl: - __intrinsic_asm "all"; + if (__isBool()) + __intrinsic_asm "all"; + __intrinsic_asm "all(vec$N0($0))"; default: bool result = true; for(int i = 0; i < N; ++i) diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index e02ca569be..4aa796bd25 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1317,28 +1317,31 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu { // WGSL doesn't have operator overloadings for `&&` and `||` when the operands are // non-scalar. Unlike HLSL, WGSL doesn't have `and()` and `or()`. - if (as(inst->getDataType())) + auto vecType = as(inst->getDataType()); + if (!vecType) return false; + // The function signature for `select` in WGSL is different from others: + // @const @must_use fn select(f: T, t: T, cond: bool) -> T if (inst->getOp() == kIROp_And) { - m_writer->emit("select("); - emitType(inst->getDataType()); - m_writer->emit("(false), "); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", "); + m_writer->emit(" select(vec"); + m_writer->emit(getIntVal(vecType->getElementCount())); + m_writer->emit("(false), "); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); - m_writer->emit(")"); + m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(") "); } else { - m_writer->emit("select("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", "); - emitType(inst->getDataType()); - m_writer->emit("(true), "); + m_writer->emit(" select("); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); - m_writer->emit(")"); + m_writer->emit(", vec"); + m_writer->emit(getIntVal(vecType->getElementCount())); + m_writer->emit("(true), "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(") "); } return true; } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 58376bbc1c..fb3afb895a 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -52,6 +52,7 @@ #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-legalize-array-return-type.h" +#include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-image-subscript.h" #include "slang-ir-legalize-mesh-outputs.h" @@ -1460,6 +1461,9 @@ Result linkAndOptimizeIR( floatNonUniformResourceIndex(irModule, NonUniformResourceIndexFloatMode::Textual); } + if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) || isWGPUTarget(targetRequest)) + legalizeLogicalAndOr(irModule->getModuleInst()); + // Legalize non struct parameters that are expected to be structs for HLSL. if (isD3DTarget(targetRequest)) legalizeNonStructParameterToStructForHLSL(irModule); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fcafe4bc6e..a151c8d81e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4518,6 +4518,9 @@ struct IRBuilder IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1); IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1); + IRInst* emitAnd(IRType* type, IRInst* left, IRInst* right); + IRInst* emitOr(IRType* type, IRInst* left, IRInst* right); + IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal); IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst); IRSPIRVAsmOperand* createSPIRVAsmOperandInst(IRInst* inst); diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp index a1affb7e9d..1595aa130d 100644 --- a/source/slang/slang-ir-legalize-binary-operator.cpp +++ b/source/slang/slang-ir-legalize-binary-operator.cpp @@ -118,4 +118,101 @@ void legalizeBinaryOp(IRInst* inst) } } +void legalizeLogicalAndOr(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_And: + case kIROp_Or: + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Logical-AND and logical-OR takes boolean types as its operands. + // If they are not, legalize them by casting to boolean type. + // + SLANG_ASSERT(inst->getOperandCount() == 2); + for (UInt i = 0; i < 2; i++) + { + auto operand = inst->getOperand(i); + auto operandDataType = operand->getDataType(); + + if (auto vecType = as(operandDataType)) + { + if (!as(vecType->getElementType())) + { + // Cast operand to vector + auto elemCount = vecType->getElementCount(); + auto vb = builder.getVectorType(builder.getBoolType(), elemCount); + auto v = builder.emitCast(vb, operand); + builder.replaceOperand(inst->getOperands() + i, v); + } + } + else if (!as(operandDataType)) + { + // Cast operand to bool + auto s = builder.emitCast(builder.getBoolType(), operand); + builder.replaceOperand(inst->getOperands() + i, s); + } + } + + // Legalize the return type; mostly for SPIRV. + // The return type of OpLogicalOr must be boolean type. + // If not, we need to recreate the instruction with boolean return type. + // Then, we have to cast it back to the original type so that other instrucitons that + // use have the matching types. + // + auto dataType = inst->getDataType(); + auto lhs = inst->getOperand(0); + auto rhs = inst->getOperand(1); + IRInst* newInst = nullptr; + + if (auto vecType = as(dataType)) + { + if (!as(vecType->getElementType())) + { + // Return type should be vector + auto elemCount = vecType->getElementCount(); + auto vb = builder.getVectorType(builder.getBoolType(), elemCount); + + if (inst->getOp() == kIROp_And) + { + newInst = builder.emitAnd(vb, lhs, rhs); + } + else + { + newInst = builder.emitOr(vb, lhs, rhs); + } + newInst = builder.emitCast(dataType, newInst); + } + } + else if (!as(dataType)) + { + // Return type should be bool + if (inst->getOp() == kIROp_And) + { + newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs); + } + else + { + newInst = builder.emitOr(builder.getBoolType(), lhs, rhs); + } + newInst = builder.emitCast(dataType, newInst); + } + + if (newInst && inst != newInst) + { + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + } + break; + } + + for (auto child : inst->getModifiableChildren()) + { + legalizeLogicalAndOr(child); + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h index 71c3197183..f9ebf90d89 100644 --- a/source/slang/slang-ir-legalize-binary-operator.h +++ b/source/slang/slang-ir-legalize-binary-operator.h @@ -13,4 +13,9 @@ struct IRInst; // signed operand is converted to unsigned. void legalizeBinaryOp(IRInst* inst); +// The logical binary operators such as AND and OR takes boolean types are its input. +// If they are in integer type, as an example, we need to explicitly cast to bool type. +// Also the return type from the logical operators should be a boolean type. +void legalizeLogicalAndOr(IRInst* inst); + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3314567f1e..2f3fb4a9bf 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6015,6 +6015,20 @@ IRInst* IRBuilder::emitShl(IRType* type, IRInst* left, IRInst* right) return inst; } +IRInst* IRBuilder::emitAnd(IRType* type, IRInst* left, IRInst* right) +{ + auto inst = createInst(this, kIROp_And, type, left, right); + addInst(inst); + return inst; +} + +IRInst* IRBuilder::emitOr(IRType* type, IRInst* left, IRInst* right) +{ + auto inst = createInst(this, kIROp_Or, type, left, right); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitGetNativePtr(IRInst* value) { auto valueType = value->getDataType(); diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang index d4fc82b2d1..efd0fe37d7 100644 --- a/tests/compute/logic-no-short-circuit-evaluation.slang +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -1,10 +1,11 @@ //TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain //TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain //TEST(compute):SIMPLE(filecheck=WGSL):-target wgsl -stage compute -entry computeMain -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -xslang -Wno-30056 -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -xslang -Wno-30056 -//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -xslang -Wno-30056 -//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056 // Testnig logical-AND, logical-OR and ternary operator with non-scalar operands @@ -15,7 +16,7 @@ static int result = 0; bool2 assignFunc(int index) { - result++; + result += 10; return bool2(true); } @@ -24,25 +25,29 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) { int index = dispatchThreadID.x; + // No short-circuiting for vector types + //SM5:(all({{.*}}&& //SM6:(all(and( - //WGSL:(all(select(vec2(false), + //WGSL:(all( select(vec2(false), if (all(bool2(index >= 1) && assignFunc(index))) { result++; } + // Intentionally using non-boolean type for testing. + //SM5:(all({{.*}}|| - //SM6:(all(or( - //WGSL:(all(select({{.*}}vec2(true), - if (all(bool2(index >= 2) || !assignFunc(index))) + //SM6:(or(vector( + //WGSL:( select({{.*}}, vec2(true), vec2( + if (all(int2(index >= 2) || !assignFunc(index))) { result++; } //SM5:(all({{.*}}?{{.*}}: - //SM6:(all(select({{ *}} - //WGSL:(all(select( + //SM6:(all(select( + //WGSL:(all(select(vec2(false), if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false))) { result++; @@ -50,8 +55,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) outputBuffer[index] = result; - //CHK:3 - //CHK-NEXT:4 - //CHK-NEXT:5 - //CHK-NEXT:6 + //CHK:30 + //CHK-NEXT:31 + //CHK-NEXT:32 + //CHK-NEXT:33 } diff --git a/tests/compute/logic-short-circuit-evaluation.slang b/tests/compute/logic-short-circuit-evaluation.slang index 31387d4894..eed30898f8 100644 --- a/tests/compute/logic-short-circuit-evaluation.slang +++ b/tests/compute/logic-short-circuit-evaluation.slang @@ -1,5 +1,6 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj From 487f35fae53f1a7447422a72f89a7b575b71563b Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:58:03 -0800 Subject: [PATCH 6/9] Disable MacOS test for now --- tests/compute/logic-short-circuit-evaluation.slang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compute/logic-short-circuit-evaluation.slang b/tests/compute/logic-short-circuit-evaluation.slang index eed30898f8..6cb3230fd0 100644 --- a/tests/compute/logic-short-circuit-evaluation.slang +++ b/tests/compute/logic-short-circuit-evaluation.slang @@ -1,6 +1,6 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj From 6dfb9c2a671eed84f36649370250c72858260f16 Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:01:31 -0800 Subject: [PATCH 7/9] Remove unnecessary space on emitted wgsl --- source/slang/slang-emit-wgsl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 4aa796bd25..13c79e9acc 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1325,23 +1325,23 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu // @const @must_use fn select(f: T, t: T, cond: bool) -> T if (inst->getOp() == kIROp_And) { - m_writer->emit(" select(vec"); + m_writer->emit("select(vec"); m_writer->emit(getIntVal(vecType->getElementCount())); m_writer->emit("(false), "); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); m_writer->emit(", "); emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(") "); + m_writer->emit(")"); } else { - m_writer->emit(" select("); + m_writer->emit("select("); emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); m_writer->emit(", vec"); m_writer->emit(getIntVal(vecType->getElementCount())); m_writer->emit("(true), "); emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(") "); + m_writer->emit(")"); } return true; } From da029b01131e12f7dd898ce827b58d368aa9373a Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:06:26 -0800 Subject: [PATCH 8/9] Disable metal test for now --- tests/compute/logic-no-short-circuit-evaluation.slang | 2 +- tests/compute/logic-short-circuit-evaluation.slang | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang index efd0fe37d7..6c5aed2377 100644 --- a/tests/compute/logic-no-short-circuit-evaluation.slang +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -3,7 +3,7 @@ //TEST(compute):SIMPLE(filecheck=WGSL):-target wgsl -stage compute -entry computeMain //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056 //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056 -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056 +//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056 //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056 //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056 diff --git a/tests/compute/logic-short-circuit-evaluation.slang b/tests/compute/logic-short-circuit-evaluation.slang index 6cb3230fd0..eed30898f8 100644 --- a/tests/compute/logic-short-circuit-evaluation.slang +++ b/tests/compute/logic-short-circuit-evaluation.slang @@ -1,6 +1,6 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj From 0a8a1b13f3c4e8dd45e0a82bb950a8cbdc82aaac Mon Sep 17 00:00:00 2001 From: Jay Kwak <82421531+jkwak-work@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:59:51 -0800 Subject: [PATCH 9/9] Fix Metal test --- source/slang/hlsl.meta.slang | 4 +++- source/slang/slang-emit.cpp | 3 ++- .../logic-no-short-circuit-evaluation.slang | 14 +++++++++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 30ad77c4a5..491f0ef4d6 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6228,7 +6228,9 @@ bool all(vector x) case hlsl: __intrinsic_asm "all"; case metal: - __intrinsic_asm "all"; + if (__isBool()) + __intrinsic_asm "all"; + __intrinsic_asm "all(bool$N0($0))"; case glsl: __intrinsic_asm "all(bvec$N0($0))"; case spirv: diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index d5b0d12b7a..e20a4a90fd 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1470,7 +1470,8 @@ Result linkAndOptimizeIR( floatNonUniformResourceIndex(irModule, NonUniformResourceIndexFloatMode::Textual); } - if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) || isWGPUTarget(targetRequest)) + if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) || + isWGPUTarget(targetRequest) || isMetalTarget(targetRequest)) legalizeLogicalAndOr(irModule->getModuleInst()); // Legalize non struct parameters that are expected to be structs for HLSL. diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang index 6c5aed2377..74351a5053 100644 --- a/tests/compute/logic-no-short-circuit-evaluation.slang +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -1,9 +1,10 @@ //TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain //TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain -//TEST(compute):SIMPLE(filecheck=WGSL):-target wgsl -stage compute -entry computeMain +//TEST(compute):SIMPLE(filecheck=WGS):-target wgsl -stage compute -entry computeMain +//TEST(compute):SIMPLE(filecheck=MTL):-target metal -stage compute -entry computeMain //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056 //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056 -//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056 //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056 //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056 @@ -29,7 +30,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}&& //SM6:(all(and( - //WGSL:(all( select(vec2(false), + //WGS:(all(select(vec2(false), + //MTL:(all({{.*}}&& if (all(bool2(index >= 1) && assignFunc(index))) { result++; @@ -39,7 +41,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}|| //SM6:(or(vector( - //WGSL:( select({{.*}}, vec2(true), vec2( + //WGS:(select({{.*}}, vec2(true), vec2( + //MTL:(all(bool2({{.*}}|| if (all(int2(index >= 2) || !assignFunc(index))) { result++; @@ -47,7 +50,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}?{{.*}}: //SM6:(all(select( - //WGSL:(all(select(vec2(false), + //WGS:(all(select(vec2(false), + //MTL:(all(select(bool2(false) if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false))) { result++;