Skip to content

Commit 6b24d4a

Browse files
authored
feat: more aggressive gather op simplification (#1903)
1 parent 0aabb97 commit 6b24d4a

File tree

4 files changed

+287
-10
lines changed

4 files changed

+287
-10
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13939,9 +13939,100 @@ struct GatherOpCanon final
1393913939

1394013940
LogicalResult matchAndRewriteImpl(stablehlo::GatherOp gather,
1394113941
PatternRewriter &rewriter) const {
13942+
if (tryRewriteGatherWithConstantStartIndices(gather, rewriter)
13943+
.succeeded()) {
13944+
return success();
13945+
}
13946+
13947+
if (tryRewriteGatherWithIotaIndexing(gather, rewriter).succeeded()) {
13948+
return success();
13949+
}
13950+
13951+
return failure();
13952+
}
13953+
13954+
LogicalResult
13955+
tryRewriteGatherWithIotaIndexing(stablehlo::GatherOp op,
13956+
PatternRewriter &rewriter) const {
13957+
auto operand = op.getOperand();
13958+
auto operandTy = cast<RankedTensorType>(operand.getType());
13959+
// TODO: check if this optimization is possible for higher dimenional
13960+
// tensors?
13961+
if (operandTy.getRank() != 1) {
13962+
return failure();
13963+
}
13964+
13965+
for (auto size : op.getSliceSizes()) {
13966+
if (size != 1) {
13967+
return failure();
13968+
}
13969+
}
13970+
13971+
auto indices = op.getStartIndices();
13972+
13973+
// size 1 index is implicitly an iota
13974+
if (indices.getType().getNumElements() == 1) {
13975+
auto scalarIndex =
13976+
stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), indices, {});
13977+
auto dsOp = stablehlo::DynamicSliceOpCreate(rewriter, op.getLoc(),
13978+
operand, {scalarIndex}, {1});
13979+
auto res =
13980+
stablehlo::ReshapeOpCreate(rewriter, op.getLoc(), dsOp,
13981+
cast<ShapedType>(op.getType()).getShape());
13982+
rewriter.replaceOp(op, res);
13983+
return success();
13984+
}
13985+
13986+
auto iotaLike = detectIotaLikeTensor(indices);
13987+
if (!iotaLike) {
13988+
return failure();
13989+
}
13990+
13991+
auto iota = *iotaLike;
13992+
auto dimNumbers = op.getDimensionNumbers();
13993+
13994+
if (dimNumbers.getStartIndexMap().size() <= iota.dimension ||
13995+
dimNumbers.getStartIndexMap()[iota.dimension] != 0) {
13996+
return failure();
13997+
}
13998+
13999+
auto indicesTy = cast<RankedTensorType>(indices.getType());
14000+
int64_t indexVectorDim = dimNumbers.getIndexVectorDim();
14001+
if (indexVectorDim < indicesTy.getRank() &&
14002+
indicesTy.getDimSize(indexVectorDim) != 1) {
14003+
return failure();
14004+
}
14005+
14006+
int64_t start = cast<IntegerAttr>(iota.start).getValue().getSExtValue();
14007+
int64_t count = indicesTy.getDimSize(iota.dimension);
14008+
int64_t stride = cast<IntegerAttr>(iota.scale).getValue().getSExtValue();
14009+
int64_t limit = start + count * stride;
14010+
14011+
auto resultTy = cast<RankedTensorType>(op.getType());
14012+
if (resultTy.getNumElements() != count) {
14013+
LLVM_DEBUG(op->emitError("expected num elements of result to match"));
14014+
return failure();
14015+
}
14016+
14017+
auto s = stablehlo::SliceOpCreate(rewriter, op.getLoc(), operand, {start},
14018+
{limit}, {stride});
14019+
14020+
if (s.getType() == resultTy) {
14021+
rewriter.replaceOp(op, s);
14022+
return success();
14023+
}
14024+
14025+
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, resultTy, s);
14026+
return success();
14027+
}
14028+
14029+
LogicalResult
14030+
tryRewriteGatherWithConstantStartIndices(stablehlo::GatherOp gather,
14031+
PatternRewriter &rewriter) const {
1394214032
DenseIntElementsAttr index;
13943-
if (!matchPattern(gather.getStartIndices(), m_Constant(&index)))
14033+
if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) {
1394414034
return failure();
14035+
}
1394514036

