Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va
} else if (to.is_bool()) {
if (from.is_float()) {
llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
return builder_->CreateFCmpONE(value, zero);
return builder_->CreateFCmpUNE(value, zero);
} else {
llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
return builder_->CreateICmpNE(value, zero);
Expand Down
25 changes: 25 additions & 0 deletions tests/python/codegen/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,31 @@ def check_llvm(n):
check_llvm(64)


@tvm.testing.requires_llvm
def test_llvm_cast_float_to_bool():
a_np = np.array([0.0, 1.0, np.nan, np.inf], dtype="float32")
n = a_np.shape[0]

A = te.placeholder((n,), name="A", dtype="float32")
C = te.compute((n,), lambda i: A[i].astype("bool"), name="C")

# Convert to TIR and create schedule
mod = te.create_prim_func([A, C])
sch = tir.Schedule(mod)

# build and invoke the kernel.
f = tvm.compile(sch.mod, target="llvm")
dev = tvm.cpu(0)

# launch the kernel.
a = tvm.runtime.tensor(a_np, dev)
c = tvm.runtime.empty((n,), dtype="bool", device=dev)
f(a, c)
c_np = np.array([False, True, True, True], dtype="bool")

tvm.testing.assert_allclose(c.numpy(), c_np)
Comment on lines +381 to +403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great test case that covers the essential scenarios for casting floats to booleans. To make it even more comprehensive, I suggest parameterizing it to run against multiple float dtypes (float16, float32, and float64). This will ensure the fix holds for different precisions and improve test coverage.

@tvm.testing.requires_llvm
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
def test_llvm_cast_float_to_bool(dtype):
    if dtype == "float16" and tvm.target.codegen.llvm_version_major() < 8:
        pytest.skip("float16 support requires LLVM 8 or greater")

    a_np = np.array([0.0, 1.0, np.nan, np.inf], dtype=dtype)
    n = a_np.shape[0]

    A = te.placeholder((n,), name="A", dtype=dtype)
    C = te.compute((n,), lambda i: A[i].astype("bool"), name="C")

    # Convert to TIR and create schedule
    mod = te.create_prim_func([A, C])
    sch = tir.Schedule(mod)

    # build and invoke the kernel.
    f = tvm.compile(sch.mod, target="llvm")
    dev = tvm.cpu(0)

    # launch the kernel.
    a = tvm.runtime.tensor(a_np, dev)
    c = tvm.runtime.empty((n,), dtype="bool", device=dev)
    f(a, c)
    c_np = np.array([False, True, True, True], dtype="bool")

    tvm.testing.assert_allclose(c.numpy(), c_np)



@tvm.testing.requires_llvm
def test_rank_zero():
def check_llvm(n):
Expand Down