Skip to content

Commit

Permalink
Refactor PolynomialApproximationPass into MathTransformPass. (#19922
Browse files Browse the repository at this point in the history
)

The name `PolynomialApproximationPass` was a misnomer since that pass
did more than polynomial approximation. It also does other
non-approximative rewrites, and casts to f32.

This PR renames it and refactors it to explicitly adjust the rewrites to
the target. This also reverse-engineers, reimplements and deprecates the
`clNativeMathPrecision` flag which had unwitting semantics.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Feb 12, 2025
1 parent 3fce185 commit 7c0259c
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 116 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ iree_compiler_cc_library(
"MaterializeEncodingIntoPadding.cpp",
"MaterializeEncodingPatterns.cpp",
"MaterializeTuningSpecsPass.cpp",
"MathTransformPass.cpp",
"MemrefCopyToLinalg.cpp",
"NormalizeLoopBounds.cpp",
"OptimizeTensorInsertExtractSlices.cpp",
"OptimizeVectorTransferPass.cpp",
"PadDynamicAlloc.cpp",
"PassUtils.cpp",
"Passes.cpp",
"PolynomialApproximationPass.cpp",
"PropagateDispatchSizeBounds.cpp",
"PropagateReshapesByExpansion.cpp",
"ReconcileTranslationInfo.cpp",
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ iree_cc_library(
"MaterializeEncodingIntoPadding.cpp"
"MaterializeEncodingPatterns.cpp"
"MaterializeTuningSpecsPass.cpp"
"MathTransformPass.cpp"
"MemrefCopyToLinalg.cpp"
"NormalizeLoopBounds.cpp"
"OptimizeTensorInsertExtractSlices.cpp"
"OptimizeVectorTransferPass.cpp"
"PadDynamicAlloc.cpp"
"PassUtils.cpp"
"Passes.cpp"
"PolynomialApproximationPass.cpp"
"PropagateDispatchSizeBounds.cpp"
"PropagateReshapesByExpansion.cpp"
"ReconcileTranslationInfo.cpp"
Expand Down
164 changes: 164 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

/// Deprecated! This flag had buggy/unintentional semantics.
/// Its original comment said:
/// ""use native hardware operations instead of polynomial approximation".
static llvm::cl::opt<bool> clNativeMathPrecision(
"iree-codegen-gpu-native-math-precision",
llvm::cl::desc("Deprecated! This flag had buggy/unintentional semantics. "
"Its original description said: \"Skip polynomial lowering "
"for math op natively available on GPU.\""),
llvm::cl::init(false));

#define GEN_PASS_DEF_MATHTRANSFORMPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

static void populateMathFunctionsRewritePatterns(
RewritePatternSet &patterns,
const std::function<bool(StringRef)> &predicate) {
if (predicate(math::TanOp::getOperationName())) {
populateExpandTanPattern(patterns);
}
if (predicate(math::SinhOp::getOperationName())) {
populateExpandSinhPattern(patterns);
}
if (predicate(math::CoshOp::getOperationName())) {
populateExpandCoshPattern(patterns);
}
if (predicate(math::AsinhOp::getOperationName())) {
populateExpandAsinhPattern(patterns);
}
if (predicate(math::AcoshOp::getOperationName())) {
populateExpandAcoshPattern(patterns);
}
if (predicate(math::AtanhOp::getOperationName())) {
populateExpandAtanhPattern(patterns);
}
if (predicate(math::PowFOp::getOperationName())) {
populateExpandPowFPattern(patterns);
}
if (predicate(math::FPowIOp::getOperationName())) {
populateExpandFPowIPattern(patterns);
}
if (predicate(math::Exp2Op::getOperationName())) {
populateExpandExp2FPattern(patterns);
}
if (predicate(math::RoundEvenOp::getOperationName())) {
populateExpandRoundEvenPattern(patterns);
}
}

static bool predicateRewrite(StringRef name,
IREE::HAL::ExecutableTargetAttr target) {
(void)target; // Currently unused.
if (clNativeMathPrecision) { // Legacy.
if (name == math::Exp2Op::getOperationName() ||
name == math::RoundEvenOp::getOperationName()) {
return false;
}
}
// Currently enable all non-approximative rewrites.
return true;
}

static bool predicateF32Cast(StringRef name,
IREE::HAL::ExecutableTargetAttr target) {
(void)target; // Currently unused.
if (clNativeMathPrecision) { // Legacy.
return false;
}
StringRef atan = math::AtanOp::getOperationName();
StringRef atan2 = math::Atan2Op::getOperationName();
StringRef cos = math::CosOp::getOperationName();
StringRef sin = math::SinOp::getOperationName();
StringRef tanh = math::TanhOp::getOperationName();
StringRef log = math::LogOp::getOperationName();
StringRef log2 = math::Log2Op::getOperationName();
StringRef log1p = math::Log1pOp::getOperationName();
StringRef exp = math::ExpOp::getOperationName();
StringRef expm1 = math::ExpM1Op::getOperationName();
StringRef cbrt = math::CbrtOp::getOperationName();
StringRef erf = math::ErfOp::getOperationName();
return llvm::is_contained(
{atan, atan2, tanh, log, log2, log1p, erf, exp, expm1, cbrt, sin, cos},
name);
}

static bool predicateApprox(StringRef name,
IREE::HAL::ExecutableTargetAttr target) {
(void)target; // Currently unused.
if (clNativeMathPrecision) { // Legacy.
if (name == math::ErfOp::getOperationName()) {
// The legacy implementation had a bug: it always applied polynomial
// approximation of math.erf, even when clNativeMathPrecision was passed.
// We actually have CI tests that rely on that bug: they pass
// clNativeMathPrecision but fail unless math.erf is approximated.
return true;
}
return false;
}
StringRef acos = math::AcosOp::getOperationName();
StringRef asin = math::AsinOp::getOperationName();
StringRef atan = math::AtanOp::getOperationName();
StringRef atan2 = math::Atan2Op::getOperationName();
StringRef cos = math::CosOp::getOperationName();
StringRef sin = math::SinOp::getOperationName();
StringRef tanh = math::TanhOp::getOperationName();
StringRef log = math::LogOp::getOperationName();
StringRef log2 = math::Log2Op::getOperationName();
StringRef log1p = math::Log1pOp::getOperationName();
StringRef exp = math::ExpOp::getOperationName();
StringRef expm1 = math::ExpM1Op::getOperationName();
StringRef cbrt = math::CbrtOp::getOperationName();
StringRef erf = math::ErfOp::getOperationName();
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin,
acos, exp, expm1, cbrt, sin, cos},
name);
}