1394614037
stablehlo::GatherDimensionNumbersAttr dnums = gather.getDimensionNumbers();
1394714038
if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
// ============================================================================
4+
// Tests for gather with iota-like indexing that converts to slice
5+
// ============================================================================
6+
7+
// Simple iota indexing: gather with iota indices should become a slice
8+
func.func @gather_iota_to_slice(%arg0: tensor<10xi64>) -> tensor<5xi64> {
9+
%indices = stablehlo.iota dim = 0 : tensor<5x1xi64>
10+
%0 = "stablehlo.gather"(%arg0, %indices) {
11+
dimension_numbers = #stablehlo.gather<
12+
collapsed_slice_dims = [0],
13+
start_index_map = [0],
14+
index_vector_dim = 1
15+
>,
16+
slice_sizes = array<i64: 1>
17+
} : (tensor<10xi64>, tensor<5x1xi64>) -> tensor<5xi64>
18+
return %0 : tensor<5xi64>
19+
}
20+
// CHECK-LABEL: func.func @gather_iota_to_slice
21+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:5]
22+
// CHECK-NEXT: return %[[SLICE]]
23+
24+
// Iota with offset: gather with indices [2, 3, 4, 5] should become slice [2:6:1]
25+
func.func @gather_iota_offset_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
26+
%c = stablehlo.constant dense<2> : tensor<4x1xi64>
27+
%iota = stablehlo.iota dim = 0 : tensor<4x1xi64>
28+
%indices = stablehlo.add %iota, %c : tensor<4x1xi64>
29+
%0 = "stablehlo.gather"(%arg0, %indices) {
30+
dimension_numbers = #stablehlo.gather<
31+
collapsed_slice_dims = [0],
32+
start_index_map = [0],
33+
index_vector_dim = 1
34+
>,
35+
slice_sizes = array<i64: 1>
36+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
37+
return %0 : tensor<4xi64>
38+
}
39+
// CHECK-LABEL: func.func @gather_iota_offset_to_slice
40+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [2:6]
41+
// CHECK-NEXT: return %[[SLICE]]
42+
43+
// Iota with stride: gather with indices [0, 2, 4, 6] should become slice [0:8:2]
44+
func.func @gather_iota_stride_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
45+
%c = stablehlo.constant dense<2> : tensor<4x1xi64>
46+
%iota = stablehlo.iota dim = 0 : tensor<4x1xi64>
47+
%indices = stablehlo.multiply %iota, %c : tensor<4x1xi64>
48+
%0 = "stablehlo.gather"(%arg0, %indices) {
49+
dimension_numbers = #stablehlo.gather<
50+
collapsed_slice_dims = [0],
51+
start_index_map = [0],
52+
index_vector_dim = 1
53+
>,
54+
slice_sizes = array<i64: 1>
55+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
56+
return %0 : tensor<4xi64>
57+
}
58+
// CHECK-LABEL: func.func @gather_iota_stride_to_slice
59+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:8:2]
60+
// CHECK-NEXT: return %[[SLICE]]
61+
62+
// Iota with offset and stride: gather with indices [1, 3, 5, 7] should become slice [1:9:2]
63+
func.func @gather_iota_offset_stride_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
64+
%c_offset = stablehlo.constant dense<1> : tensor<4x1xi64>
65+
%c_scale = stablehlo.constant dense<2> : tensor<4x1xi64>
66+
%iota = stablehlo.iota dim = 0 : tensor<4x1xi64>
67+
%scaled = stablehlo.multiply %iota, %c_scale : tensor<4x1xi64>
68+
%indices = stablehlo.add %scaled, %c_offset : tensor<4x1xi64>
69+
%0 = "stablehlo.gather"(%arg0, %indices) {
70+
dimension_numbers = #stablehlo.gather<
71+
collapsed_slice_dims = [0],
72+
start_index_map = [0],
73+
index_vector_dim = 1
74+
>,
75+
slice_sizes = array<i64: 1>
76+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
77+
return %0 : tensor<4xi64>
78+
}
79+
// CHECK-LABEL: func.func @gather_iota_offset_stride_to_slice
80+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [1:9:2]
81+
// CHECK-NEXT: return %[[SLICE]]
82+
83+
// Constant iota-like indices: dense constant that forms an iota pattern
84+
func.func @gather_const_iota_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
85+
%indices = stablehlo.constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi64>
86+
%0 = "stablehlo.gather"(%arg0, %indices) {
87+
dimension_numbers = #stablehlo.gather<
88+
collapsed_slice_dims = [0],
89+
start_index_map = [0],
90+
index_vector_dim = 1
91+
>,
92+
slice_sizes = array<i64: 1>
93+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
94+
return %0 : tensor<4xi64>
95+
}
96+
// CHECK-LABEL: func.func @gather_const_iota_to_slice
97+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:4]
98+
// CHECK-NEXT: return %[[SLICE]]
99+
100+
// ============================================================================
101+
// Tests for gather with size-1 index -> dynamic_slice
102+
// ============================================================================
103+
104+
// Size-1 index: scalar-like index should become dynamic_slice
105+
func.func @gather_scalar_index_to_dynamic_slice(%arg0: tensor<10xi64>, %idx: tensor<1xi64>) -> tensor<1xi64> {
106+
%0 = "stablehlo.gather"(%arg0, %idx) {
107+
dimension_numbers = #stablehlo.gather<
108+
collapsed_slice_dims = [0],
109+
start_index_map = [0],
110+
index_vector_dim = 1
111+
>,
112+
slice_sizes = array<i64: 1>
113+
} : (tensor<10xi64>, tensor<1xi64>) -> tensor<1xi64>
114+
return %0 : tensor<1xi64>
115+
}
116+
// CHECK-LABEL: func.func @gather_scalar_index_to_dynamic_slice
117+
// CHECK: stablehlo.reshape
118+
// CHECK: stablehlo.dynamic_slice
119+
120+
// Floating point elements in gather
121+
func.func @gather_iota_float(%arg0: tensor<10xf64>) -> tensor<5xf64> {
122+
%indices = stablehlo.iota dim = 0 : tensor<5x1xi64>
123+
%0 = "stablehlo.gather"(%arg0, %indices) {
124+
dimension_numbers = #stablehlo.gather<
125+
collapsed_slice_dims = [0],
126+
start_index_map = [0],
127+
index_vector_dim = 1
128+
>,
129+
slice_sizes = array<i64: 1>
130+
} : (tensor<10xf64>, tensor<5x1xi64>) -> tensor<5xf64>
131+
return %0 : tensor<5xf64>
132+
}
133+
// CHECK-LABEL: func.func @gather_iota_float
134+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [0:5]
135+
// CHECK-NEXT: return %[[SLICE]]
136+
137+
// ============================================================================
138+
// Negative tests: should NOT be simplified
139+
// ============================================================================
140+
141+
// Non-1D operand: should not simplify (currently only supports 1D operands)
142+
func.func @gather_non_1d_operand(%arg0: tensor<4x4xi64>) -> tensor<2xi64> {
143+
%indices = stablehlo.constant dense<[[0, 0], [1, 1]]> : tensor<2x2xi64>
144+
%0 = "stablehlo.gather"(%arg0, %indices) {
145+
dimension_numbers = #stablehlo.gather<
146+
collapsed_slice_dims = [0, 1],
147+
start_index_map = [0, 1],
148+
index_vector_dim = 1
149+
>,
150+
slice_sizes = array<i64: 1, 1>
151+
} : (tensor<4x4xi64>, tensor<2x2xi64>) -> tensor<2xi64>
152+
return %0 : tensor<2xi64>
153+
}
154+
// CHECK-LABEL: func.func @gather_non_1d_operand
155+
// CHECK: stablehlo.gather
156+
157+
// Slice sizes not all 1: should not simplify
158+
func.func @gather_slice_size_not_1(%arg0: tensor<10xi64>) -> tensor<4x2xi64> {
159+
%indices = stablehlo.iota dim = 0 : tensor<4x1xi64>
160+
%0 = "stablehlo.gather"(%arg0, %indices) {
161+
dimension_numbers = #stablehlo.gather<
162+
offset_dims = [1],
163+
start_index_map = [0],
164+
index_vector_dim = 1
165+
>,
166+
slice_sizes = array<i64: 2>
167+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4x2xi64>
168+
return %0 : tensor<4x2xi64>
169+
}
170+
// CHECK-LABEL: func.func @gather_slice_size_not_1
171+
// CHECK: stablehlo.gather
172+
173+
// Non-iota indices: random indices should not simplify
174+
func.func @gather_non_iota_indices(%arg0: tensor<10xi64>) -> tensor<4xi64> {
175+
%indices = stablehlo.constant dense<[[3], [1], [4], [2]]> : tensor<4x1xi64>
176+
%0 = "stablehlo.gather"(%arg0, %indices) {
177+
dimension_numbers = #stablehlo.gather<
178+
collapsed_slice_dims = [0],
179+
start_index_map = [0],
180+
index_vector_dim = 1
181+
>,
182+
slice_sizes = array<i64: 1>
183+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
184+
return %0 : tensor<4xi64>
185+
}
186+
// CHECK-LABEL: func.func @gather_non_iota_indices
187+
// CHECK: stablehlo.gather

test/lit_tests/linalg/lu.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@ module {
3737
// CPU-NEXT: %7 = stablehlo.add %iterArg, %c_0 {enzymexla.bounds = {{.*}}} : tensor<i32>
3838
// CPU-NEXT: %8 = stablehlo.dynamic_slice %1, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
3939
// CPU-NEXT: %9 = stablehlo.dynamic_slice %iterArg_4, %iterArg, sizes = [1] : (tensor<64xi64>, tensor<i32>) -> tensor<1xi64>
40-
// CPU-NEXT: %10 = "stablehlo.gather"(%iterArg_4, %8) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], start_index_map = [0]>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<64xi64>, tensor<1xi64>) -> tensor<1xi64>
41-
// CPU-NEXT: %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor<i32>) -> tensor<64xi64>
42-
// CPU-NEXT: %12 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor<i64>
43-
// CPU-NEXT: %13 = "stablehlo.scatter"(%11, %8, %12) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = false}> ({
40+
// CPU-NEXT: %10 = stablehlo.reshape %8 : (tensor<1xi64>) -> tensor<i64>
41+
// CPU-NEXT: %11 = stablehlo.dynamic_slice %iterArg_4, %10, sizes = [1] : (tensor<64xi64>, tensor<i64>) -> tensor<1xi64>
42+
// CPU-NEXT: %12 = stablehlo.dynamic_update_slice %iterArg_4, %11, %iterArg : (tensor<64xi64>, tensor<1xi64>, tensor<i32>) -> tensor<64xi64>
43+
// CPU-NEXT: %13 = stablehlo.reshape %9 : (tensor<1xi64>) -> tensor<i64>
44+
// CPU-NEXT: %14 = "stablehlo.scatter"(%12, %8, %13) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = false}> ({
4445
// CPU-NEXT: ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
4546
// CPU-NEXT: stablehlo.return %arg2 : tensor<i64>
4647
// CPU-NEXT: }) : (tensor<64xi64>, tensor<1xi64>, tensor<i64>) -> tensor<64xi64>
47-
// CPU-NEXT: stablehlo.return %7, %13 : tensor<i32>, tensor<64xi64>
48+
// CPU-NEXT: stablehlo.return %7, %14 : tensor<i32>, tensor<64xi64>
4849
// CPU-NEXT: }
4950
// CPU-NEXT: %3 = stablehlo.add %2#1, %c_2 : tensor<64xi64>
5051
// CPU-NEXT: %4 = stablehlo.convert %0#1 : (tensor<64xi64>) -> tensor<64xi32>

test/lit_tests/raising/affine_to_stablehlo13.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ module {
2929
}
3030
}
3131
// CHECK: func.func private @single_dim_raised(%arg0: tensor<3xi64>, %arg1: tensor<3xi64>) -> (tensor<3xi64>, tensor<3xi64>) {
32-
// CHECK-NEXT: %0 = stablehlo.iota dim = 0 : tensor<3x1xi64>
33-
// CHECK-NEXT: %1 = "stablehlo.gather"(%arg1, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<3xi64>, tensor<3x1xi64>) -> tensor<3xi64>
34-
// CHECK-NEXT: return %1, %arg1 : tensor<3xi64>, tensor<3xi64>
32+
// CHECK-NEXT: return %arg1, %arg1 : tensor<3xi64>, tensor<3xi64>
3533
// CHECK-NEXT: }
3634

3735
module {
@@ -108,10 +106,10 @@ module {
108106
}
109107
return
110108
}
109+
}
111110
// CHECK: func.func private @multiple_ivs_per_index_lanes_raised(%arg0: tensor<10x10xi64>, %arg1: tensor<10xf64>, %arg2: tensor<10x10xf64>) -> (tensor<10x10xi64>, tensor<10xf64>, tensor<10x10xf64>) {
112111
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<10x10xi64>) -> tensor<100x1xi64>
113112
// CHECK-NEXT: %1 = "stablehlo.gather"(%arg1, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<10xf64>, tensor<100x1xi64>) -> tensor<100xf64>
114113
// CHECK-NEXT: %2 = stablehlo.reshape %1 : (tensor<100xf64>) -> tensor<10x10xf64>
115114
// CHECK-NEXT: return %arg0, %arg1, %2 : tensor<10x10xi64>, tensor<10xf64>, tensor<10x10xf64>
116115
// CHECK-NEXT: }
117-
}

0 commit comments

Comments
 (0)