Skip to content

Commit 028cb8a

Browse files
committed
[CIR][AMDGPU] Add lowering for amdgcn rsq builtins
1 parent 34fabe1 commit 028cb8a

File tree

7 files changed

+130
-2
lines changed

7 files changed

+130
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinAMDGPU.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,13 @@ mlir::Value CIRGenFunction::emitAMDGPUBuiltinExpr(unsigned builtinId,
332332
case AMDGPU::BI__builtin_amdgcn_rsqf:
333333
case AMDGPU::BI__builtin_amdgcn_rsqh:
334334
case AMDGPU::BI__builtin_amdgcn_rsq_bf16: {
335-
llvm_unreachable("rsq_* NYI");
335+
return emitBuiltinWithOneOverloadedType<1>(expr, "amdgcn.rsq")
336+
.getScalarVal();
336337
}
337338
case AMDGPU::BI__builtin_amdgcn_rsq_clamp:
338339
case AMDGPU::BI__builtin_amdgcn_rsq_clampf: {
339-
llvm_unreachable("rsq_clamp_* NYI");
340+
return emitBuiltinWithOneOverloadedType<1>(expr, "amdgcn.rsq.clamp")
341+
.getScalarVal();
340342
}
341343
case AMDGPU::BI__builtin_amdgcn_sinf:
342344
case AMDGPU::BI__builtin_amdgcn_sinh:

clang/test/CIR/CodeGen/HIP/builtins-amdgcn-gfx1250.hip

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,14 @@ __device__ void test_rcp_bf16(__bf16* out, __bf16 a)
2727
{
2828
*out = __builtin_amdgcn_rcp_bf16(a);
2929
}
30+
31+
// CIR-LABEL: @_Z13test_rsq_bf16PDF16bDF16b
32+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.bf16) -> !cir.bf16
33+
// LLVM: define{{.*}} void @_Z13test_rsq_bf16PDF16bDF16b
34+
// LLVM: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
35+
// OGCG: define{{.*}} void @_Z13test_rsq_bf16PDF16bDF16b
36+
// OGCG: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
37+
__device__ void test_rsq_bf16(__bf16* out, __bf16 a)
38+
{
39+
*out = __builtin_amdgcn_rsq_bf16(a);
40+
}

clang/test/CIR/CodeGen/HIP/builtins-amdgcn-vi.hip

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,14 @@ __device__ void test_rcp_f16(_Float16* out, _Float16 a)
7676
{
7777
*out = __builtin_amdgcn_rcph(a);
7878
}
79+
80+
// CIR-LABEL: @_Z10test_rsq_hPDF16_DF16_
81+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.f16) -> !cir.f16
82+
// LLVM: define{{.*}} void @_Z10test_rsq_hPDF16_DF16_
83+
// LLVM: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
84+
// OGCG: define{{.*}} void @_Z10test_rsq_hPDF16_DF16_
85+
// OGCG: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
86+
__device__ void test_rsq_h(_Float16* out, _Float16 a)
87+
{
88+
*out = __builtin_amdgcn_rsqh(a);
89+
}

clang/test/CIR/CodeGen/HIP/builtins-amdgcn.hip

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,45 @@ __device__ void test_rcp_f32(float* out, float a) {
365365
__device__ void test_rcp_f64(double* out, double a) {
366366
*out = __builtin_amdgcn_rcp(a);
367367
}
368+
369+
// CIR-LABEL: @_Z12test_rsq_f32Pff
370+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.float) -> !cir.float
371+
// LLVM: define{{.*}} void @_Z12test_rsq_f32Pff
372+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
373+
// OGCG: define{{.*}} void @_Z12test_rsq_f32Pff
374+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
375+
__device__ void test_rsq_f32(float* out, float a)
376+
{
377+
*out = __builtin_amdgcn_rsqf(a);
378+
}
379+
380+
// CIR-LABEL: @_Z12test_rsq_f64Pdd
381+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.double) -> !cir.double
382+
// LLVM: define{{.*}} void @_Z12test_rsq_f64Pdd
383+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
384+
// OGCG: define{{.*}} void @_Z12test_rsq_f64Pdd
385+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
386+
__device__ void test_rsq_f64(double* out, double a) {
387+
*out = __builtin_amdgcn_rsq(a);
388+
}
389+
390+
// CIR-LABEL: @_Z18test_rsq_clamp_f32Pff
391+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.float) -> !cir.float
392+
// LLVM: define{{.*}} void @_Z18test_rsq_clamp_f32Pff
393+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
394+
// OGCG: define{{.*}} void @_Z18test_rsq_clamp_f32Pff
395+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
396+
__device__ void test_rsq_clamp_f32(float* out, float a)
397+
{
398+
*out = __builtin_amdgcn_rsq_clampf(a);
399+
}
400+
401+
// CIR-LABEL: @_Z18test_rsq_clamp_f64Pdd
402+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.double) -> !cir.double
403+
// LLVM: define{{.*}} void @_Z18test_rsq_clamp_f64Pdd
404+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
405+
// OGCG: define{{.*}} void @_Z18test_rsq_clamp_f64Pdd
406+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
407+
__device__ void test_rsq_clamp_f64(double* out, double a) {
408+
*out = __builtin_amdgcn_rsq_clamp(a);
409+
}

