-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[CIR] Upstream extract op for VectorType #138413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds local zero initialization for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/138413.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 422c89c4f9391..b2121dee8d8b3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1976,4 +1976,28 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// VecExtractOp
+//===----------------------------------------------------------------------===//
+
+def VecExtractOp : CIR_Op<"vec.extract", [Pure,
+ TypesMatchWith<"type of 'result' matches element type of 'vec'", "vec",
+ "result", "cast<VectorType>($_self).getElementType()">]> {
+
+ let summary = "Extract one element from a vector object";
+ let description = [{
+ The `cir.vec.extract` operation extracts the element at the given index
+ from a vector object.
+ }];
+
+ let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
+ let results = (outs CIR_AnyType:$result);
+
+ let assemblyFormat = [{
+ $vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
+ }];
+
+ let hasVerifier = 0;
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 8ead6e793b4c8..a59b87cb9241b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -161,8 +161,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
if (e->getBase()->getType()->isVectorType()) {
assert(!cir::MissingFeatures::scalableVectors());
- cgf.getCIRGenModule().errorNYI("VisitArraySubscriptExpr: VectorType");
- return {};
+
+ const mlir::Location loc = cgf.getLoc(e->getSourceRange());
+ const mlir::Value vecValue = Visit(e->getBase());
+ const mlir::Value indexValue = Visit(e->getIdx());
+ return cgf.builder.create<cir::VecExtractOp>(loc, vecValue, indexValue);
}
// Just load the lvalue formed by the subscript expression.
return emitLoadOfLValue(e);
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 6137adb1e9936..66f29f8f6cdd0 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1600,7 +1600,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMStackRestoreOpLowering,
CIRToLLVMTrapOpLowering,
CIRToLLVMUnaryOpLowering,
- CIRToLLVMVecCreateOpLowering
+ CIRToLLVMVecCreateOpLowering,
+ CIRToLLVMVecExtractOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1709,6 +1710,14 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
+ cir::VecExtractOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
+ op, adaptor.getVec(), adaptor.getIndex());
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index f248ea31e7844..026505ea31b4c 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -303,6 +303,16 @@ class CIRToLLVMVecCreateOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecExtractOpLowering
+ : public mlir::OpConversionPattern<cir::VecExtractOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecExtractOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecExtractOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a3880a944de1f..aeeaf655cad18 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -109,3 +109,36 @@ void foo2(vi4 p) {}
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
+
+void foo3() {
+ vi4 a = { 1, 2, 3, 4 };
+ int e = a[1];
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[INIT:.*]] = alloca i32, align 4
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 76a85eab52380..9c85ed4a9e216 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -96,3 +96,36 @@ void foo2(vi4 p) {}
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
+
+void foo3() {
+ vi4 a = { 1, 2, 3, 4 };
+ int e = a[1];
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[INIT:.*]] = alloca i32, align 4
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index d2612a7310ad0..aeb268e84c71c 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -65,4 +65,36 @@ cir.func @local_vector_create_test() {
// CHECK: cir.return
// CHECK: }
+cir.func @vector_extract_element_test() {
+ %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
+ %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+ %2 = cir.const #cir.int<1> : !s32i
+ %3 = cir.const #cir.int<2> : !s32i
+ %4 = cir.const #cir.int<3> : !s32i
+ %5 = cir.const #cir.int<4> : !s32i
+ %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+ cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+ %8 = cir.const #cir.int<1> : !s32i
+ %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
+ cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
+ cir.return
+}
+
+// CHECK: cir.func @vector_extract_element_test() {
+// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
+// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CHECK: %2 = cir.const #cir.int<1> : !s32i
+// CHECK: %3 = cir.const #cir.int<2> : !s32i
+// CHECK: %4 = cir.const #cir.int<3> : !s32i
+// CHECK: %5 = cir.const #cir.int<4> : !s32i
+// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK: %8 = cir.const #cir.int<1> : !s32i
+// CHECK: %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
+// CHECK: cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK: cir.return
+// CHECK: }
+
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, besides nits
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec)) | ||
}]; | ||
|
||
let hasVerifier = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is implicit
let hasVerifier = 0; |
let summary = "Extract one element from a vector object"; | ||
let description = [{ | ||
The `cir.vec.extract` operation extracts the element at the given index | ||
from a vector object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we add some example here?
72b4aac
to
80bd84c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. While here, can you please implement a folder for this operation? It should kick-in if both idx and input vector are constants.
@bcardosolopes I implement it like this snippet
But I am thinking, is there a case that codegen will perform extractOp directly from ConstVec, not on load or get_global? I see a similar implementation in MLIR Vector Dialect 🤔 I will try to come up with a test case for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, with a small request.
void foo3() { | ||
vi4 a = { 1, 2, 3, 4 }; | ||
int e = a[1]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test where the index of the element being extracted is a variable?
This change adds extract op for VectorType
Issue #136487