Skip to content
Merged
62 changes: 62 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,64 @@ gentbl_cc_library(
],
)

td_library(
name = "PerfifyDialectFiles",
srcs = [
"Dialect/Perfify/Dialect.td",
"Dialect/Perfify/Ops.td",
],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "PerfifyDialectIncGen",
tbl_outs = [
(
[
"-gen-dialect-decls",
"-dialect=perfify",
],
"Dialect/Perfify/PerfifyDialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=perfify",
],
"Dialect/Perfify/PerfifyDialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Dialect/Perfify/Dialect.td",
deps = [
":PerfifyDialectFiles"
],
)

gentbl_cc_library(
name = "PerfifyOpsIncGen",
tbl_outs = [
(
["-gen-op-decls"],
"Dialect/Perfify/PerfifyOps.h.inc",
),
(
["-gen-op-defs"],
"Dialect/Perfify/PerfifyOps.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Dialect/Perfify/Ops.td",
deps = [
":PerfifyDialectFiles"
],
)

cc_library(
name = "CheckedRewrite",
hdrs = ["CheckedRewrite.h"],
Expand Down Expand Up @@ -721,6 +779,7 @@ cc_library(
"Dialect/*.cpp",
"Dialect/Distributed/*.cpp",
"Dialect/Tessera/*.cpp",
"Dialect/Perfify/*.cpp",
]) + [
"Utils.cpp",
],
Expand All @@ -730,6 +789,7 @@ cc_library(
"Dialect/*.h",
"Dialect/Distributed/*.h",
"Dialect/Tessera/*.h",
"Dialect/Perfify/*.h",
]) + [
"Utils.h",
],
Expand Down Expand Up @@ -758,6 +818,8 @@ cc_library(
":StablehloOptPatternsIncGen",
":TesseraDialectIncGen",
":TesseraOpsIncGen",
":PerfifyDialectIncGen",
":PerfifyOpsIncGen",
":chlo-derivatives",
":enzymexla-derivatives",
":mhlo-derivatives",
Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "Dialect.h"

#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"

#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyDialect.cpp.inc"

// Initialize the dialect
void mlir::enzyme::perfify::PerfifyDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.cpp.inc"
>();
}
19 changes: 19 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Dialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H
#define ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Types.h"

// Include the dialect
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyDialect.h.inc"

// Operations
#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.h.inc"

#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_H
32 changes: 32 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Dialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD
#define ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/Traits.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Perfify dialect definition.
//===----------------------------------------------------------------------===//

def PerfifyDialect : Dialect {
let name = "perfify";
let summary = "A dialect for specifying and proving runtime bounds";
let description = [{
Lets users specify a bound on the number of steps/latency (per a predefined cost model) that a function or other operation should take.
Leverages SAT solvers to automatically prove this, or interactive theorem provers to allow for complete proofs.
}];
let cppNamespace = "::mlir::enzyme::perfify";
}

//===----------------------------------------------------------------------===//
// Base Perfify operation definition.
//===----------------------------------------------------------------------===//

class PerfifyOp<string mnemonic, list<Trait> traits = []>
: Op<PerfifyDialect, mnemonic, traits>;

class PerfifyType<string name> : TypeDef<PerfifyDialect, name>; // may need to be modified

#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_DIALECT_TD
15 changes: 15 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"

#include "Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;
using namespace mlir::enzyme::perfify;

namespace mlir::perfify {} // namespace mlir::perfify

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/Dialect/Perfify/PerfifyOps.cpp.inc"
50 changes: 50 additions & 0 deletions src/enzyme_ad/jax/Dialect/Perfify/Ops.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD
#define ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD

include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "Dialect.td"

// Perfify.cost op
def CostOp : PerfifyOp<"cost", []> {
// summary
// description
// arguments
let arguments = (ins StrAttr:$target_op,
APIntAttr:$cycle_cost);
let assemblyFormat = "$target_op $cycle_cost attr-dict";

}

def ArgOp : PerfifyOp<"arg", []> {
let arguments = (ins I64Attr:$val);
let assemblyFormat = "$val attr-dict";
let results = (outs I64);
}

def AssumeOp : PerfifyOp<"assume", [HasParent<"ConditionsOp">, Terminator]> {
let arguments = (ins I1:$precondition);
let assemblyFormat = "$precondition attr-dict";
}

def ConditionsOp : PerfifyOp<"conditions", [HasParent<"AssumptionsOp">, Terminator]> {
let arguments = (ins FlatSymbolRefAttr:$func_handle,
BoolAttr:$verify_huh);
let regions = (region AnyRegion:$precondition, AnyRegion:$postcondition);

let assemblyFormat = [{
$func_handle $verify_huh attr-dict `pre`
$precondition
`post`
$postcondition
}];
}

def AssumptionsOp : PerfifyOp<"assumptions", [Terminator]> {
let regions = (region AnyRegion:$body);
let assemblyFormat = [{$body attr-dict}];
}

#endif // ENZYME_AD_JAX_DIALECT_PERFIFY_OPS_TD
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/RegistryUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "src/enzyme_ad/jax/Dialect/Distributed/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Perfify/Dialect.h"
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"

#include "shardy/dialect/sdy/ir/dialect.h"
Expand Down Expand Up @@ -208,6 +209,7 @@ void registerDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::enzymexla::EnzymeXLADialect>();
registry.insert<mlir::enzyme::distributed::DistributedDialect>();
registry.insert<mlir::enzyme::tessera::TesseraDialect>();
registry.insert<mlir::enzyme::perfify::PerfifyDialect>();
registry.insert<mlir::sdy::SdyDialect>();
registry.insert<mlir::ub::UBDialect>();
registry.insert<mlir::triton::TritonDialect>();
Expand Down
26 changes: 26 additions & 0 deletions test/lit_tests/perfify/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module {
func.func @foo() {func.return}
perfify.assumptions { // operation in the dialect
perfify.cost "arith.mul" 3 // op
perfify.cost "func.return" 0
perfify.cost "scf.yield" 0


perfify.conditions @foo true pre {
%b0 = perfify.arg 0 // op
%c0 = arith.constant 0
%cmp = arith.cmpi eq, %c0, %b0 : i64
perfify.assume %cmp
} post {
// %cost = perfify.fn_cost : perfify.cost
// %c9 = perfify.constant_cost 9 : perfify.cost // then our cost is 9
// %cmp = arith.cmpi eq, %cost, %c9
%b0 = perfify.arg 0 // op
%c0 = arith.constant 0
%cmp = arith.cmpi eq, %c0, %b0 : i64

perfify.assume %cmp
}

}
}