Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,11 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::YieldOp::create(*this, loc, value);
}

cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
mlir::Value stride) {
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride);
cir::PtrStrideOp
createPtrStride(mlir::Location loc, mlir::Value base, mlir::Value stride,
std::optional<CIR_GEPNoWrapFlags> flags = std::nullopt) {
return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride,
flags.value_or(CIR_GEPNoWrapFlags::none));
}

cir::CallOp createCallOp(mlir::Location loc,
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIREnumAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
include "mlir/IR/EnumAttr.td"
include "clang/CIR/Dialect/IR/CIRDialect.td"

class CIR_I32BitEnum<string name, string summary, list<BitEnumCaseBase> cases>
: I32BitEnum<name, summary, cases> {
let cppNamespace = "::cir";
}

class CIR_I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, summary, cases> {
let cppNamespace = "::cir";
Expand Down
32 changes: 27 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,24 @@ def CIR_PtrDiffOp : CIR_Op<"ptr_diff", [Pure, SameTypeOperands]> {
//===----------------------------------------------------------------------===//
// PtrStrideOp
//===----------------------------------------------------------------------===//
def CIR_GEPNone : I32BitEnumCaseNone<"none">;
def CIR_GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
def CIR_GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
def CIR_GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
def CIR_GEPInbounds
: BitEnumCaseGroup<"inbounds", [CIR_GEPInboundsFlag, CIR_GEPNusw]>;

def CIR_GEPNoWrapFlags
: CIR_I32BitEnum<"CIR_GEPNoWrapFlags", "::cir::CIR_GEPNoWrapFlags",
[CIR_GEPNone, CIR_GEPInboundsFlag, CIR_GEPNusw, CIR_GEPNuw,
CIR_GEPInbounds]> {
let cppNamespace = "::cir";
let printBitEnumPrimaryGroups = 1;
}

def CIR_GEPNoWrapFlagsProp : EnumProp<CIR_GEPNoWrapFlags> {
let defaultValue = interfaceType#"::none";
}

def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[
Pure, AllTypesMatch<["base", "result"]>
Expand All @@ -397,19 +415,23 @@ def CIR_PtrStrideOp : CIR_Op<"ptr_stride",[

```mlir
%3 = cir.const 0 : i32

%4 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32), !cir.ptr<i32>

%5 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds), !cir.ptr<i32>

%6 = cir.ptr_stride(%2 : !cir.ptr<i32>, %3 : i32, inbounds|nuw), !cir.ptr<i32>

```
}];

let arguments = (ins
CIR_PointerType:$base,
CIR_AnyFundamentalIntType:$stride
);
let arguments = (ins CIR_PointerType:$base, CIR_AnyFundamentalIntType:$stride,
CIR_GEPNoWrapFlagsProp:$noWrapFlags);

let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
$base`,` $stride `:` functional-type(operands, results) attr-dict
($noWrapFlags^)? $base`,` $stride `:` functional-type(operands, results) attr-dict
}];

let extraClassDeclaration = [{
Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2842,10 +2842,16 @@ mlir::Value CIRGenFunction::emitCheckedInBoundsGEP(
assert(IdxList.size() == 1 && "multi-index ptr arithmetic NYI");
mlir::Value GEPVal =
builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr, IdxList[0]);

// If the pointer overflow sanitizer isn't enabled, do nothing.
if (!SanOpts.has(SanitizerKind::PointerOverflow))
return GEPVal;
if (!SanOpts.has(SanitizerKind::PointerOverflow)) {
cir::CIR_GEPNoWrapFlags nwFlags = cir::CIR_GEPNoWrapFlags::inbounds;
if (!SignedIndices && !IsSubtraction)
nwFlags = nwFlags | cir::CIR_GEPNoWrapFlags::nuw;
return builder.create<cir::PtrStrideOp>(CGM.getLoc(Loc), PtrTy, Ptr,
IdxList[0], nwFlags);
}

return GEPVal;

// TODO(cir): the unreachable code below hides a substantial amount of code
// from the original codegen related with pointer overflow sanitizer.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
Expand Down
38 changes: 28 additions & 10 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@ void walkRegionSkipping(mlir::Region &region,
});
}

/// Convert from a CIR PtrStrideOp kind to an LLVM IR equivalent of GEP.
mlir::LLVM::GEPNoWrapFlags
convertPtrStrideKindToGEPFlags(cir::CIR_GEPNoWrapFlags flags) {
using CIRFlags = cir::CIR_GEPNoWrapFlags;
using LLVMFlags = mlir::LLVM::GEPNoWrapFlags;

LLVMFlags x = LLVMFlags::none;
if ((flags & CIRFlags::inboundsFlag) == CIRFlags::inboundsFlag)
x = x | LLVMFlags::inboundsFlag;
if ((flags & CIRFlags::nusw) == CIRFlags::nusw)
x = x | LLVMFlags::nusw;
if ((flags & CIRFlags::inbounds) == CIRFlags::inbounds)
x = x | LLVMFlags::inbounds;
if ((flags & CIRFlags::nuw) == CIRFlags::nuw)
x = x | LLVMFlags::nuw;
return x;
}

/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
mlir::LLVM::ICmpPredicate convertCmpKindToICmpPredicate(cir::CmpOpKind kind,
bool isSigned) {
Expand Down Expand Up @@ -1023,9 +1041,9 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index,
convertPtrStrideKindToGEPFlags(adaptor.getNoWrapFlags()));
return mlir::success();
}