clang/test/CIR/CodeGen/OpenCL/builtins-amdgcn-gfx1250.cl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,14 @@ void test_rcp_bf16(global __bf16* out, __bf16 a)
2727
{
2828
*out = __builtin_amdgcn_rcp_bf16(a);
2929
}
30+
31+
// CIR-LABEL: @test_rsq_bf16
32+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.bf16) -> !cir.bf16
33+
// LLVM: define{{.*}} void @test_rsq_bf16
34+
// LLVM: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
35+
// OGCG: define{{.*}} void @test_rsq_bf16
36+
// OGCG: call{{.*}} bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
37+
void test_rsq_bf16(__bf16* out, __bf16 a)
38+
{
39+
*out = __builtin_amdgcn_rsq_bf16(a);
40+
}

clang/test/CIR/CodeGen/OpenCL/builtins-amdgcn-vi.cl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,14 @@ void test_rcp_f16(global half* out, half a)
7676
{
7777
*out = __builtin_amdgcn_rcph(a);
7878
}
79+
80+
// CIR-LABEL: @test_rsq_f16
81+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.f16) -> !cir.f16
82+
// LLVM: define{{.*}} void @test_rsq_f16
83+
// LLVM: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
84+
// OGCG: define{{.*}} void @test_rsq_f16
85+
// OGCG: call{{.*}} half @llvm.amdgcn.rsq.f16(half %{{.*}})
86+
void test_rsq_f16(global half* out, half a)
87+
{
88+
*out = __builtin_amdgcn_rsqh(a);
89+
}

clang/test/CIR/CodeGen/OpenCL/builtins_amdgcn.cl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,43 @@ void test_rcp_f32(global float* out, float a) {
380380
void test_rcp_f64(global double* out, double a) {
381381
*out = __builtin_amdgcn_rcp(a);
382382
}
383+
384+
// CIR-LABEL: @test_rsq_f32
385+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.float) -> !cir.float
386+
// LLVM: define{{.*}} void @test_rsq_f32
387+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
388+
// OGCG: define{{.*}} void @test_rsq_f32
389+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.f32(float %{{.*}})
390+
void test_rsq_f32(global float* out, float a) {
391+
*out = __builtin_amdgcn_rsqf(a);
392+
}
393+
394+
// CIR-LABEL: @test_rsq_f64
395+
// CIR: cir.llvm.intrinsic "amdgcn.rsq" {{.*}} : (!cir.double) -> !cir.double
396+
// LLVM: define{{.*}} void @test_rsq_f64
397+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
398+
// OGCG: define{{.*}} void @test_rsq_f64
399+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.f64(double %{{.*}})
400+
void test_rsq_f64(global double* out, double a) {
401+
*out = __builtin_amdgcn_rsq(a);
402+
}
403+
404+
// CIR-LABEL: @test_rsq_clamp_f32
405+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.float) -> !cir.float
406+
// LLVM: define{{.*}} void @test_rsq_clamp_f32
407+
// LLVM: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
408+
// OGCG: define{{.*}} void @test_rsq_clamp_f32
409+
// OGCG: call{{.*}} float @llvm.amdgcn.rsq.clamp.f32(float %{{.*}})
410+
void test_rsq_clamp_f32(global float* out, float a) {
411+
*out = __builtin_amdgcn_rsq_clampf(a);
412+
}
413+
414+
// CIR-LABEL: @test_rsq_clamp_f64
415+
// CIR: cir.llvm.intrinsic "amdgcn.rsq.clamp" {{.*}} : (!cir.double) -> !cir.double
416+
// LLVM: define{{.*}} void @test_rsq_clamp_f64
417+
// LLVM: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
418+
// OGCG: define{{.*}} void @test_rsq_clamp_f64
419+
// OGCG: call{{.*}} double @llvm.amdgcn.rsq.clamp.f64(double %{{.*}})
420+
void test_rsq_clamp_f64(global double* out, double a) {
421+
*out = __builtin_amdgcn_rsq_clamp(a);
422+
}

0 commit comments

Comments
 (0)