Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use and() and or() functions for logical-AND and OR #6310

Merged
Empty file added lock
Empty file.
8 changes: 6 additions & 2 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -6228,7 +6228,9 @@ bool all(vector<T,N> x)
case hlsl:
__intrinsic_asm "all";
case metal:
__intrinsic_asm "all";
if (__isBool<T>())
__intrinsic_asm "all";
__intrinsic_asm "all(bool$N0($0))";
case glsl:
__intrinsic_asm "all(bvec$N0($0))";
case spirv:
Expand Down Expand Up @@ -6256,7 +6258,9 @@ bool all(vector<T,N> x)
};
}
case wgsl:
__intrinsic_asm "all";
if (__isBool<T>())
__intrinsic_asm "all";
__intrinsic_asm "all(vec$N0<bool>($0))";
default:
bool result = true;
for(int i = 0; i < N; ++i)
Expand Down
47 changes: 47 additions & 0 deletions source/slang/slang-emit-hlsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,53 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
break;
}
case kIROp_And:
case kIROp_Or:
{
// SM6.0 requires to use `and()` and `or()` functions for the logical-AND and
// logical-OR, respectively, with non-scalar operands.
auto targetProfile = getTargetProgram()->getOptionSet().getProfile();
if (targetProfile.getVersion() < ProfileVersion::DX_6_0)
return false;

if (as<IRBasicType>(inst->getDataType()))
return false;

if (inst->getOp() == kIROp_And)
{
m_writer->emit("and(");
}
else
{
m_writer->emit("or(");
}
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(")");
return true;
}
case kIROp_Select:
{
// SM6.0 requires to use `select()` instead of the ternary operator "?:" when the
// operands are non-scalar.
auto targetProfile = getTargetProgram()->getOptionSet().getProfile();
if (targetProfile.getVersion() < ProfileVersion::DX_6_0)
return false;

if (as<IRBasicType>(inst->getDataType()))
return false;

m_writer->emit("select(");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
m_writer->emit(")");
return true;
}

case kIROp_BitCast:
{
// For simplicity, we will handle all bit-cast operations
Expand Down
34 changes: 34 additions & 0 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,40 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
break;

case kIROp_And:
case kIROp_Or:
{
// WGSL doesn't have operator overloadings for `&&` and `||` when the operands are
// non-scalar. Unlike HLSL, WGSL doesn't have `and()` and `or()`.
auto vecType = as<IRVectorType>(inst->getDataType());
if (!vecType)
return false;

// The function signature for `select` in WGSL is different from others:
// @const @must_use fn select(f: T, t: T, cond: bool) -> T
if (inst->getOp() == kIROp_And)
{
m_writer->emit("select(vec");
m_writer->emit(getIntVal(vecType->getElementCount()));
m_writer->emit("<bool>(false), ");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", ");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(")");
}
else
{
m_writer->emit("select(");
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit(", vec");
m_writer->emit(getIntVal(vecType->getElementCount()));
m_writer->emit("<bool>(true), ");
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
m_writer->emit(")");
}
return true;
}

case kIROp_BitCast:
{
// In WGSL there is a built-in bitcast function!
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-layout.h"
#include "slang-ir-legalize-array-return-type.h"
#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-image-subscript.h"
#include "slang-ir-legalize-mesh-outputs.h"
Expand Down Expand Up @@ -1469,6 +1470,10 @@ Result linkAndOptimizeIR(
floatNonUniformResourceIndex(irModule, NonUniformResourceIndexFloatMode::Textual);
}

if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) ||
isWGPUTarget(targetRequest) || isMetalTarget(targetRequest))
legalizeLogicalAndOr(irModule->getModuleInst());

// Legalize non struct parameters that are expected to be structs for HLSL.
if (isD3DTarget(targetRequest))
legalizeNonStructParameterToStructForHLSL(irModule);
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -4520,6 +4520,9 @@ struct IRBuilder
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);

IRInst* emitAnd(IRType* type, IRInst* left, IRInst* right);
IRInst* emitOr(IRType* type, IRInst* left, IRInst* right);

IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal);
IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst);
IRSPIRVAsmOperand* createSPIRVAsmOperandInst(IRInst* inst);
Expand Down
97 changes: 97 additions & 0 deletions source/slang/slang-ir-legalize-binary-operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,101 @@ void legalizeBinaryOp(IRInst* inst)
}
}

void legalizeLogicalAndOr(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_And:
case kIROp_Or:
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);

// Logical-AND and logical-OR takes boolean types as its operands.
// If they are not, legalize them by casting to boolean type.
//
SLANG_ASSERT(inst->getOperandCount() == 2);
for (UInt i = 0; i < 2; i++)
{
auto operand = inst->getOperand(i);
auto operandDataType = operand->getDataType();

if (auto vecType = as<IRVectorType>(operandDataType))
{
if (!as<IRBoolType>(vecType->getElementType()))
{
// Cast operand to vector<bool,N>
auto elemCount = vecType->getElementCount();
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
auto v = builder.emitCast(vb, operand);
builder.replaceOperand(inst->getOperands() + i, v);
}
}
else if (!as<IRBoolType>(operandDataType))
{
// Cast operand to bool
auto s = builder.emitCast(builder.getBoolType(), operand);
builder.replaceOperand(inst->getOperands() + i, s);
}
}