Expand Down Expand Up @@ -4299,14 +4317,14 @@ mlir::LogicalResult CIRToLLVMEhSetjmpOpLowering::matchAndRewrite(
return mlir::success();
}

StringRef fnName = "_setjmp";
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
auto fnType = mlir::LLVM::LLVMFunctionType::get(returnType, llvmPtrTy,
/*isVarArg=*/false);
getOrCreateLLVMFuncOp(rewriter, op, fnName, fnType);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, returnType, fnName,
adaptor.getEnv());
return mlir::success();
StringRef fnName = "_setjmp";
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
auto fnType = mlir::LLVM::LLVMFunctionType::get(returnType, llvmPtrTy,
/*isVarArg=*/false);
getOrCreateLLVMFuncOp(rewriter, op, fnName, fnType);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, returnType, fnName,
adaptor.getEnv());
return mlir::success();
}

mlir::LogicalResult CIRToLLVMCatchParamOpLowering::matchAndRewrite(
Expand Down
21 changes: 11 additions & 10 deletions clang/test/CIR/CodeGen/pointer-arith-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ void *f4(void *a, int b) { return a - b; }
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[SUB]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[SUB]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>

// LLVM-LABEL: f4
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: %[[SUB:.*]] = sub i64 0, %[[STRIDE]]
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[SUB]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[SUB]]

// Similar to f4, just make sure it does not crash.
void *f4_1(void *a, int b) { return (a -= b); }
Expand All @@ -52,13 +52,13 @@ FP f5(FP a, int b) { return a + b; }
// CIR-LABEL: f5
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[STRIDE]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[STRIDE]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>

// LLVM-LABEL: f5
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[STRIDE]]

// These test the same paths above, just make sure it does not crash.
FP f5_1(FP a, int b) { return (a += b); }
Expand All @@ -70,14 +70,14 @@ FP f7(FP a, int b) { return a - b; }
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[SUB]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[SUB]] : (!cir.ptr<!cir.func<()>>, !s32i) -> !cir.ptr<!cir.func<()>>

// LLVM-LABEL: f7
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: %[[SUB:.*]] = sub i64 0, %[[STRIDE]]
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[SUB]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[SUB]]

