Skip to content

Commit 6b80360

Browse files
committed
Handle the implicit cast to/from boolean type
1 parent e6d17ca commit 6b80360

9 files changed

+163
-29
lines changed

source/slang/hlsl.meta.slang

+3-1
Original file line numberDiff line numberDiff line change
@@ -6256,7 +6256,9 @@ bool all(vector<T,N> x)
62566256
};
62576257
}
62586258
case wgsl:
6259-
__intrinsic_asm "all";
6259+
if (__isBool<T>())
6260+
__intrinsic_asm "all";
6261+
__intrinsic_asm "all(vec$N0<bool>($0))";
62606262
default:
62616263
bool result = true;
62626264
for(int i = 0; i < N; ++i)

source/slang/slang-emit-wgsl.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -1317,28 +1317,31 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
13171317
{
13181318
// WGSL doesn't have operator overloadings for `&&` and `||` when the operands are
13191319
// non-scalar. Unlike HLSL, WGSL doesn't have `and()` and `or()`.
1320-
if (as<IRBasicType>(inst->getDataType()))
1320+
auto vecType = as<IRVectorType>(inst->getDataType());
1321+
if (!vecType)
13211322
return false;
13221323

1324+
// The function signature for `select` in WGSL is different from others:
1325+
// @const @must_use fn select(f: T, t: T, cond: bool) -> T
13231326
if (inst->getOp() == kIROp_And)
13241327
{
1325-
m_writer->emit("select(");
1326-
emitType(inst->getDataType());
1327-
m_writer->emit("(false), ");
1328-
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1329-
m_writer->emit(", ");
1328+
m_writer->emit(" select(vec");
1329+
m_writer->emit(getIntVal(vecType->getElementCount()));
1330+
m_writer->emit("<bool>(false), ");
13301331
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1331-
m_writer->emit(")");
1332+
m_writer->emit(", ");
1333+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1334+
m_writer->emit(") ");
13321335
}
13331336
else
13341337
{
1335-
m_writer->emit("select(");
1336-
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1337-
m_writer->emit(", ");
1338-
emitType(inst->getDataType());
1339-
m_writer->emit("(true), ");
1338+
m_writer->emit(" select(");
13401339
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1341-
m_writer->emit(")");
1340+
m_writer->emit(", vec");
1341+
m_writer->emit(getIntVal(vecType->getElementCount()));
1342+
m_writer->emit("<bool>(true), ");
1343+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1344+
m_writer->emit(") ");
13421345
}
13431346
return true;
13441347
}

source/slang/slang-emit.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "slang-ir-insts.h"
5353
#include "slang-ir-layout.h"
5454
#include "slang-ir-legalize-array-return-type.h"
55+
#include "slang-ir-legalize-binary-operator.h"
5556
#include "slang-ir-legalize-global-values.h"
5657
#include "slang-ir-legalize-image-subscript.h"
5758
#include "slang-ir-legalize-mesh-outputs.h"
@@ -1460,6 +1461,9 @@ Result linkAndOptimizeIR(
14601461
floatNonUniformResourceIndex(irModule, NonUniformResourceIndexFloatMode::Textual);
14611462
}
14621463

1464+
if (isD3DTarget(targetRequest) || isKhronosTarget(targetRequest) || isWGPUTarget(targetRequest))
1465+
legalizeLogicalAndOr(irModule->getModuleInst());
1466+
14631467
// Legalize non struct parameters that are expected to be structs for HLSL.
14641468
if (isD3DTarget(targetRequest))
14651469
legalizeNonStructParameterToStructForHLSL(irModule);

source/slang/slang-ir-insts.h

+3
Original file line numberDiff line numberDiff line change
@@ -4518,6 +4518,9 @@ struct IRBuilder
45184518
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
45194519
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);
45204520

4521+
IRInst* emitAnd(IRType* type, IRInst* left, IRInst* right);
4522+
IRInst* emitOr(IRType* type, IRInst* left, IRInst* right);
4523+
45214524
IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal);
45224525
IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst);
45234526
IRSPIRVAsmOperand* createSPIRVAsmOperandInst(IRInst* inst);

