-
Notifications
You must be signed in to change notification settings - Fork 663
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
) The name `PolynomialApproximationPass` was a misnomer since that pass did more than polynomial approximation. It also does other non-approximative rewrites, and casts to f32. This PR renames it and refactors it to explicitly adjust the rewrites to the target. This also reverse-engineers, reimplements and deprecates the `clNativeMathPrecision` flag which had unwitting semantics. Signed-off-by: Benoit Jacob <[email protected]>
- Loading branch information
Showing
12 changed files
with
234 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
164 changes: 164 additions & 0 deletions
164
compiler/src/iree/compiler/Codegen/Common/MathTransformPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
// Copyright 2022 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "mlir/Dialect/Math/Transforms/Approximation.h" | ||
#include "mlir/Dialect/Math/Transforms/Passes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
/// Deprecated! This flag had buggy/unintentional semantics. | ||
/// Its original comment said: | ||
/// ""use native hardware operations instead of polynomial approximation". | ||
static llvm::cl::opt<bool> clNativeMathPrecision( | ||
"iree-codegen-gpu-native-math-precision", | ||
llvm::cl::desc("Deprecated! This flag had buggy/unintentional semantics. " | ||
"Its original description said: \"Skip polynomial lowering " | ||
"for math op natively available on GPU.\""), | ||
llvm::cl::init(false)); | ||
|
||
#define GEN_PASS_DEF_MATHTRANSFORMPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
static void populateMathFunctionsRewritePatterns( | ||
RewritePatternSet &patterns, | ||
const std::function<bool(StringRef)> &predicate) { | ||
if (predicate(math::TanOp::getOperationName())) { | ||
populateExpandTanPattern(patterns); | ||
} | ||
if (predicate(math::SinhOp::getOperationName())) { | ||
populateExpandSinhPattern(patterns); | ||
} | ||
if (predicate(math::CoshOp::getOperationName())) { | ||
populateExpandCoshPattern(patterns); | ||
} | ||
if (predicate(math::AsinhOp::getOperationName())) { | ||
populateExpandAsinhPattern(patterns); | ||
} | ||
if (predicate(math::AcoshOp::getOperationName())) { | ||
populateExpandAcoshPattern(patterns); | ||
} | ||
if (predicate(math::AtanhOp::getOperationName())) { | ||
populateExpandAtanhPattern(patterns); | ||
} | ||
if (predicate(math::PowFOp::getOperationName())) { | ||
populateExpandPowFPattern(patterns); | ||
} | ||
if (predicate(math::FPowIOp::getOperationName())) { | ||
populateExpandFPowIPattern(patterns); | ||
} | ||
if (predicate(math::Exp2Op::getOperationName())) { | ||
populateExpandExp2FPattern(patterns); | ||
} | ||
if (predicate(math::RoundEvenOp::getOperationName())) { | ||
populateExpandRoundEvenPattern(patterns); | ||
} | ||
} | ||
|
||
static bool predicateRewrite(StringRef name, | ||
IREE::HAL::ExecutableTargetAttr target) { | ||
(void)target; // Currently unused. | ||
if (clNativeMathPrecision) { // Legacy. | ||
if (name == math::Exp2Op::getOperationName() || | ||
name == math::RoundEvenOp::getOperationName()) { | ||
return false; | ||
} | ||
} | ||
// Currently enable all non-approximative rewrites. | ||
return true; | ||
} | ||
|
||
static bool predicateF32Cast(StringRef name, | ||
IREE::HAL::ExecutableTargetAttr target) { | ||
(void)target; // Currently unused. | ||
if (clNativeMathPrecision) { // Legacy. | ||
return false; | ||
} | ||
StringRef atan = math::AtanOp::getOperationName(); | ||
StringRef atan2 = math::Atan2Op::getOperationName(); | ||
StringRef cos = math::CosOp::getOperationName(); | ||
StringRef sin = math::SinOp::getOperationName(); | ||
StringRef tanh = math::TanhOp::getOperationName(); | ||
StringRef log = math::LogOp::getOperationName(); | ||
StringRef log2 = math::Log2Op::getOperationName(); | ||
StringRef log1p = math::Log1pOp::getOperationName(); | ||
StringRef exp = math::ExpOp::getOperationName(); | ||
StringRef expm1 = math::ExpM1Op::getOperationName(); | ||
StringRef cbrt = math::CbrtOp::getOperationName(); | ||
StringRef erf = math::ErfOp::getOperationName(); | ||
return llvm::is_contained( | ||
{atan, atan2, tanh, log, log2, log1p, erf, exp, expm1, cbrt, sin, cos}, | ||
name); | ||
} | ||
|
||
static bool predicateApprox(StringRef name, | ||
IREE::HAL::ExecutableTargetAttr target) { | ||
(void)target; // Currently unused. | ||
if (clNativeMathPrecision) { // Legacy. | ||
if (name == math::ErfOp::getOperationName()) { | ||
// The legacy implementation had a bug: it always applied polynomial | ||
// approximation of math.erf, even when clNativeMathPrecision was passed. | ||
// We actually have CI tests that rely on that bug: they pass | ||
// clNativeMathPrecision but fail unless math.erf is approximated. | ||
return true; | ||
} | ||
return false; | ||
} | ||
StringRef acos = math::AcosOp::getOperationName(); | ||
StringRef asin = math::AsinOp::getOperationName(); | ||
StringRef atan = math::AtanOp::getOperationName(); | ||
StringRef atan2 = math::Atan2Op::getOperationName(); | ||
StringRef cos = math::CosOp::getOperationName(); | ||
StringRef sin = math::SinOp::getOperationName(); | ||
StringRef tanh = math::TanhOp::getOperationName(); | ||
StringRef log = math::LogOp::getOperationName(); | ||
StringRef log2 = math::Log2Op::getOperationName(); | ||
StringRef log1p = math::Log1pOp::getOperationName(); | ||
StringRef exp = math::ExpOp::getOperationName(); | ||
StringRef expm1 = math::ExpM1Op::getOperationName(); | ||
StringRef cbrt = math::CbrtOp::getOperationName(); | ||
StringRef erf = math::ErfOp::getOperationName(); | ||
return llvm::is_contained({atan, atan2, tanh, log, log2, log1p, erf, asin, | ||
acos, exp, expm1, cbrt, sin, cos}, | ||
name); | ||
} | ||
|
||
namespace { | ||
|
||
class MathTransformPass final | ||
: public impl::MathTransformPassBase<MathTransformPass> { | ||
public: | ||
using Base::Base; | ||
|
||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
auto target = IREE::HAL::ExecutableTargetAttr::lookup(getOperation()); | ||
if (!target) { | ||
return signalPassFailure(); | ||
} | ||
populateMathFunctionsRewritePatterns(patterns, [target](StringRef name) { | ||
return predicateRewrite(name, target); | ||
}); | ||
|
||
populateMathF32ExpansionPatterns(patterns, [target](StringRef name) { | ||
return predicateF32Cast(name, target); | ||
}); | ||
|
||
populateMathPolynomialApproximationPatterns( | ||
patterns, | ||
[target](StringRef name) { return predicateApprox(name, target); }); | ||
|
||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 0 additions & 78 deletions
78
compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
compiler/src/iree/compiler/Codegen/Common/test/math_transform.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-codegen-math-transform))' --split-input-file %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @rewrite_tan | ||
func.func @rewrite_tan(%arg0: f16) -> f16 attributes { | ||
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}> | ||
} { | ||
// Tan should be directly approximated by a rational function. It's also possible | ||
// (though not good) that it gets rewritten as sin/cos and those get approximated by | ||
// rational functions. Either way, we expect to see rational arithmetic here, on f32 | ||
// as the operands get casted to f32. | ||
// CHECK-NOT: math.tan | ||
// CHECK-NOT: math.sin | ||
// CHECK-NOT: math.cos | ||
// CHECK: math.fma {{.*}} : f32 | ||
// Final division after cast to f16. | ||
// CHECK: arith.divf {{.*}} : f16 | ||
%0 = math.tan %arg0 : f16 | ||
return %0 : f16 | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @rewrite_pow | ||
func.func @rewrite_pow(%arg0: f16, %arg1: f16) -> f16 attributes { | ||
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}> | ||
} { | ||
|
||
// Powf should be either directly approximated, or first rewritten into log and | ||
// exp and then those get approximated. Some targets with fast exponentials might | ||
// prefer to keep the exponential form, but this is not the case with the current | ||
// lowering for CPU, so we expect to see rational arithmetic here, on f32 as the | ||
// operands get casted to f32. | ||
// CHECK-NOT: math.powf | ||
// CHECK-NOT: math.exp | ||
// CHECK-NOT: math.log | ||
// CHECK: math.fma {{.*}} : f32 | ||
%0 = math.powf %arg0, %arg1 : f16 | ||
return %0 : f16 | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @rewrite_erf | ||
func.func @rewrite_erf(%arg0: f16) -> f16 attributes { | ||
hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}> | ||
} { | ||
// Erf should be directly approximated by a rational function. Some targets | ||
// with fast exponentials might prefer an exponential approximation, but this | ||
// is not the case with the current lowering for CPU, so we expect to see rational | ||
// arithmetic here, on f32 as the operands get casted to f32. | ||
// CHECK-NOT: math.erf | ||
// CHECK-NOT: math.exp | ||
// CHECK-NOT: math.log | ||
// CHECK: math.fma {{.*}} : f32 | ||
%0 = math.erf %arg0 : f16 | ||
return %0 : f16 | ||
} |
17 changes: 0 additions & 17 deletions
17
compiler/src/iree/compiler/Codegen/Common/test/polynomial_approximation.mlir
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.