// Similar to f7, just make sure it does not crash.
FP f7_1(FP a, int b) { return (a -= b); }
Expand All @@ -87,14 +87,14 @@ void f8(void *a, int b) { return *(id(a + b)); }
// CIR-LABEL: f8
// CIR: %[[PTR:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
// CIR: %[[STRIDE:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride %[[PTR]], %[[STRIDE]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>
// CIR: cir.ptr_stride inbounds %[[PTR]], %[[STRIDE]] : (!cir.ptr<!void>, !s32i) -> !cir.ptr<!void>
// CIR: cir.return

// LLVM-LABEL: f8
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
// LLVM: %[[TOEXT:.*]] = load i32, ptr {{.*}}, align 4
// LLVM: %[[STRIDE:.*]] = sext i32 %[[TOEXT]] to i64
// LLVM: getelementptr i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: getelementptr inbounds i8, ptr %[[PTR]], i64 %[[STRIDE]]
// LLVM: ret void

void f8_1(void *a, int b) { return a[b]; }
Expand All @@ -119,7 +119,8 @@ unsigned char *p(unsigned int x) {

// CIR-LABEL: @p
// CIR: %[[SUB:.*]] = cir.binop(sub
// CIR: cir.ptr_stride {{.*}}, %[[SUB]] : (!cir.ptr<!u8i>, !u32i) -> !cir.ptr<!u8i>
// CIR: cir.ptr_stride inbounds|nuw {{.*}}, %[[SUB]] : (!cir.ptr<!u8i>, !u32i) -> !cir.ptr<!u8i>

// LLVM-LABEL: @p
// LLVM: getelementptr i8, ptr {{.*}}
// LLVM: getelementptr inbounds nuw i8, ptr {{.*}}

12 changes: 6 additions & 6 deletions clang/test/CIR/CodeGen/pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@
void foo(int *iptr, char *cptr, unsigned ustride) {
*(iptr + 2) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
*(cptr + 3) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
*(iptr - 2) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<2> : !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
*(cptr - 3) = 1;
// CHECK: %[[#STRIDE:]] = cir.const #cir.int<3> : !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#STRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s8i>, !s32i) -> !cir.ptr<!s8i>
*(iptr + ustride) = 1;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !u32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds|nuw %{{.+}}, %[[#STRIDE]] : (!cir.ptr<!s32i>, !u32i) -> !cir.ptr<!s32i>

// Must convert unsigned stride to a signed one.
*(iptr - ustride) = 1;
// CHECK: %[[#STRIDE:]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!u32i>, !u32i
// CHECK: %[[#SIGNSTRIDE:]] = cir.cast(integral, %[[#STRIDE]] : !u32i), !s32i
// CHECK: %[[#NEGSTRIDE:]] = cir.unary(minus, %[[#SIGNSTRIDE]]) : !s32i, !s32i
// CHECK: cir.ptr_stride %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK: cir.ptr_stride inbounds %{{.+}}, %[[#NEGSTRIDE]] : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
}

void testPointerSubscriptAccess(int *ptr) {
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CIR/IR/ptr_stride.cir
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ module {
%4 = cir.ptr_stride %2, %3 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
cir.return
}

cir.func @gepflags(%arg0: !cir.ptr<!s32i>, %arg1: !s32i) {
%0 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%1 = cir.ptr_stride nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%2 = cir.ptr_stride inbounds|nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
%3 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
cir.return
}
}

// CHECK: cir.func @arraysubscript(%arg0: !s32i) {
Expand All @@ -20,3 +28,12 @@ module {
// CHECK-NEXT: %4 = cir.ptr_stride %2, %3 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: cir.return
// CHECK-NEXT: }


// CHECK: cir.func @gepflags(%arg0: !cir.ptr<!s32i>, %arg1: !s32i) {
// CHECK-NEXT: %0 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: %1 = cir.ptr_stride nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: %2 = cir.ptr_stride inbounds|nuw %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: %3 = cir.ptr_stride %arg0, %arg1 : (!cir.ptr<!s32i>, !s32i) -> !cir.ptr<!s32i>
// CHECK-NEXT: cir.return
// CHECK-NEXT: }
Loading