// Legalize the return type; mostly for SPIRV.
// The return type of OpLogicalOr must be boolean type.
// If not, we need to recreate the instruction with boolean return type.
// Then, we have to cast it back to the original type so that other instrucitons that
// use have the matching types.
//
auto dataType = inst->getDataType();
auto lhs = inst->getOperand(0);
auto rhs = inst->getOperand(1);
IRInst* newInst = nullptr;

if (auto vecType = as<IRVectorType>(dataType))
{
if (!as<IRBoolType>(vecType->getElementType()))
{
// Return type should be vector<bool,N>
auto elemCount = vecType->getElementCount();
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);

if (inst->getOp() == kIROp_And)
{
newInst = builder.emitAnd(vb, lhs, rhs);
}
else
{
newInst = builder.emitOr(vb, lhs, rhs);
}
newInst = builder.emitCast(dataType, newInst);
}
}
else if (!as<IRBoolType>(dataType))
{
// Return type should be bool
if (inst->getOp() == kIROp_And)
{
newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs);
}
else
{
newInst = builder.emitOr(builder.getBoolType(), lhs, rhs);
}
newInst = builder.emitCast(dataType, newInst);
}

if (newInst && inst != newInst)
{
inst->replaceUsesWith(newInst);
inst->removeAndDeallocate();
}
}
break;
}

for (auto child : inst->getModifiableChildren())
{
legalizeLogicalAndOr(child);
}
}

} // namespace Slang
5 changes: 5 additions & 0 deletions source/slang/slang-ir-legalize-binary-operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,9 @@ struct IRInst;
// signed operand is converted to unsigned.
void legalizeBinaryOp(IRInst* inst);

// The logical binary operators such as AND and OR takes boolean types are its input.
// If they are in integer type, as an example, we need to explicitly cast to bool type.
// Also the return type from the logical operators should be a boolean type.
void legalizeLogicalAndOr(IRInst* inst);

} // namespace Slang
14 changes: 14 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6020,6 +6020,20 @@ IRInst* IRBuilder::emitShl(IRType* type, IRInst* left, IRInst* right)
return inst;
}

IRInst* IRBuilder::emitAnd(IRType* type, IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(this, kIROp_And, type, left, right);
addInst(inst);
return inst;
}

IRInst* IRBuilder::emitOr(IRType* type, IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(this, kIROp_Or, type, left, right);
addInst(inst);
return inst;
}

IRInst* IRBuilder::emitGetNativePtr(IRInst* value)
{
auto valueType = value->getDataType();
Expand Down
66 changes: 66 additions & 0 deletions tests/compute/logic-no-short-circuit-evaluation.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain
//TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain
//TEST(compute):SIMPLE(filecheck=WGS):-target wgsl -stage compute -entry computeMain
//TEST(compute):SIMPLE(filecheck=MTL):-target metal -stage compute -entry computeMain
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056

// Testnig logical-AND, logical-OR and ternary operator with non-scalar operands

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;

static int result = 0;

bool2 assignFunc(int index)
{
result += 10;
return bool2(true);
}

[numthreads(4, 1, 1)]
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
{
int index = dispatchThreadID.x;

// No short-circuiting for vector types

//SM5:(all({{.*}}&&
//SM6:(all(and(
//WGS:(all(select(vec2<bool>(false),
//MTL:(all({{.*}}&&
if (all(bool2(index >= 1) && assignFunc(index)))
{
result++;
}

// Intentionally using non-boolean type for testing.

//SM5:(all({{.*}}||
//SM6:(or(vector<bool,2>(
//WGS:(select({{.*}}, vec2<bool>(true), vec2<bool>(
//MTL:(all(bool2({{.*}}||
if (all(int2(index >= 2) || !assignFunc(index)))
{
result++;
}

//SM5:(all({{.*}}?{{.*}}:
//SM6:(all(select(
//WGS:(all(select(vec2<bool>(false),
//MTL:(all(select(bool2(false)
if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false)))
{
result++;
}

outputBuffer[index] = result;

//CHK:30
//CHK-NEXT:31
//CHK-NEXT:32
//CHK-NEXT:33
}
15 changes: 10 additions & 5 deletions tests/compute/logic-short-circuit-evaluation.slang
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//TEST(compute):COMPARE_COMPUTE:-dx12 -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -compile-arg -O3 -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj

// Test doing vector comparisons

Expand All @@ -25,4 +26,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)

// Only the last 4 elements will be 1.
(index < 12) || assignFunc(index);

//CHK-COUNT-4: 1
//CHK-COUNT-8: 0
//CHK-COUNT-4: 1
}
16 changes: 0 additions & 16 deletions tests/compute/logic-short-circuit-evaluation.slang.expected.txt

This file was deleted.

Loading