source/slang/slang-ir-legalize-binary-operator.cpp

+97
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,101 @@ void legalizeBinaryOp(IRInst* inst)
118118
}
119119
}
120120

121+
void legalizeLogicalAndOr(IRInst* inst)
122+
{
123+
switch (inst->getOp())
124+
{
125+
case kIROp_And:
126+
case kIROp_Or:
127+
{
128+
IRBuilder builder(inst);
129+
builder.setInsertBefore(inst);
130+
131+
// Logical-AND and logical-OR takes boolean types as its operands.
132+
// If they are not, legalize them by casting to boolean type.
133+
//
134+
SLANG_ASSERT(inst->getOperandCount() == 2);
135+
for (UInt i = 0; i < 2; i++)
136+
{
137+
auto operand = inst->getOperand(i);
138+
auto operandDataType = operand->getDataType();
139+
140+
if (auto vecType = as<IRVectorType>(operandDataType))
141+
{
142+
if (!as<IRBoolType>(vecType->getElementType()))
143+
{
144+
// Cast operand to vector<bool,N>
145+
auto elemCount = vecType->getElementCount();
146+
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
147+
auto v = builder.emitCast(vb, operand);
148+
builder.replaceOperand(inst->getOperands() + i, v);
149+
}
150+
}
151+
else if (!as<IRBoolType>(operandDataType))
152+
{
153+
// Cast operand to bool
154+
auto s = builder.emitCast(builder.getBoolType(), operand);
155+
builder.replaceOperand(inst->getOperands() + i, s);
156+
}
157+
}
158+
159+
// Legalize the return type; mostly for SPIRV.
160+
// The return type of OpLogicalOr must be boolean type.
161+
// If not, we need to recreate the instruction with boolean return type.
162+
// Then, we have to cast it back to the original type so that other instrucitons that
163+
// use have the matching types.
164+
//
165+
auto dataType = inst->getDataType();
166+
auto lhs = inst->getOperand(0);
167+
auto rhs = inst->getOperand(1);
168+
IRInst* newInst = nullptr;
169+
170+
if (auto vecType = as<IRVectorType>(dataType))
171+
{
172+
if (!as<IRBoolType>(vecType->getElementType()))
173+
{
174+
// Return type should be vector<bool,N>
175+
auto elemCount = vecType->getElementCount();
176+
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
177+
178+
if (inst->getOp() == kIROp_And)
179+
{
180+
newInst = builder.emitAnd(vb, lhs, rhs);
181+
}
182+
else
183+
{
184+
newInst = builder.emitOr(vb, lhs, rhs);
185+
}
186+
newInst = builder.emitCast(dataType, newInst);
187+
}
188+
}
189+
else if (!as<IRBoolType>(dataType))
190+
{
191+
// Return type should be bool
192+
if (inst->getOp() == kIROp_And)
193+
{
194+
newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs);
195+
}
196+
else
197+
{
198+
newInst = builder.emitOr(builder.getBoolType(), lhs, rhs);
199+
}
200+
newInst = builder.emitCast(dataType, newInst);
201+
}
202+
203+
if (newInst && inst != newInst)
204+
{
205+
inst->replaceUsesWith(newInst);
206+
inst->removeAndDeallocate();
207+
}
208+
}
209+
break;
210+
}
211+
212+
for (auto child : inst->getModifiableChildren())
213+
{
214+
legalizeLogicalAndOr(child);
215+
}
216+
}
217+
121218
} // namespace Slang

source/slang/slang-ir-legalize-binary-operator.h

+5
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ struct IRInst;
1313
// signed operand is converted to unsigned.
1414
void legalizeBinaryOp(IRInst* inst);
1515

16+
// The logical binary operators such as AND and OR takes boolean types are its input.
17+
// If they are in integer type, as an example, we need to explicitly cast to bool type.
18+
// Also the return type from the logical operators should be a boolean type.
19+
void legalizeLogicalAndOr(IRInst* inst);
20+
1621
} // namespace Slang