namespace {

class MathTransformPass final
: public impl::MathTransformPassBase<MathTransformPass> {
public:
using Base::Base;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
auto target = IREE::HAL::ExecutableTargetAttr::lookup(getOperation());
if (!target) {
return signalPassFailure();
}
populateMathFunctionsRewritePatterns(patterns, [target](StringRef name) {
return predicateRewrite(name, target);
});

populateMathF32ExpansionPatterns(patterns, [target](StringRef name) {
return predicateF32Cast(name, target);
});

populateMathPolynomialApproximationPatterns(
patterns,
[target](StringRef name) { return predicateApprox(name, target); });

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace
} // namespace mlir::iree_compiler
12 changes: 3 additions & 9 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,9 @@ def PadDynamicAllocPass :
let summary = "Pass to pad dynamic alloc into static one.";
}

def PolynomialApproximationPass :
Pass<"iree-codegen-polynomial-approximation", ""> {
let summary = "Convert math operations to their polynomial approximation";
let options = [
ListOption<"noApproxOps", "no-approx-ops", "std::string",
[{List of operations that should not be approximated.\n"
"As of now, possible options are:\n"
"\ttan, sinh, cosh, asinh, acosh, atanh, powf, fpowf, erf\n}]>,
];
def MathTransformPass :
Pass<"iree-codegen-math-transform", ""> {
let summary = "Apply math ops transformations: approximations, rewrites to other math ops, operand casts.";
}

def PropagateDispatchSizeBoundsPass :
Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ iree_lit_test_suite(
"materialize_tuning_specs_invalid_spec.mlir",
"materialize_user_config_from_tuning_spec.mlir",
"materialize_user_configs.mlir",
"math_transform.mlir",
"normalize_loop_bounds.mlir",
"optimize_tensor_insert_extract_slices.mlir",
"pad_dynamic_alloc.mlir",
"polynomial_approximation.mlir",
"propagate_dispatch_size_bounds.mlir",
"propagate_reshapes_by_expansion.mlir",
"reconcile_translation_info.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ iree_lit_test_suite(
"materialize_tuning_specs_invalid_spec.mlir"
"materialize_user_config_from_tuning_spec.mlir"
"materialize_user_configs.mlir"
"math_transform.mlir"
"normalize_loop_bounds.mlir"
"optimize_tensor_insert_extract_slices.mlir"
"pad_dynamic_alloc.mlir"
"polynomial_approximation.mlir"
"propagate_dispatch_size_bounds.mlir"
"propagate_reshapes_by_expansion.mlir"
"reconcile_translation_info.mlir"
Expand Down
57 changes: 57 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-math-transform))' --split-input-file %s | FileCheck %s

// CHECK-LABEL: @rewrite_tan
func.func @rewrite_tan(%arg0: f16) -> f16 attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>
} {
// Tan should be directly approximated by a rational function. It's also possible
// (though not good) that it gets rewritten as sin/cos and those get approximated by
// rational functions. Either way, we expect to see rational arithmetic here, on f32
// as the operands get casted to f32.
// CHECK-NOT: math.tan
// CHECK-NOT: math.sin
// CHECK-NOT: math.cos
// CHECK: math.fma {{.*}} : f32
// Final division after cast to f16.
// CHECK: arith.divf {{.*}} : f16
%0 = math.tan %arg0 : f16
return %0 : f16
}

// -----

// CHECK-LABEL: @rewrite_pow
func.func @rewrite_pow(%arg0: f16, %arg1: f16) -> f16 attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>
} {

// Powf should be either directly approximated, or first rewritten into log and
// exp and then those get approximated. Some targets with fast exponentials might
// prefer to keep the exponential form, but this is not the case with the current
// lowering for CPU, so we expect to see rational arithmetic here, on f32 as the
// operands get casted to f32.
// CHECK-NOT: math.powf
// CHECK-NOT: math.exp
// CHECK-NOT: math.log
// CHECK: math.fma {{.*}} : f32
%0 = math.powf %arg0, %arg1 : f16
return %0 : f16
}

// -----

// CHECK-LABEL: @rewrite_erf
func.func @rewrite_erf(%arg0: f16) -> f16 attributes {
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>
} {
// Erf should be directly approximated by a rational function. Some targets
// with fast exponentials might prefer an exponential approximation, but this
// is not the case with the current lowering for CPU, so we expect to see rational
// arithmetic here, on f32 as the operands get casted to f32.
// CHECK-NOT: math.erf
// CHECK-NOT: math.exp
// CHECK-NOT: math.log
// CHECK: math.fma {{.*}} : f32
%0 = math.erf %arg0 : f16
return %0 : f16
}

This file was deleted.

Loading

0 comments on commit 7c0259c

Please sign in to comment.