Skip to content

Commit b4ffaa4

Browse files
committed
[CIR][AMDGPU] Add lowering for amdgcn rsq builtins
1 parent 5237bd4 commit b4ffaa4

File tree

7 files changed

+130
-2
lines changed

7 files changed

+130
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinAMDGPU.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,11 +333,13 @@ mlir::Value CIRGenFunction::emitAMDGPUBuiltinExpr(unsigned builtinId,
333333
case AMDGPU::BI__builtin_amdgcn_rsqf:
334334
case AMDGPU::BI__builtin_amdgcn_rsqh:
335335
case AMDGPU::BI__builtin_amdgcn_rsq_bf16: {
336-
llvm_unreachable("rsq_* NYI");
336+
return emitBuiltinWithOneOverloadedType<1>(expr, "amdgcn.rsq")
337+
.getScalarVal();
337338
}
338339
case AMDGPU::BI__builtin_amdgcn_rsq_clamp:
339340
case AMDGPU::BI__builtin_amdgcn_rsq_clampf: {
340-
llvm_unreachable("rsq_clamp_* NYI");
341+
return emitBuiltinWithOneOverloadedType<1>(expr, "amdgcn.rsq.clamp")
342+
.getScalarVal();
341343
}
342344
case AMDGPU::BI__builtin_amdgcn_sinf:
343345
case AMDGPU::BI__builtin_amdgcn_sinh:

clang/test/CIR/CodeGen/HIP/builtins-amdgcn-gfx1250.hip

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,14 @@ __device__ void test_sqrt_bf16(__bf16* out, __bf16 a)
3838
{
3939
*out = __builtin_amdgcn_sqrt_bf16(a);
4040
}
41+
42+
// CIR-LABEL: @_Z13test_rsq_bf16PDF16bDF16b
43+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.bf16) -> !cir.bf16
44+
// LLVM: define{{.*}} void @_Z13test_rsq_bf16PDF16bDF16b
45+
// LLVM: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
46+
// OGCG: define{{.*}} void @_Z13test_rsq_bf16PDF16bDF16b
47+
// OGCG: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
48+
__device__ void test_rsq_bf16(__bf16* out, __bf16 a)
49+
{
50+
*out = __builtin_amdgcn_rsq_bf16(a);
51+
}

clang/test/CIR/CodeGen/HIP/builtins-amdgcn-vi.hip

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,14 @@ __device__ void test_sqrt_f16(_Float16* out, _Float16 a)
8787
{
8888
*out = __builtin_amdgcn_sqrth(a);
8989
}
90+
91+
// CIR-LABEL: @_Z10test_rsq_hPDF16_DF16_
92+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.f16) -> !cir.f16
93+
// LLVM: define{{.*}} void @_Z10test_rsq_hPDF16_DF16_
94+
// LLVM: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
95+
// OGCG: define{{.*}} void @_Z10test_rsq_hPDF16_DF16_
96+
// OGCG: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
97+
__device__ void test_rsq_h(_Float16* out, _Float16 a)
98+
{
99+
*out = __builtin_amdgcn_rsqh(a);
100+
}

clang/test/CIR/CodeGen/HIP/builtins-amdgcn.hip

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,45 @@ __device__ void test_sqrt_f32(float* out, float a) {
385385
__device__ void test_sqrt_f64(double* out, double a) {
386386
*out = __builtin_amdgcn_sqrt(a);
387387
}
388+
389+
// CIR-LABEL: @_Z12test_rsq_f32Pff
390+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.float) -> !cir.float
391+
// LLVM: define{{.*}} void @_Z12test_rsq_f32Pff
392+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
393+
// OGCG: define{{.*}} void @_Z12test_rsq_f32Pff
394+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
395+
__device__ void test_rsq_f32(float* out, float a)
396+
{
397+
*out = __builtin_amdgcn_rsqf(a);
398+
}
399+
400+
// CIR-LABEL: @_Z12test_rsq_f64Pdd
401+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.double) -> !cir.double
402+
// LLVM: define{{.*}} void @_Z12test_rsq_f64Pdd
403+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
404+
// OGCG: define{{.*}} void @_Z12test_rsq_f64Pdd
405+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
406+
__device__ void test_rsq_f64(double* out, double a) {
407+
*out = __builtin_amdgcn_rsq(a);
408+
}
409+
410+
// CIR-LABEL: @_Z18test_rsq_clamp_f32Pff
411+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.float) -> !cir.float
412+
// LLVM: define{{.*}} void @_Z18test_rsq_clamp_f32Pff
413+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
414+
// OGCG: define{{.*}} void @_Z18test_rsq_clamp_f32Pff
415+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
416+
__device__ void test_rsq_clamp_f32(float* out, float a)
417+
{
418+
*out = __builtin_amdgcn_rsq_clampf(a);
419+
}
420+
421+
// CIR-LABEL: @_Z18test_rsq_clamp_f64Pdd
422+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.double) -> !cir.double
423+
// LLVM: define{{.*}} void @_Z18test_rsq_clamp_f64Pdd
424+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
425+
// OGCG: define{{.*}} void @_Z18test_rsq_clamp_f64Pdd
426+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
427+
__device__ void test_rsq_clamp_f64(double* out, double a) {
428+
*out = __builtin_amdgcn_rsq_clamp(a);
429+
}

