Skip to content

Commit 2ab5192

Browse files
authored
[Vulkan] Add TIR unary trigonometric/hyperbolic intrinsic definitions (apache#18005)
1 parent 08f3365 commit 2ab5192

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

src/target/spirv/intrin_rule_spirv.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,39 @@ TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
9191
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
9292
DispatchGLSLPureIntrin<GLSLstd450Cos>);
9393

94+
TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
95+
DispatchGLSLPureIntrin<GLSLstd450Tan>);
96+
97+
TVM_REGISTER_OP("tir.asin")
98+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Asin>);
99+
100+
TVM_REGISTER_OP("tir.acos")
101+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Acos>);
102+
103+
TVM_REGISTER_OP("tir.atan")
104+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Atan>);
105+
106+
TVM_REGISTER_OP("tir.sinh")
107+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Sinh>);
108+
109+
TVM_REGISTER_OP("tir.cosh")
110+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Cosh>);
111+
112+
TVM_REGISTER_OP("tir.tanh")
113+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
114+
115+
TVM_REGISTER_OP("tir.asinh")
116+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Asinh>);
117+
118+
TVM_REGISTER_OP("tir.acosh")
119+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Acosh>);
120+
121+
TVM_REGISTER_OP("tir.atanh")
122+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Atanh>);
123+
124+
TVM_REGISTER_OP("tir.atan2")
125+
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Atan2>);
126+
94127
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
95128
DispatchGLSLPureIntrin<GLSLstd450Log>);
96129

@@ -103,9 +136,6 @@ TVM_REGISTER_OP("tir.sqrt")
103136
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
104137
DispatchGLSLPureIntrin<GLSLstd450Pow>);
105138

106-
TVM_REGISTER_OP("tir.tanh")
107-
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin<GLSLstd450Tanh>);
108-
109139
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
110140
codegen::intrin ::DispatchFastErf);
111141
} // namespace intrin

tests/python/codegen/test_target_codegen_vulkan.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,5 +568,60 @@ def kernel():
568568
vulkan_codegen(mod, target)
569569

570570

571+
@tvm.testing.requires_gpu
572+
@tvm.testing.requires_vulkan
573+
def test_unary():
574+
test_funcs = [
575+
(tvm.tir.sin, lambda x: np.sin(x)),
576+
(tvm.tir.cos, lambda x: np.cos(x)),
577+
(tvm.tir.tan, lambda x: np.tan(x)),
578+
(tvm.tir.sinh, lambda x: np.sinh(x)),
579+
(tvm.tir.cosh, lambda x: np.cosh(x)),
580+
(tvm.tir.tanh, lambda x: np.tanh(x)),
581+
(tvm.tir.asin, lambda x: np.arcsin(x)),
582+
(tvm.tir.acos, lambda x: np.arccos(x)),
583+
(tvm.tir.atan, lambda x: np.arctan(x)),
584+
(tvm.tir.asinh, lambda x: np.arcsinh(x)),
585+
(tvm.tir.acosh, lambda x: np.arccosh(x)),
586+
(tvm.tir.atanh, lambda x: np.arctanh(x)),
587+
]
588+
589+
def run_test(tvm_intrin, np_func):
590+
m = te.var("m")
591+
A = te.placeholder((m,), name="A", dtype="float32")
592+
B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B")
593+
594+
mod = te.create_prim_func([A, B])
595+
sch = tir.Schedule(mod)
596+
597+
block = sch.get_block("B")
598+
loop = sch.get_loops(block)[0]
599+
bx, tx = sch.split(loop, factors=[None, 64])
600+
sch.bind(bx, "blockIdx.x")
601+
sch.bind(tx, "threadIdx.x")
602+
603+
target = tvm.target.Target("vulkan")
604+
dev = tvm.device(target.kind.name, 0)
605+
func = tvm.compile(sch.mod, target=target)
606+
607+
n = 16
608+
if tvm_intrin in [tvm.tir.asin, tvm.tir.acos]:
609+
data = np.random.uniform(-1.0, 1.0, size=n)
610+
elif tvm_intrin == tvm.tir.atanh:
611+
data = np.random.uniform(-0.999, 0.999, size=n)
612+
elif tvm_intrin == tvm.tir.acosh:
613+
data = np.random.uniform(1.0, 5.0, size=n)
614+
else:
615+
data = np.random.uniform(0.1, 0.9, size=n)
616+
617+
a = tvm.nd.array(data.astype(A.dtype), dev)
618+
b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
619+
func(a, b)
620+
tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3)
621+
622+
for func in test_funcs:
623+
run_test(*func)
624+
625+
571626
if __name__ == "__main__":
572627
tvm.testing.main()

0 commit comments

Comments
 (0)