Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s
// RUN: stablehlo-opt --stablehlo-aggressive-folder=fold-op-element-limit=100 --split-input-file --verify-diagnostics %s | FileCheck %s

////////
// AddOp
Expand Down Expand Up @@ -41,6 +41,21 @@ func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>)

// -----

////////
// ClampOp

// CHECK-LABEL: func.func @clamp_fold
func.func @clamp_fold(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%min = stablehlo.constant dense<[1, 5, 10]> : tensor<3xi32>
%max = stablehlo.constant dense<[10, 25, 12]> : tensor<3xi32>
%operand = stablehlo.constant dense<[0, 30, 11]> : tensor<3xi32>
// CHECK: stablehlo.constant dense<[1, 25, 11]> : tensor<3xi32>
%0 = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
func.return %0: tensor<3xi32>
}

// -----

////////
// CompareOp

Expand Down Expand Up @@ -102,6 +117,26 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>,

// -----

////////
// DivOp

// CHECK-LABEL: @div_fold_cst
func.func @div_fold_cst() -> (tensor<i32>, tensor<ui32>, tensor<f32>) {
%cst = stablehlo.constant dense<2> : tensor<i32>
%cst_1 = stablehlo.constant dense<2> : tensor<ui32>
%cst_2 = stablehlo.constant dense<2.0> : tensor<f32>
// CHECK: stablehlo.constant dense<1> : tensor<i32>
// CHECK: stablehlo.constant dense<1> : tensor<ui32>
// CHECK: stablehlo.divide{{.*}} : tensor<f32>
// DISABLED-CHECK: stablehlo.constant dense<1.0{{.*}}> : tensor<f32>
%0 = stablehlo.divide %cst, %cst : tensor<i32>
%1 = stablehlo.divide %cst_1, %cst_1 : tensor<ui32>
%2 = stablehlo.divide %cst_2, %cst_2 : tensor<f32>
return %0, %1, %2 : tensor<i32>, tensor<ui32>, tensor<f32>
}

// -----

////////
// MulOp

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s
// RUN: stablehlo-opt --stablehlo-aggressive-simplification=fold-op-element-limit=100 --allow-unregistered-dialect --split-input-file %s | FileCheck %s

/////////
// AddOp
Expand Down
10 changes: 5 additions & 5 deletions stablehlo/tests/transforms/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -521,16 +521,16 @@ func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> {
// -----

// CHECK-LABEL: func @eval_slice_zerorank
func.func @eval_slice_zerorank() -> tensor<f32> {
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor<f32>
func.func @eval_slice_zerorank() -> tensor<i32> {
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<33> : tensor<i32>
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<33.0> : tensor<f32>
%0 = stablehlo.constant dense<33> : tensor<i32>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64>,
limit_indices = array<i64>,
strides = array<i64>
} : (tensor<f32>) -> tensor<f32>
func.return %1 : tensor<f32>
} : (tensor<i32>) -> tensor<i32>
func.return %1 : tensor<i32>
}

// -----
Expand Down
4 changes: 4 additions & 0 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <limits>
#include <tuple>
#include <utility>

Expand Down Expand Up @@ -1038,8 +1039,11 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
// Populate additional patterns for StableHLO extensions.
state.addAdditionalPatterns(patterns);

// No float folding and fold as much as possible. Shape refinement will fail
// if int shape computations are unable to be folded.
StablehloAggressiveFolderPassOptions folderOptions;
folderOptions.optimizeFloat = false;
folderOptions.foldOpElementLimit = std::numeric_limits<int64_t>::max();

// The folding patterns implement partial evaluation of shape computations
// which is a critical part of implementing type refinement for ops like
Expand Down
Loading
Loading