source/slang/slang-ir.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -6015,6 +6015,20 @@ IRInst* IRBuilder::emitShl(IRType* type, IRInst* left, IRInst* right)
60156015
return inst;
60166016
}
60176017

6018+
IRInst* IRBuilder::emitAnd(IRType* type, IRInst* left, IRInst* right)
6019+
{
6020+
auto inst = createInst<IRInst>(this, kIROp_And, type, left, right);
6021+
addInst(inst);
6022+
return inst;
6023+
}
6024+
6025+
IRInst* IRBuilder::emitOr(IRType* type, IRInst* left, IRInst* right)
6026+
{
6027+
auto inst = createInst<IRInst>(this, kIROp_Or, type, left, right);
6028+
addInst(inst);
6029+
return inst;
6030+
}
6031+
60186032
IRInst* IRBuilder::emitGetNativePtr(IRInst* value)
60196033
{
60206034
auto valueType = value->getDataType();
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
//TEST(compute):SIMPLE(filecheck=SM5):-target hlsl -profile cs_5_1 -entry computeMain
22
//TEST(compute):SIMPLE(filecheck=SM6):-target hlsl -profile cs_6_0 -entry computeMain
33
//TEST(compute):SIMPLE(filecheck=WGSL):-target wgsl -stage compute -entry computeMain
4-
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -xslang -Wno-30056
5-
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -xslang -Wno-30056
6-
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -xslang -Wno-30056
7-
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -xslang -Wno-30056
4+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-slang -compute -shaderobj -output-using-type -xslang -Wno-30056
5+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -Wno-30056
6+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj -output-using-type -xslang -Wno-30056
7+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj -output-using-type -xslang -Wno-30056
8+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -shaderobj -output-using-type -xslang -Wno-30056
89

910
// Testnig logical-AND, logical-OR and ternary operator with non-scalar operands
1011

@@ -15,7 +16,7 @@ static int result = 0;
1516

1617
bool2 assignFunc(int index)
1718
{
18-
result++;
19+
result += 10;
1920
return bool2(true);
2021
}
2122

@@ -24,34 +25,38 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
2425
{
2526
int index = dispatchThreadID.x;
2627

28+
// No short-circuiting for vector types
29+
2730
//SM5:(all({{.*}}&&
2831
//SM6:(all(and(
29-
//WGSL:(all(select(vec2<bool>(false),
32+
//WGSL:(all( select(vec2<bool>(false),
3033
if (all(bool2(index >= 1) && assignFunc(index)))
3134
{
3235
result++;
3336
}
3437

38+
// Intentionally using non-boolean type for testing.
39+
3540
//SM5:(all({{.*}}||
36-
//SM6:(all(or(
37-
//WGSL:(all(select({{.*}}vec2<bool>(true),
38-
if (all(bool2(index >= 2) || !assignFunc(index)))
41+
//SM6:(or(vector<bool,2>(
42+
//WGSL:( select({{.*}}, vec2<bool>(true), vec2<bool>(
43+
if (all(int2(index >= 2) || !assignFunc(index)))
3944
{
4045
result++;
4146
}
4247

4348
//SM5:(all({{.*}}?{{.*}}:
44-
//SM6:(all(select({{ *}}
45-
//WGSL:(all(select(
49+
//SM6:(all(select(
50+
//WGSL:(all(select(vec2<bool>(false),
4651
if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false)))
4752
{
4853
result++;
4954
}
5055

5156
outputBuffer[index] = result;
5257

53-
//CHK:3
54-
//CHK-NEXT:4
55-
//CHK-NEXT:5
56-
//CHK-NEXT:6
58+
//CHK:30
59+
//CHK-NEXT:31
60+
//CHK-NEXT:32
61+
//CHK-NEXT:33
5762
}

tests/compute/logic-short-circuit-evaluation.slang

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-dx12 -compute -shaderobj
22
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-vk -compute -shaderobj
3+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHK):-mtl -compute -shaderobj
34
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cuda -compute -shaderobj
45
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-cpu -compute -compile-arg -O3 -shaderobj
56
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-slang -compute -shaderobj

0 commit comments

Comments
 (0)