clang/test/CIR/CodeGen/OpenCL/builtins-amdgcn-gfx1250.cl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,14 @@ void test_sqrt_bf16(global __bf16* out, __bf16 a)
3838
{
3939
*out = __builtin_amdgcn_sqrt_bf16(a);
4040
}
41+
42+
// CIR-LABEL: @test_rsq_bf16
43+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.bf16) -> !cir.bf16
44+
// LLVM: define{{.*}} void @test_rsq_bf16
45+
// LLVM: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
46+
// OGCG: define{{.*}} void @test_rsq_bf16
47+
// OGCG: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
48+
void test_rsq_bf16(__bf16* out, __bf16 a)
49+
{
50+
*out = __builtin_amdgcn_rsq_bf16(a);
51+
}

clang/test/CIR/CodeGen/OpenCL/builtins-amdgcn-vi.cl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,14 @@ void test_sqrt_f16(global half* out, half a)
8787
{
8888
*out = __builtin_amdgcn_sqrth(a);
8989
}
90+
91+
// CIR-LABEL: @test_rsq_f16
92+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.f16) -> !cir.f16
93+
// LLVM: define{{.*}} void @test_rsq_f16
94+
// LLVM: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
95+
// OGCG: define{{.*}} void @test_rsq_f16
96+
// OGCG: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
97+
void test_rsq_f16(global half* out, half a)
98+
{
99+
*out = __builtin_amdgcn_rsqh(a);
100+
}

clang/test/CIR/CodeGen/OpenCL/builtins_amdgcn.cl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,43 @@ void test_sqrt_f32(global float* out, float a) {
400400
void test_sqrt_f64(global double* out, double a) {
401401
*out = __builtin_amdgcn_sqrt(a);
402402
}
403+
404+
// CIR-LABEL: @test_rsq_f32
405+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.float) -> !cir.float
406+
// LLVM: define{{.*}} void @test_rsq_f32
407+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
408+
// OGCG: define{{.*}} void @test_rsq_f32
409+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
410+
void test_rsq_f32(global float* out, float a) {
411+
*out = __builtin_amdgcn_rsqf(a);
412+
}
413+
414+
// CIR-LABEL: @test_rsq_f64
415+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.double) -> !cir.double
416+
// LLVM: define{{.*}} void @test_rsq_f64
417+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
418+
// OGCG: define{{.*}} void @test_rsq_f64
419+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
420+
void test_rsq_f64(global double* out, double a) {
421+
*out = __builtin_amdgcn_rsq(a);
422+
}
423+
424+
// CIR-LABEL: @test_rsq_clamp_f32
425+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.float) -> !cir.float
426+
// LLVM: define{{.*}} void @test_rsq_clamp_f32
427+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
428+
// OGCG: define{{.*}} void @test_rsq_clamp_f32
429+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
430+
void test_rsq_clamp_f32(global float* out, float a) {
431+
*out = __builtin_amdgcn_rsq_clampf(a);
432+
}
433+
434+
// CIR-LABEL: @test_rsq_clamp_f64
435+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.double) -> !cir.double
436+
// LLVM: define{{.*}} void @test_rsq_clamp_f64
437+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
438+
// OGCG: define{{.*}} void @test_rsq_clamp_f64
439+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
440+
void test_rsq_clamp_f64(global double* out, double a) {
441+
*out = __builtin_amdgcn_rsq_clamp(a);
442+
}

0 commit comments

Comments
 (0)