Skip to content

[CIR] Upstream extract op for VectorType #138413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

AmrDeveloper
Copy link
Member

@AmrDeveloper AmrDeveloper commented May 3, 2025

This change adds extract op for VectorType

Issue #136487

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 3, 2025
@llvmbot
Copy link
Member

llvmbot commented May 3, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clangir

Author: Amr Hesham (AmrDeveloper)

Changes

This change adds local zero initialization for VectorType

Issue #136487


Full diff: https://github.com/llvm/llvm-project/pull/138413.diff

7 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+24)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+5-2)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+10-1)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10)
  • (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+33)
  • (modified) clang/test/CIR/CodeGen/vector.cpp (+33)
  • (modified) clang/test/CIR/IR/vector.cir (+32)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 422c89c4f9391..b2121dee8d8b3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1976,4 +1976,28 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// VecExtractOp
+//===----------------------------------------------------------------------===//
+
+def VecExtractOp : CIR_Op<"vec.extract", [Pure,
+  TypesMatchWith<"type of 'result' matches element type of 'vec'", "vec",
+                 "result", "cast<VectorType>($_self).getElementType()">]> {
+
+  let summary = "Extract one element from a vector object";
+  let description = [{
+    The `cir.vec.extract` operation extracts the element at the given index
+    from a vector object.
+  }];
+
+  let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
+  let results = (outs CIR_AnyType:$result);
+
+  let assemblyFormat = [{
+    $vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
+  }];
+
+  let hasVerifier = 0;
+}
+
 #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 8ead6e793b4c8..a59b87cb9241b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -161,8 +161,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
   mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
     if (e->getBase()->getType()->isVectorType()) {
       assert(!cir::MissingFeatures::scalableVectors());
-      cgf.getCIRGenModule().errorNYI("VisitArraySubscriptExpr: VectorType");
-      return {};
+
+      const mlir::Location loc = cgf.getLoc(e->getSourceRange());
+      const mlir::Value vecValue = Visit(e->getBase());
+      const mlir::Value indexValue = Visit(e->getIdx());
+      return cgf.builder.create<cir::VecExtractOp>(loc, vecValue, indexValue);
     }
     // Just load the lvalue formed by the subscript expression.
     return emitLoadOfLValue(e);
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 6137adb1e9936..66f29f8f6cdd0 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1600,7 +1600,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMStackRestoreOpLowering,
                CIRToLLVMTrapOpLowering,
                CIRToLLVMUnaryOpLowering,
-               CIRToLLVMVecCreateOpLowering
+               CIRToLLVMVecCreateOpLowering,
+               CIRToLLVMVecExtractOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -1709,6 +1710,14 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
+    cir::VecExtractOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
+      op, adaptor.getVec(), adaptor.getIndex());
+  return mlir::success();
+}
+
 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
   return std::make_unique<ConvertCIRToLLVMPass>();
 }
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index f248ea31e7844..026505ea31b4c 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -303,6 +303,16 @@ class CIRToLLVMVecCreateOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVecExtractOpLowering
+    : public mlir::OpConversionPattern<cir::VecExtractOp> {
+public:
+  using mlir::OpConversionPattern<cir::VecExtractOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VecExtractOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 } // namespace direct
 } // namespace cir
 
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index a3880a944de1f..aeeaf655cad18 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -109,3 +109,36 @@ void foo2(vi4 p) {}
 
 // OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
 // OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
+
+void foo3() {
+  vi4 a = { 1, 2, 3, 4 };
+  int e = a[1];
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[INIT:.*]] = alloca i32, align 4
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 76a85eab52380..9c85ed4a9e216 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -96,3 +96,36 @@ void foo2(vi4 p) {}
 
 // OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
 // OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
+
+void foo3() {
+  vi4 a = { 1, 2, 3, 4 };
+  int e = a[1];
+}
+
+// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
+
+// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[INIT:.*]] = alloca i32, align 4
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
+// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
+// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
+// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index d2612a7310ad0..aeb268e84c71c 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -65,4 +65,36 @@ cir.func @local_vector_create_test() {
 // CHECK:   cir.return
 // CHECK: }
 
+cir.func @vector_extract_element_test() {
+    %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
+    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+    %2 = cir.const #cir.int<1> : !s32i
+    %3 = cir.const #cir.int<2> : !s32i
+    %4 = cir.const #cir.int<3> : !s32i
+    %5 = cir.const #cir.int<4> : !s32i
+    %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+    cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+    %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+    %8 = cir.const #cir.int<1> : !s32i
+    %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
+    cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
+    cir.return
+}
+
+// CHECK: cir.func @vector_extract_element_test() {
+// CHECK:    %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
+// CHECK:    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
+// CHECK:    %2 = cir.const #cir.int<1> : !s32i
+// CHECK:    %3 = cir.const #cir.int<2> : !s32i
+// CHECK:    %4 = cir.const #cir.int<3> : !s32i
+// CHECK:    %5 = cir.const #cir.int<4> : !s32i
+// CHECK:    %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK:    cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK:    %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK:    %8 = cir.const #cir.int<1> : !s32i
+// CHECK:    %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
+// CHECK:    cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK:    cir.return
+// CHECK: }
+
 }

Copy link
Contributor

@xlauko xlauko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, besides nits

$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
}];

let hasVerifier = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is implicit

Suggested change
let hasVerifier = 0;

let summary = "Extract one element from a vector object";
let description = [{
The `cir.vec.extract` operation extracts the element at the given index
from a vector object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we add some example here?

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_vector_extract branch from 72b4aac to 80bd84c Compare May 3, 2025 19:59
Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. While here, can you please implement a folder for this operation? It should kick-in if both idx and input vector are constants.

@AmrDeveloper
Copy link
Member Author

Overall looks good. While here, can you please implement a folder for this operation? It should kick-in if both idx and input vector are constants.

@bcardosolopes I implement it like this snippet

OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
  const auto vectorAttr =
      llvm::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getVec());
  if (!vectorAttr)
    return {};

  const auto indexAttr =
      llvm::dyn_cast_if_present<cir::IntAttr>(adaptor.getIndex());
  if (!indexAttr)
    return {};

  const mlir::ArrayAttr elements = vectorAttr.getElts();
  const int64_t index = indexAttr.getSInt();
  return elements[index];
}

But I am thinking, is there a case that codegen will perform extractOp directly from ConstVec, not on load or get_global? I see a similar implementation in MLIR Vector Dialect 🤔 I will try to come up with a test case for testing

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, with a small request.

void foo3() {
vi4 a = { 1, 2, 3, 4 };
int e = a[1];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test where the index of the element being extracted is a variable?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants