Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13854,9 +13854,100 @@ struct GatherOpCanon final

LogicalResult matchAndRewriteImpl(stablehlo::GatherOp gather,
PatternRewriter &rewriter) const {
if (tryRewriteGatherWithConstantStartIndices(gather, rewriter)
.succeeded()) {
return success();
}

if (tryRewriteGatherWithIotaIndexing(gather, rewriter).succeeded()) {
return success();
}

return failure();
}

LogicalResult
tryRewriteGatherWithIotaIndexing(stablehlo::GatherOp op,
PatternRewriter &rewriter) const {
auto operand = op.getOperand();
auto operandTy = cast<RankedTensorType>(operand.getType());
// TODO: check if this optimization is possible for higher dimenional
// tensors?
if (operandTy.getRank() != 1) {
return failure();
}

for (auto size : op.getSliceSizes()) {
if (size != 1) {
return failure();
}
}

auto indices = op.getStartIndices();

// size 1 index is implicitly an iota
if (indices.getType().getNumElements() == 1) {
auto scalarIndex =
stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), indices, {});
auto dsOp = stablehlo::DynamicSliceOpCreate(rewriter, op.getLoc(),
operand, {scalarIndex}, {1});
auto res =
stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), dsOp,
cast<ShapedType>(op.getType()).getShape());
rewriter.replaceOp(op, res);
return success();
}

auto iotaLike = detectIotaLikeTensor(indices);
if (!iotaLike) {
return failure();
}

auto iota = *iotaLike;
auto dimNumbers = op.getDimensionNumbers();

if (dimNumbers.getStartIndexMap().size() <= iota.dimension ||
dimNumbers.getStartIndexMap()[iota.dimension] != 0) {
return failure();
}

auto indicesTy = cast<RankedTensorType>(indices.getType());
int64_t indexVectorDim = dimNumbers.getIndexVectorDim();
if (indexVectorDim < indicesTy.getRank() &&
indicesTy.getDimSize(indexVectorDim) != 1) {
return failure();
}

int64_t start = cast<IntegerAttr>(iota.start).getValue().getSExtValue();
int64_t count = indicesTy.getDimSize(iota.dimension);
int64_t stride = cast<IntegerAttr>(iota.scale).getValue().getSExtValue();
int64_t limit = start + count * stride;

auto resultTy = cast<RankedTensorType>(op.getType());
if (resultTy.getNumElements() != count) {
LLVM_DEBUG(op->emitError("expected num elements of result to match"));
return failure();
}

auto s = stablehlo::SliceOpCreate(rewriter, op.getLoc(), operand, {start},
{limit}, {stride});

if (s.getType() == resultTy) {
rewriter.replaceOp(op, s);
return success();
}

rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, resultTy, s);
return success();
}

LogicalResult
tryRewriteGatherWithConstantStartIndices(stablehlo::GatherOp gather,
PatternRewriter &rewriter) const {
DenseIntElementsAttr index;
if (!matchPattern(gather.getStartIndices(), m_Constant(&index)))
if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) {
return failure();
}

stablehlo::GatherDimensionNumbersAttr dnums = gather.getDimensionNumbers();
if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1)
Expand Down
187 changes: 187 additions & 0 deletions test/lit_tests/gather_iota_to_slice.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// ============================================================================
// Tests for gather with iota-like indexing that converts to slice
// ============================================================================

// Simple iota indexing: gather with iota indices should become a slice
func.func @gather_iota_to_slice(%arg0: tensor<10xi64>) -> tensor<5xi64> {
%indices = stablehlo.iota dim = 0 : tensor<5x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<5x1xi64>) -> tensor<5xi64>
return %0 : tensor<5xi64>
}
// CHECK-LABEL: func.func @gather_iota_to_slice
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:5]
// CHECK-NEXT: return %[[SLICE]]

// Iota with offset: gather with indices [2, 3, 4, 5] should become slice [2:6:1]
func.func @gather_iota_offset_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
%c = stablehlo.constant dense<2> : tensor<4x1xi64>
%iota = stablehlo.iota dim = 0 : tensor<4x1xi64>
%indices = stablehlo.add %iota, %c : tensor<4x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
return %0 : tensor<4xi64>
}
// CHECK-LABEL: func.func @gather_iota_offset_to_slice
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [2:6]
// CHECK-NEXT: return %[[SLICE]]

// Iota with stride: gather with indices [0, 2, 4, 6] should become slice [0:8:2]
func.func @gather_iota_stride_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
%c = stablehlo.constant dense<2> : tensor<4x1xi64>
%iota = stablehlo.iota dim = 0 : tensor<4x1xi64>
%indices = stablehlo.multiply %iota, %c : tensor<4x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
return %0 : tensor<4xi64>
}
// CHECK-LABEL: func.func @gather_iota_stride_to_slice
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:8:2]
// CHECK-NEXT: return %[[SLICE]]

// Iota with offset and stride: gather with indices [1, 3, 5, 7] should become slice [1:9:2]
func.func @gather_iota_offset_stride_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
%c_offset = stablehlo.constant dense<1> : tensor<4x1xi64>
%c_scale = stablehlo.constant dense<2> : tensor<4x1xi64>
%iota = stablehlo.iota dim = 0 : tensor<4x1xi64>
%scaled = stablehlo.multiply %iota, %c_scale : tensor<4x1xi64>
%indices = stablehlo.add %scaled, %c_offset : tensor<4x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
return %0 : tensor<4xi64>
}
// CHECK-LABEL: func.func @gather_iota_offset_stride_to_slice
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [1:9:2]
// CHECK-NEXT: return %[[SLICE]]

// Constant iota-like indices: dense constant that forms an iota pattern
func.func @gather_const_iota_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
%indices = stablehlo.constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
return %0 : tensor<4xi64>
}
// CHECK-LABEL: func.func @gather_const_iota_to_slice
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:4]
// CHECK-NEXT: return %[[SLICE]]

// ============================================================================
// Tests for gather with size-1 index -> dynamic_slice
// ============================================================================

// Size-1 index: scalar-like index should become dynamic_slice
func.func @gather_scalar_index_to_dynamic_slice(%arg0: tensor<10xi64>, %idx: tensor<1xi64>) -> tensor<1xi64> {
%0 = "stablehlo.gather"(%arg0, %idx) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<1xi64>) -> tensor<1xi64>
return %0 : tensor<1xi64>
}
// CHECK-LABEL: func.func @gather_scalar_index_to_dynamic_slice
// CHECK: stablehlo.reshape
// CHECK: stablehlo.dynamic_slice

// Floating point elements in gather
func.func @gather_iota_float(%arg0: tensor<10xf64>) -> tensor<5xf64> {
%indices = stablehlo.iota dim = 0 : tensor<5x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xf64>, tensor<5x1xi64>) -> tensor<5xf64>
return %0 : tensor<5xf64>
}
// CHECK-LABEL: func.func @gather_iota_float
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:5]
// CHECK-NEXT: return %[[SLICE]]

// ============================================================================
// Negative tests: should NOT be simplified
// ============================================================================

// Non-1D operand: should not simplify (currently only supports 1D operands)
func.func @gather_non_1d_operand(%arg0: tensor<4x4xi64>) -> tensor<2xi64> {
%indices = stablehlo.constant dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0, 1],
start_index_map = [0, 1],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1, 1>
} : (tensor<4x4xi64>, tensor<2x2xi64>) -> tensor<2xi64>
return %0 : tensor<2xi64>
}
// CHECK-LABEL: func.func @gather_non_1d_operand
// CHECK: stablehlo.gather

// Slice sizes not all 1: should not simplify
func.func @gather_slice_size_not_1(%arg0: tensor<10xi64>) -> tensor<4x2xi64> {
%indices = stablehlo.iota dim = 0 : tensor<4x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [1],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 2>
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4x2xi64>
return %0 : tensor<4x2xi64>
}
// CHECK-LABEL: func.func @gather_slice_size_not_1
// CHECK: stablehlo.gather

// Non-iota indices: random indices should not simplify
func.func @gather_non_iota_indices(%arg0: tensor<10xi64>) -> tensor<4xi64> {
%indices = stablehlo.constant dense<[[3], [1], [4], [2]]> : tensor<4x1xi64>
%0 = "stablehlo.gather"(%arg0, %indices) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [0],
start_index_map = [0],
index_vector_dim = 1
>,
slice_sizes = array<i64: 1>
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
return %0 : tensor<4xi64>
}
// CHECK-LABEL: func.func @gather_non_iota_indices
// CHECK: stablehlo.gather
11 changes: 6 additions & 5 deletions test/lit_tests/linalg/lu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ module {
// CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 {enzymexla.bounds = {{.*}}} : tensor<i32>
// CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], start_index_map = [0]>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64>
// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor<i32>) -> tensor<64xi64>
// CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor<i64>
// CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = false}> ({
// CPU-NEXT: %10 = stablehlo.reshape %8 : (tensor<1xi64>) -> tensor<i64>
// CPU-NEXT: %11 = stablehlo.dynamic_slice %iterArg_4, %10, sizes = [1] : (tensor<64xi64>, tensor<i64>) -> tensor<1xi64>
// CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_4, %11, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor<i32>) -> tensor<64xi64>
// CPU-NEXT: %13 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor<i64>
// CPU-NEXT: %14 = "stablehlo.scatter"(%12, %8, %13) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = false}> ({
// CPU-NEXT: ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
// CPU-NEXT: stablehlo.return %arg2 : tensor<i64>
// CPU-NEXT: }) : (tensor<64xi64>, tensor<1xi64>, tensor<i64>) -> tensor<64xi64>
// CPU-NEXT: stablehlo.return %7, %13 : tensor<i32>, tensor<64xi64>
// CPU-NEXT: stablehlo.return %7, %14 : tensor<i32>, tensor<64xi64>
// CPU-NEXT: }
// CPU-NEXT: %3 = stablehlo.add %2#1, %c_2 : tensor<64xi64>
// CPU-NEXT: %4 = stablehlo.convert %0#1 : (tensor<64xi64>) -> tensor<64xi32>
Expand Down
6 changes: 2 additions & 4 deletions test/lit_tests/raising/affine_to_stablehlo13.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ module {
}
}
// CHECK: func.func private @single_dim_raised(%arg0: tensor<3xi64>, %arg1: tensor<3xi64>) -> (tensor<3xi64>, tensor<3xi64>) {
// CHECK-NEXT: %0 = stablehlo.iota dim = 0 : tensor<3x1xi64>
// CHECK-NEXT: %1 = "stablehlo.gather"(%arg1, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<3xi64>, tensor<3x1xi64>) -> tensor<3xi64>
// CHECK-NEXT: return %1, %arg1 : tensor<3xi64>, tensor<3xi64>
// CHECK-NEXT: return %arg1, %arg1 : tensor<3xi64>, tensor<3xi64>
// CHECK-NEXT: }

module {
Expand Down Expand Up @@ -108,10 +106,10 @@ module {
}
return
}
}
// CHECK: func.func private @multiple_ivs_per_index_lanes_raised(%arg0: tensor<10x10xi64>, %arg1: tensor<10xf64>, %arg2: tensor<10x10xf64>) -> (tensor<10x10xi64>, tensor<10xf64>, tensor<10x10xf64>) {
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<10x10xi64>) -> tensor<100x1xi64>
// CHECK-NEXT: %1 = "stablehlo.gather"(%arg1, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<10xf64>, tensor<100x1xi64>) -> tensor<100xf64>
// CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<100xf64>) -> tensor<10x10xf64>
// CHECK-NEXT: return %arg0, %arg1, %2 : tensor<10x10xi64>, tensor<10xf64>, tensor<10x10xf64>
// CHECK-NEXT: }
}
Loading