diff --git a/.gitignore b/.gitignore index 6b516f7..91e7cb5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ -# +# Custom /build +/.vscode # Prerequisites *.d diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ee9b49..9bbf487 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,187 +1,5 @@ -# # cmake_minimum_required(VERSION 3.16) -# # project(ellvm CXX) - -# # set(CMAKE_CXX_STANDARD 17) -# # set(CMAKE_CXX_STANDARD_REQUIRED ON) -# # #set(LLVM_DIR "/home/user/miniconda3/envs/tvm5/lib/cmake/llvm") -# # set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -# # find_package(LLVM REQUIRED CONFIG) - -# # link_directories(${LLVM_LIBRARY_DIRS}) - -# # message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") -# # message(STATUS "LLVM DIR: ${LLVM_DIR}") -# # message(STATUS "LLVM include dirs: ${LLVM_INCLUDE_DIRS}") - - -# # include_directories( -# # ${LLVM_INCLUDE_DIRS} -# # ${LLVM_BUILD_INCLUDE_DIR} -# # ${CMAKE_SOURCE_DIR}/include -# # ) - - -# # add_definitions(${LLVM_DEFINITIONS}) - -# # file(GLOB SOURCES "${CMAKE_SOURCE_DIR}/src/*.cpp") -# # add_executable(ellvm ${SOURCES}) - - - -# # target_link_libraries(ellvm -# # PRIVATE -# # LLVMCore -# # LLVMSupport -# # LLVMIRReader -# # LLVMAsmPrinter -# # LLVMAsmParser -# # LLVMBitWriter -# # LLVMExecutionEngine -# # LLVMOrcJIT -# # LLVMTarget -# # LLVMX86CodeGen -# # LLVMX86AsmParser -# # LLVMX86Desc -# # LLVMX86Info -# # ) - - -# # target_compile_options(ellvm PRIVATE -# # -Wall # Enable all common warnings -# # -Wextra # Extra warnings -# # -Wno-unused-parameter # Suppress unused parameter warnings -# # -Wunused-variable # Warn about unused variables -# # -Wreturn-type # Warn if a function might not return a value -# # -Wno-redundant-move # Suppress LLVM redundant move warning -# # -g # Include debug information -# # ) - - -# cmake_minimum_required(VERSION 3.13.4) -# project(standalone_mlir LANGUAGES CXX C) - -# set(CMAKE_CXX_STANDARD 17) -# set(CMAKE_CXX_STANDARD_REQUIRED YES) - -# set(LLVM_DIR "/home/hamza/Repos/llvm-project/build/lib/cmake/llvm") -# set(MLIR_DIR "/home/hamza/Repos/llvm-project/build/lib/cmake/mlir") - - -# set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -# find_package(LLVM REQUIRED CONFIG) -# find_package(MLIR REQUIRED CONFIG) - -# list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") - -# include(AddLLVM) -# include(AddMLIR) -# include(TableGen) -# include(HandleLLVMOptions) - -# set(LLVM_LINK_COMPONENTS Core Support nativecodegen OrcJIT) - -# include_directories(include) - -# include_directories( -# ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/include -# ${CMAKE_CURRENT_BINARY_DIR}/include # For TableGen-generated .inc files -# ${CMAKE_CURRENT_BINARY_DIR}) - -# add_definitions(${LLVM_DEFINITIONS}) - -# add_subdirectory(include) - -# set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) -# mlir_tablegen(ToyCombine.inc -gen-rewriters) -# add_public_tablegen_target(CombineIncGen) - - -# # file( -# # GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp" -# # ) -# add_executable(standalone_mlir -# main.cpp -# parser/*.cpp -# mlir/*.cpp -# ) - -# #add_executable(standalone_mlir ${CCMLIR_SRC}) - -# # add_dependencies(standalone_mlir Ops ShapeInfer CombineIncGen) -# add_dependencies(standalone_mlir ToyCh7OpsIncGen ToyCh7ShapeInferenceInterfaceIncGen CombineIncGen) - -# llvm_update_compile_flags(standalone_mlir) - -# get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -# get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -# get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) - -# target_link_libraries( -# standalone_mlir -# PRIVATE ${dialect_libs} -# ${conversion_libs} -# ${extension_libs} -# LLVMSupport -# MLIRAnalysis -# MLIRBuiltinToLLVMIRTranslation -# MLIRCallInterfaces -# MLIRCastInterfaces -# MLIRExecutionEngine -# MLIRFunctionInterfaces -# MLIRIR -# MLIRLLVMCommonConversion -# MLIRLLVMDialect -# MLIRLLVMToLLVMIRTranslation -# MLIRMemRefDialect -# MLIRParser -# MLIRPass -# MLIRSideEffectInterfaces -# MLIRSupport -# MLIRTargetLLVMIRExport -# MLIRTransforms -# MLIRLinalgDialect -# MLIRTensorDialect -# MLIRArithDialect -# MLIRIR -# MLIRSupport -# MLIRParser -# MLIRAnalysis -# MLIRPass -# MLIRTransforms -# MLIRFuncDialect -# MLIRArithDialect -# MLIRMemRefDialect -# MLIRLinalgDialect -# MLIRAffineDialect -# MLIRSCFDialect -# MLIRLLVMDialect -# MLIRTensorDialect -# MLIRFuncTransforms -# MLIRLinalgTransforms -# MLIRAffineToStandard -# MLIRSCFToControlFlow -# MLIRArithToLLVM -# MLIRMemRefToLLVM -# MLIRFuncToLLVM -# MLIRLLVMCommonConversion -# MLIRLLVMToLLVMIRTranslation -# MLIRTargetLLVMIRExport -# MLIRBuiltinToLLVMIRTranslation -# MLIRTranslateLib -# MLIRExecutionEngine -# MLIRCallInterfaces -# MLIRCastInterfaces -# MLIRFunctionInterfaces -# MLIRSideEffectInterfaces -# LLVMSupport) - -# install(TARGETS standalone_mlir RUNTIME DESTINATION bin) - - cmake_minimum_required(VERSION 3.13.4) -project(standalone_mlir LANGUAGES CXX C) +project(mlp_mlir LANGUAGES CXX C) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED YES) @@ -189,9 +7,9 @@ set(CMAKE_CXX_STANDARD_REQUIRED YES) set(LLVM_DIR "/home/hamza/Repos/llvm-project/build/lib/cmake/llvm") set(MLIR_DIR "/home/hamza/Repos/llvm-project/build/lib/cmake/mlir") - set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + find_package(LLVM REQUIRED CONFIG) find_package(MLIR REQUIRED CONFIG) @@ -215,40 +33,35 @@ add_definitions(${LLVM_DEFINITIONS}) add_subdirectory(include) -set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) -mlir_tablegen(ToyCombine.inc -gen-rewriters) -add_public_tablegen_target(CombineIncGen) -# file( -# GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp" -# ) -add_executable(standalone_mlir - main.cpp - parser/AST.cpp - mlir/MLIRGen.cpp - mlir/Dialect.cpp - mlir/LowerToAffineLoops.cpp - mlir/LowerToLLVM.cpp - mlir/ShapeInferencePass.cpp - mlir/ToyCombine.cpp - ) +set(LLVM_TARGET_DEFINITIONS src/MlpCombine.td) +mlir_tablegen(MlpCombine.inc -gen-rewriters) +add_public_tablegen_target(CombineIncGen) +file( + GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp" + # GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/targets/cpu/*.cpp" + # GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/targets/gpu/*.cpp" + # GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/targets/metal/*.cpp" + # GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/targets/cpu/*.cpp" + # GLOB CCMLIR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/targets/rocm/*.cpp" +) +add_executable(mlp_mlir ${CCMLIR_SRC}) -#add_executable(standalone_mlir ${CCMLIR_SRC}) -# add_dependencies(standalone_mlir Ops ShapeInfer CombineIncGen) -add_dependencies(standalone_mlir ToyCh7OpsIncGen ToyCh7ShapeInferenceInterfaceIncGen CombineIncGen) +# add_dependencies(mlp_mlir Ops ShapeInfer CombineIncGen) +add_dependencies(mlp_mlir OpsIncGen ShapeInferenceInterfaceIncGen CombineIncGen) -llvm_update_compile_flags(standalone_mlir) +llvm_update_compile_flags(mlp_mlir) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries( - standalone_mlir + mlp_mlir PRIVATE ${dialect_libs} ${conversion_libs} ${extension_libs} @@ -300,10 +113,20 @@ target_link_libraries( MLIRBuiltinToLLVMIRTranslation MLIRTranslateLib MLIRExecutionEngine + MLIRArithTransforms MLIRCallInterfaces + MLIRBufferizationTransforms MLIRCastInterfaces MLIRFunctionInterfaces MLIRSideEffectInterfaces + MLIRArithDialect + MLIRFuncDialect + MLIRTensorDialect + MLIRLinalgDialect + MLIRBufferizationDialect + MLIRMemRefDialect + MLIRSCFDialect + LLVMSupport) -install(TARGETS standalone_mlir RUNTIME DESTINATION bin) \ No newline at end of file +install(TARGETS mlp_mlir RUNTIME DESTINATION bin) \ No newline at end of file diff --git a/include/Builder.h b/include/Builder.h new file mode 100644 index 0000000..ce10b4b --- /dev/null +++ b/include/Builder.h @@ -0,0 +1,30 @@ +#ifndef BUILDER_H +#define BUILDER_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" + +using namespace mlir; + +namespace builder { + +// Basic examples +func::FuncOp createMainFunction(MLIRContext &ctx, ModuleOp module); + +func::FuncOp createMainFunction(MLIRContext &ctx, ModuleOp module); + +func::FuncOp createAddFunction(MLIRContext &ctx, ModuleOp module); + +func::FuncOp createMulFunction(MLIRContext &ctx, ModuleOp module); + +// MLP examples +func::FuncOp createMLPAddFunction(MLIRContext &ctx, ModuleOp module); + +func::FuncOp createMLPReluFunction(MLIRContext &ctx, ModuleOp module); + +func::FuncOp createMLPLinearFunction(MLIRContext &ctx, ModuleOp module); + +} // namespace builder + +#endif // TOY_BUILDER_H diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index 37c89d0..142e3e8 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1 +1,13 @@ -add_subdirectory(toy) +# Most dialects should use add_mlir_dialect(). See examples/standalone. +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_public_tablegen_target(OpsIncGen) + +# Most dialects should use add_mlir_interfaces(). +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ShapeInferenceInterfaceIncGen) diff --git a/include/toy/Dialect.h b/include/Dialect.h similarity index 54% rename from include/toy/Dialect.h rename to include/Dialect.h index 64094c3..726e461 100644 --- a/include/toy/Dialect.h +++ b/include/Dialect.h @@ -1,4 +1,4 @@ -//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// +//===- Dialect.h - Dialect definition for the Mlp IR ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,14 +6,15 @@ // //===----------------------------------------------------------------------===// // -// This file implements the IR Dialect for the Toy language. -// See docs/Tutorials/Toy/Ch-2.md for more information. +// This file implements the IR Dialect for the Mlp language. +// See docs/Tutorials/Ch-2.md for more information. // //===----------------------------------------------------------------------===// -#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ -#define MLIR_TUTORIAL_TOY_DIALECT_H_ +#ifndef MLIR_TUTORIAL_MLP_DIALECT_H_ +#define MLIR_TUTORIAL_MLP_DIALECT_H_ +#include "ShapeInferenceInterface.h" #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -22,61 +23,63 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "toy/ShapeInferenceInterface.h" namespace mlir { -namespace toy { +namespace mlp { + namespace detail { + struct StructTypeStorage; } // namespace detail -} // namespace toy +} // namespace mlp } // namespace mlir -/// Include the auto-generated header file containing the declaration of the toy +/// Include the auto-generated header file containing the declaration of the mlp /// dialect. -#include "toy/Dialect.h.inc" +#include "Dialect.h.inc" //===----------------------------------------------------------------------===// -// Toy Operations +// Mlp Operations //===----------------------------------------------------------------------===// /// Include the auto-generated header file containing the declarations of the -/// toy operations. +/// Mlp operations. #define GET_OP_CLASSES -#include "toy/Ops.h.inc" +#include "Ops.h.inc" namespace mlir { -namespace toy { +namespace mlp { //===----------------------------------------------------------------------===// -// Toy Types +// Mlp Types //===----------------------------------------------------------------------===// -/// This class defines the Toy struct type. It represents a collection of +/// This class defines the Mlp struct type. It represents a collection of /// element types. All derived types in MLIR must inherit from the CRTP class /// 'Type::TypeBase'. It takes as template parameters the concrete type /// (StructType), the base class to use (Type), and the storage class /// (StructTypeStorage). -class StructType : public mlir::Type::TypeBase { -public: - /// Inherit some necessary constructors from 'TypeBase'. - using Base::Base; - - /// Create an instance of a `StructType` with the given element types. There - /// *must* be atleast one element type. - static StructType get(llvm::ArrayRef elementTypes); - - /// Returns the element types of this struct type. - llvm::ArrayRef getElementTypes(); - - /// Returns the number of element type held by this struct. - size_t getNumElementTypes() { return getElementTypes().size(); } - - /// The name of this struct type. - static constexpr StringLiteral name = "toy.struct"; -}; -} // namespace toy +// class StructType : public mlir::Type::TypeBase { +// public: +// /// Inherit some necessary constructors from 'TypeBase'. +// using Base::Base; + +// /// Create an instance of a `StructType` with the given element types. +// There +// /// *must* be atleast one element type. +// static StructType get(llvm::ArrayRef elementTypes); + +// /// Returns the element types of this struct type. +// llvm::ArrayRef getElementTypes(); + +// /// Returns the number of element type held by this struct. +// size_t getNumElementTypes() { return getElementTypes().size(); } + +// /// The name of this struct type. +// static constexpr StringLiteral name = "Mlp.struct"; +// }; +} // namespace mlp } // namespace mlir -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ +#endif // MLIR_TUTORIAL_MLP_DIALECT_H_ diff --git a/include/Jit.h b/include/Jit.h new file mode 100644 index 0000000..1b094fd --- /dev/null +++ b/include/Jit.h @@ -0,0 +1,63 @@ +//===- Jit.h - MLP Jit -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef JIT_H +#define JIT_H + +#include "Dialect.h" +#include "Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace jit { + +int runJit(mlir::ModuleOp module); +} // namespace jit + +#endif // MLP_PARSER_H diff --git a/include/Ops.td b/include/Ops.td new file mode 100644 index 0000000..ce7913d --- /dev/null +++ b/include/Ops.td @@ -0,0 +1,488 @@ +//===- Ops.td - MLP dialect operation definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the operations of the MLP dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLP_OPS +#define MLP_OPS + +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "ShapeInferenceInterface.td" + +// Provide a definition of the 'mlp' dialect in the ODS framework so that we +// can define our operations. +def MLP_Dialect : Dialect { + let name = "mlp"; + let cppNamespace = "::mlir::mlp"; + + // We set this bit to generate a declaration of the `materializeConstant` + // method so that we can materialize constants for our mlp operations. + //let hasConstantMaterializer = 1; Original is 1 + let hasConstantMaterializer = 0; + + // We set this bit to generate the declarations for the dialect's type parsing + // and printing hooks. + let useDefaultTypePrinterParser = 1; + +} + +// Base class for mlp dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class MLP_OP traits = []> : + Op; + +// Provide a definition for the MLP StructType for use in ODS. This allows for +// using StructType in a similar way to Tensor or MemRef. We use `DialectType` +// to demarcate the StructType as belonging to the MLP dialect. +def MLP_StructType : + DialectType($_self)">, + "MLP struct type">; + +// Provide a definition of the types that are used within the MLP dialect. +def MLP_Type : AnyTypeOf<[F64Tensor, MLP_StructType]>; + +//===----------------------------------------------------------------------===// +// MLP Operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +// We define a MLP operation by inheriting from our base 'MLP_OP' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'Pure' as it is a pure operation +// and may be removed if dead. +def ConstantOp : MLP_OP<"constant", + [Pure]> {// ConstantLike,Pure + // Provide a summary and description for this operation. This can be used to + // auto-generate documentation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = MLP.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> + : tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `ConstantOp::create(builder, ...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<(ins "DenseElementsAttr":$value), [{ + build($_builder, $_state, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<(ins "double":$value)> + ]; + + // Indicate that additional verification for this operation is necessary. + let hasVerifier = 1; + + // Set the folder bit so that we can implement constant folders. + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +def AddOp : MLP_OP<"add", + [Pure]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs)> + ]; +} +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +// def MulOp : MLP_OP<"mul", +// [Pure, DeclareOpInterfaceMethods]> { +// let summary = "element-wise multiplication operation"; +// let description = [{ +// The "mul" operation performs element-wise multiplication between two +// tensors. The shapes of the tensor operands are expected to match. +// }]; + +// let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); +// let results = (outs F64Tensor); + +// // Indicate that the operation has a custom parser and printer method. +// let hasCustomAssemblyFormat = 1; + +// // Allow building a MulOp with from the two input operands. +// let builders = [ +// OpBuilder<(ins "Value":$lhs, "Value":$rhs)> +// ]; +// } + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +def ReluOp : MLP_OP<"relu", + [Pure]> { + + let summary = "element-wise ReLU operation"; + let description = [{ + The `relu` operation applies the Rectified Linear Unit function + element-wise to the input tensor: + + relu(x) = max(0, x) + + The input and output tensor shapes are identical. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + + // Custom printer/parser (optional, but consistent with AddOp) + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; +} + +// ===----------------------------------------------------------------------===// +// LeakyReLU +// ===----------------------------------------------------------------------===// +def LeakyReluOp : MLP_OP<"leaky_relu", [Pure]> { + let summary = "leaky rectified linear unit"; + let description = [{ + Applies max(x, alpha * x) element-wise. + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + + +} + +// ===----------------------------------------------------------------------===// +// ELU +// ===----------------------------------------------------------------------===// +def EluOp : MLP_OP<"elu", [Pure]> { + let summary = "exponential linear unit"; + let description = [{ + Applies x if x > 0, otherwise alpha * (exp(x) - 1). + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +// ===----------------------------------------------------------------------===// +// Sigmoid +// ===----------------------------------------------------------------------===// +def SigmoidOp : MLP_OP<"sigmoid", [Pure]> { + let summary = "sigmoid activation"; + let description = [{ + Applies 1 / (1 + exp(-x)) element-wise. + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +// ===----------------------------------------------------------------------===// +// Tanh +// ===----------------------------------------------------------------------===// +def TanhOp : MLP_OP<"tanh", [Pure]> { + let summary = "hyperbolic tangent activation"; + let description = [{ + Applies tanh(x) element-wise. + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +// ===----------------------------------------------------------------------===// +// Softmax +// ===----------------------------------------------------------------------===// +def SoftmaxOp : MLP_OP<"softmax", [Pure]> { + let summary = "softmax normalization operation"; + let description = [{ + Normalizes values along the last dimension so that they sum to 1. + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +// ===----------------------------------------------------------------------===// +// GELU +// ===----------------------------------------------------------------------===// +def GeluOp : MLP_OP<"gelu", [Pure]> { + let summary = "Gaussian Error Linear Unit"; + let description = [{ + Applies GELU(x) = x * 0.5 * (1 + erf(x / sqrt(2))). + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +// ===----------------------------------------------------------------------===// +// Swish +// ===----------------------------------------------------------------------===// +def SwishOp : MLP_OP<"swish", [Pure]> { + let summary = "swish activation"; + let description = [{ + Applies x * sigmoid(x) element-wise. + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +// ===----------------------------------------------------------------------===// +// Mish +// ===----------------------------------------------------------------------===// +def MishOp : MLP_OP<"mish", [Pure]> { + let summary = "mish activation"; + let description = [{ + Applies x * tanh(softplus(x)) element-wise. + }]; + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$result); + let hasCustomAssemblyFormat = 1; + let builders = [ OpBuilder<(ins "Value":$input)> ]; + +} + +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +def LinearOp : MLP_OP<"linear", + [Pure]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$input, F64Tensor:$weight); + let results = (outs F64Tensor); + + // Indicate that the operation has a custom parser and printer method. + let hasCustomAssemblyFormat = 1; + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$weight)> + ]; +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +// def CastOp : MLP_OP<"cast", [ +// DeclareOpInterfaceMethods, +// DeclareOpInterfaceMethods, +// Pure, +// SameOperandsAndResultShape +// ]> { +// let summary = "shape cast operation"; +// let description = [{ +// The "cast" operation converts a tensor from one type to an equivalent type +// without changing any data elements. The source and destination types must +// both be tensor types with the same element type. If both are ranked, then +// shape is required to match. The operation is invalid if converting to a +// mismatching constant dimension. +// }]; + +// let arguments = (ins F64Tensor:$input); +// let results = (outs F64Tensor:$output); + +// let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +// } + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +def FuncOp : MLP_OP<"func", [ + FunctionOpInterface, IsolatedFromAbove + ]> { + let summary = "user defined function operation"; + let description = [{ + The "mlp.func" operation represents a user defined function. These are + callable SSA-region operations that contain mlp computations. + + Example: + + ```mlir + mlp.func @main() { + %0 = mlp.constant dense<5.500000e+00> : tensor + %1 = mlp.reshape(%0 : tensor) to tensor<2x2xf64> + mlp.print %1 : tensor<2x2xf64> + mlp.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + Region *getCallableRegion() { return &getBody(); } + }]; + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + + + +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +def PrintOp : MLP_OP<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); + + let assemblyFormat = "$input attr-dict `:` type($input)"; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +// def ReshapeOp : MLP_OP<"reshape", [Pure]> { +// let summary = "tensor reshape operation"; +// let description = [{ +// Reshape operation is transforming its input tensor into a new tensor with +// the same number of elements but different shapes. For example: + +// ```mlir +// %0 = mlp.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> +// ``` +// }]; + +// let arguments = (ins F64Tensor:$input); + +// let assemblyFormat = [{ +// `(` $input `:` type($input) `)` attr-dict `to` type(results) +// }]; + +// // Enable registering canonicalization patterns with this operation. +// let hasCanonicalizer = 1; + +// // We expect that the reshape operation returns a statically shaped tensor. +// let results = (outs StaticShapeTensorOf<[F64]>); +// } + + +// //===----------------------------------------------------------------------===// +// // TransposeOp +// //===----------------------------------------------------------------------===// + +// def TransposeOp : MLP_OP<"transpose", +// [Pure, DeclareOpInterfaceMethods]> { +// let summary = "transpose operation"; + +// let arguments = (ins F64Tensor:$input); +// let results = (outs F64Tensor); + +// let assemblyFormat = [{ +// `(` $input `:` type($input) `)` attr-dict `to` type(results) +// }]; + +// // Enable registering canonicalization patterns with this operation. +// let hasCanonicalizer = 1; + +// // Allow building a TransposeOp with from the input operand. +// let builders = [ +// OpBuilder<(ins "Value":$input)> +// ]; + +// // Indicate that additional verification for this operation is necessary. +// let hasVerifier = 1; +// } + +#endif // MLP_OPS diff --git a/include/toy/Passes.h b/include/Passes.h similarity index 85% rename from include/toy/Passes.h rename to include/Passes.h index 62471dd..0b04892 100644 --- a/include/toy/Passes.h +++ b/include/Passes.h @@ -10,15 +10,15 @@ // //===----------------------------------------------------------------------===// -#ifndef TOY_PASSES_H -#define TOY_PASSES_H +#ifndef PASSES_H +#define PASSES_H #include namespace mlir { class Pass; -namespace toy { +namespace mlp { std::unique_ptr createShapeInferencePass(); /// Create a pass for lowering to operations in the `Affine` and `Std` dialects, @@ -29,7 +29,10 @@ std::unique_ptr createLowerToAffinePass(); /// well as `Affine` and `Std`, to the LLVM dialect for codegen. std::unique_ptr createLowerToLLVMPass(); -} // namespace toy +// Create pass mlp-to-linalg +std::unique_ptr createLowerToLinalgPass(); + +} // namespace mlp } // namespace mlir -#endif // TOY_PASSES_H +#endif // MLP_PASSES_H diff --git a/include/toy/ShapeInferenceInterface.h b/include/ShapeInferenceInterface.h similarity index 73% rename from include/toy/ShapeInferenceInterface.h rename to include/ShapeInferenceInterface.h index cfe5a87..da9cc1e 100644 --- a/include/toy/ShapeInferenceInterface.h +++ b/include/ShapeInferenceInterface.h @@ -11,18 +11,18 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ -#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#ifndef MLIR_TUTORIAL_MLP_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_MLP_SHAPEINFERENCEINTERFACE_H_ #include "mlir/IR/OpDefinition.h" namespace mlir { -namespace toy { +namespace mlp { /// Include the auto-generated declarations. -#include "toy/ShapeInferenceOpInterfaces.h.inc" +#include "ShapeInferenceOpInterfaces.h.inc" -} // namespace toy +} // namespace mlp } // namespace mlir -#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#endif // MLIR_TUTORIAL_MLP_SHAPEINFERENCEINTERFACE_H_ diff --git a/include/toy/ShapeInferenceInterface.td b/include/ShapeInferenceInterface.td similarity index 100% rename from include/toy/ShapeInferenceInterface.td rename to include/ShapeInferenceInterface.td diff --git a/include/toy/AST.h b/include/toy/AST.h deleted file mode 100644 index 42d64ed..0000000 --- a/include/toy/AST.h +++ /dev/null @@ -1,313 +0,0 @@ -//===- AST.h - Node definition for the Toy AST ----------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the AST for the Toy language. It is optimized for -// simplicity, not efficiency. The AST forms a tree structure where each node -// references its children using std::unique_ptr<>. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_AST_H -#define TOY_AST_H - -#include "toy/Lexer.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include -#include -#include - -namespace toy { - -/// A variable type with either name or shape information. -struct VarType { - std::string name; - std::vector shape; -}; - -/// Base class for all expression nodes. -class ExprAST { -public: - enum ExprASTKind { - Expr_VarDecl, - Expr_Return, - Expr_Num, - Expr_Literal, - Expr_StructLiteral, - Expr_Var, - Expr_BinOp, - Expr_Call, - Expr_Print, - }; - - ExprAST(ExprASTKind kind, Location location) - : kind(kind), location(std::move(location)) {} - virtual ~ExprAST() = default; - - ExprASTKind getKind() const { return kind; } - - const Location &loc() { return location; } - -private: - const ExprASTKind kind; - Location location; -}; - -/// A block-list of expressions. -using ExprASTList = std::vector>; - -/// Expression class for numeric literals like "1.0". -class NumberExprAST : public ExprAST { - double val; - -public: - NumberExprAST(Location loc, double val) - : ExprAST(Expr_Num, std::move(loc)), val(val) {} - - double getValue() { return val; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } -}; - -/// Expression class for a literal value. -class LiteralExprAST : public ExprAST { - std::vector> values; - std::vector dims; - -public: - LiteralExprAST(Location loc, std::vector> values, - std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), - dims(std::move(dims)) {} - - llvm::ArrayRef> getValues() { return values; } - llvm::ArrayRef getDims() { return dims; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } -}; - -/// Expression class for a literal struct value. -class StructLiteralExprAST : public ExprAST { - std::vector> values; - -public: - StructLiteralExprAST(Location loc, - std::vector> values) - : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { - } - - llvm::ArrayRef> getValues() { return values; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { - return c->getKind() == Expr_StructLiteral; - } -}; - -/// Expression class for referencing a variable, like "a". -class VariableExprAST : public ExprAST { - std::string name; - -public: - VariableExprAST(Location loc, llvm::StringRef name) - : ExprAST(Expr_Var, std::move(loc)), name(name) {} - - llvm::StringRef getName() { return name; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } -}; - -/// Expression class for defining a variable. -class VarDeclExprAST : public ExprAST { - std::string name; - VarType type; - std::unique_ptr initVal; - -public: - VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, - std::unique_ptr initVal = nullptr) - : ExprAST(Expr_VarDecl, std::move(loc)), name(name), - type(std::move(type)), initVal(std::move(initVal)) {} - - llvm::StringRef getName() { return name; } - ExprAST *getInitVal() { return initVal.get(); } - const VarType &getType() { return type; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } -}; - -/// Expression class for a return operator. -class ReturnExprAST : public ExprAST { - std::optional> expr; - -public: - ReturnExprAST(Location loc, std::optional> expr) - : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} - - std::optional getExpr() { - if (expr.has_value()) - return expr->get(); - return std::nullopt; - } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } -}; - -/// Expression class for a binary operator. -class BinaryExprAST : public ExprAST { - char op; - std::unique_ptr lhs, rhs; - -public: - char getOp() { return op; } - ExprAST *getLHS() { return lhs.get(); } - ExprAST *getRHS() { return rhs.get(); } - - BinaryExprAST(Location loc, char op, std::unique_ptr lhs, - std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), - rhs(std::move(rhs)) {} - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } -}; - -/// Expression class for function calls. -class CallExprAST : public ExprAST { - std::string callee; - std::vector> args; - -public: - CallExprAST(Location loc, const std::string &callee, - std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), callee(callee), - args(std::move(args)) {} - - llvm::StringRef getCallee() { return callee; } - llvm::ArrayRef> getArgs() { return args; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } -}; - -/// Expression class for builtin print calls. -class PrintExprAST : public ExprAST { - std::unique_ptr arg; - -public: - PrintExprAST(Location loc, std::unique_ptr arg) - : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} - - ExprAST *getArg() { return arg.get(); } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } -}; - -/// This class represents the "prototype" for a function, which captures its -/// name, and its argument names (thus implicitly the number of arguments the -/// function takes). -class PrototypeAST { - Location location; - std::string name; - std::vector> args; - -public: - PrototypeAST(Location location, const std::string &name, - std::vector> args) - : location(std::move(location)), name(name), args(std::move(args)) {} - - const Location &loc() { return location; } - llvm::StringRef getName() const { return name; } - llvm::ArrayRef> getArgs() { return args; } -}; - -/// This class represents a top level record in a module. -class RecordAST { -public: - enum RecordASTKind { - Record_Function, - Record_Struct, - }; - - RecordAST(RecordASTKind kind) : kind(kind) {} - virtual ~RecordAST() = default; - - RecordASTKind getKind() const { return kind; } - -private: - const RecordASTKind kind; -}; - -/// This class represents a function definition itself. -class FunctionAST : public RecordAST { - std::unique_ptr proto; - std::unique_ptr body; - -public: - FunctionAST(std::unique_ptr proto, - std::unique_ptr body) - : RecordAST(Record_Function), proto(std::move(proto)), - body(std::move(body)) {} - PrototypeAST *getProto() { return proto.get(); } - ExprASTList *getBody() { return body.get(); } - - /// LLVM style RTTI - static bool classof(const RecordAST *r) { - return r->getKind() == Record_Function; - } -}; - -/// This class represents a struct definition. -class StructAST : public RecordAST { - Location location; - std::string name; - std::vector> variables; - -public: - StructAST(Location location, const std::string &name, - std::vector> variables) - : RecordAST(Record_Struct), location(std::move(location)), name(name), - variables(std::move(variables)) {} - - const Location &loc() { return location; } - llvm::StringRef getName() const { return name; } - llvm::ArrayRef> getVariables() { - return variables; - } - - /// LLVM style RTTI - static bool classof(const RecordAST *r) { - return r->getKind() == Record_Struct; - } -}; - -/// This class represents a list of functions to be processed together -class ModuleAST { - std::vector> records; - -public: - ModuleAST(std::vector> records) - : records(std::move(records)) {} - - auto begin() { return records.begin(); } - auto end() { return records.end(); } -}; - -void dump(ModuleAST &); - -} // namespace toy - -#endif // TOY_AST_H diff --git a/include/toy/CMakeLists.txt b/include/toy/CMakeLists.txt deleted file mode 100644 index 7712e42..0000000 --- a/include/toy/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -# Most dialects should use add_mlir_dialect(). See examples/standalone. -set(LLVM_TARGET_DEFINITIONS Ops.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) -add_public_tablegen_target(ToyCh7OpsIncGen) - -# Most dialects should use add_mlir_interfaces(). -set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) -mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) -add_public_tablegen_target(ToyCh7ShapeInferenceInterfaceIncGen) diff --git a/include/toy/Lexer.h b/include/toy/Lexer.h deleted file mode 100644 index f022c2f..0000000 --- a/include/toy/Lexer.h +++ /dev/null @@ -1,236 +0,0 @@ -//===- Lexer.h - Lexer for the Toy language -------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a simple Lexer for the Toy language. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_LEXER_H -#define TOY_LEXER_H - -#include "llvm/ADT/StringRef.h" - -#include -#include -#include - -namespace toy { - -/// Structure definition a location in a file. -struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. -}; - -// List of Token returned by the lexer. -enum Token : int { - tok_semicolon = ';', - tok_parenthese_open = '(', - tok_parenthese_close = ')', - tok_bracket_open = '{', - tok_bracket_close = '}', - tok_sbracket_open = '[', - tok_sbracket_close = ']', - - tok_eof = -1, - - // commands - tok_return = -2, - tok_var = -3, - tok_def = -4, - tok_struct = -5, - - // primary - tok_identifier = -6, - tok_number = -7, -}; - -/// The Lexer is an abstract base class providing all the facilities that the -/// Parser expects. It goes through the stream one token at a time and keeps -/// track of the location in the file for debugging purpose. -/// It relies on a subclass to provide a `readNextLine()` method. The subclass -/// can proceed by reading the next line from the standard input or from a -/// memory mapped file. -class Lexer { -public: - /// Create a lexer for the given filename. The filename is kept only for - /// debugging purpose (attaching a location to a Token). - Lexer(std::string filename) - : lastLocation( - {std::make_shared(std::move(filename)), 0, 0}) {} - virtual ~Lexer() = default; - - /// Look at the current token in the stream. - Token getCurToken() { return curTok; } - - /// Move to the next token in the stream and return it. - Token getNextToken() { return curTok = getTok(); } - - /// Move to the next token in the stream, asserting on the current token - /// matching the expectation. - void consume(Token tok) { - assert(tok == curTok && "consume Token mismatch expectation"); - getNextToken(); - } - - /// Return the current identifier (prereq: getCurToken() == tok_identifier) - llvm::StringRef getId() { - assert(curTok == tok_identifier); - return identifierStr; - } - - /// Return the current number (prereq: getCurToken() == tok_number) - double getValue() { - assert(curTok == tok_number); - return numVal; - } - - /// Return the location for the beginning of the current token. - Location getLastLocation() { return lastLocation; } - - // Return the current line in the file. - int getLine() { return curLineNum; } - - // Return the current column in the file. - int getCol() { return curCol; } - -private: - /// Delegate to a derived class fetching the next line. Returns an empty - /// string to signal end of file (EOF). Lines are expected to always finish - /// with "\n" - virtual llvm::StringRef readNextLine() = 0; - - /// Return the next character from the stream. This manages the buffer for the - /// current line and request the next line buffer to the derived class as - /// needed. - int getNextChar() { - // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) - return EOF; - ++curCol; - auto nextchar = curLineBuffer.front(); - curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) - curLineBuffer = readNextLine(); - if (nextchar == '\n') { - ++curLineNum; - curCol = 0; - } - return nextchar; - } - - /// Return the next token from standard input. - Token getTok() { - // Skip any whitespace. - while (isspace(lastChar)) - lastChar = Token(getNextChar()); - - // Save the current location before reading the token characters. - lastLocation.line = curLineNum; - lastLocation.col = curCol; - - // Identifier: [a-zA-Z][a-zA-Z0-9_]* - if (isalpha(lastChar)) { - identifierStr = (char)lastChar; - while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') - identifierStr += (char)lastChar; - - if (identifierStr == "return") - return tok_return; - if (identifierStr == "def") - return tok_def; - if (identifierStr == "struct") - return tok_struct; - if (identifierStr == "var") - return tok_var; - return tok_identifier; - } - - // Number: [0-9] ([0-9.])* - if (isdigit(lastChar)) { - std::string numStr; - do { - numStr += lastChar; - lastChar = Token(getNextChar()); - } while (isdigit(lastChar) || lastChar == '.'); - - numVal = strtod(numStr.c_str(), nullptr); - return tok_number; - } - - if (lastChar == '#') { - // Comment until end of line. - do { - lastChar = Token(getNextChar()); - } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - - if (lastChar != EOF) - return getTok(); - } - - // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) - return tok_eof; - - // Otherwise, just return the character as its ascii value. - Token thisChar = Token(lastChar); - lastChar = Token(getNextChar()); - return thisChar; - } - - /// The last token read from the input. - Token curTok = tok_eof; - - /// Location for `curTok`. - Location lastLocation; - - /// If the current Token is an identifier, this string contains the value. - std::string identifierStr; - - /// If the current Token is a number, this contains the value. - double numVal = 0; - - /// The last value returned by getNextChar(). We need to keep it around as we - /// always need to read ahead one character to decide when to end a token and - /// we can't put it back in the stream after reading from it. - Token lastChar = Token(' '); - - /// Keep track of the current line number in the input stream - int curLineNum = 0; - - /// Keep track of the current column number in the input stream - int curCol = 0; - - /// Buffer supplied by the derived class on calls to `readNextLine()` - llvm::StringRef curLineBuffer = "\n"; -}; - -/// A lexer implementation operating on a buffer in memory. -class LexerBuffer final : public Lexer { -public: - LexerBuffer(const char *begin, const char *end, std::string filename) - : Lexer(std::move(filename)), current(begin), end(end) {} - -private: - /// Provide one line at a time to the Lexer, return an empty string when - /// reaching the end of the buffer. - llvm::StringRef readNextLine() override { - auto *begin = current; - while (current <= end && *current && *current != '\n') - ++current; - if (current <= end && *current) - ++current; - llvm::StringRef result{begin, static_cast(current - begin)}; - return result; - } - const char *current, *end; -}; -} // namespace toy - -#endif // TOY_LEXER_H diff --git a/include/toy/MLIRGen.h b/include/toy/MLIRGen.h deleted file mode 100644 index fe9dbe5..0000000 --- a/include/toy/MLIRGen.h +++ /dev/null @@ -1,35 +0,0 @@ -//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares a simple interface to perform IR generation targeting MLIR -// from a Module AST for the Toy language. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_MLIRGEN_H -#define TOY_MLIRGEN_H - -#include - -namespace mlir { -class MLIRContext; -template -class OwningOpRef; -class ModuleOp; -} // namespace mlir - -namespace toy { -class ModuleAST; - -/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module -/// or nullptr on failure. -mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST); -} // namespace toy - -#endif // TOY_MLIRGEN_H diff --git a/include/toy/Ops.td b/include/toy/Ops.td deleted file mode 100644 index 9151396..0000000 --- a/include/toy/Ops.td +++ /dev/null @@ -1,459 +0,0 @@ -//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines the operations of the Toy dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_OPS -#define TOY_OPS - -include "mlir/Interfaces/FunctionInterfaces.td" -include "mlir/IR/SymbolInterfaces.td" -include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "toy/ShapeInferenceInterface.td" - -// Provide a definition of the 'toy' dialect in the ODS framework so that we -// can define our operations. -def Toy_Dialect : Dialect { - let name = "toy"; - let cppNamespace = "::mlir::toy"; - - // We set this bit to generate a declaration of the `materializeConstant` - // method so that we can materialize constants for our toy operations. - let hasConstantMaterializer = 1; - - // We set this bit to generate the declarations for the dialect's type parsing - // and printing hooks. - let useDefaultTypePrinterParser = 1; - -} - -// Base class for toy dialect operations. This operation inherits from the base -// `Op` class in OpBase.td, and provides: -// * The parent dialect of the operation. -// * The mnemonic for the operation, or the name without the dialect prefix. -// * A list of traits for the operation. -class Toy_Op traits = []> : - Op; - -// Provide a definition for the Toy StructType for use in ODS. This allows for -// using StructType in a similar way to Tensor or MemRef. We use `DialectType` -// to demarcate the StructType as belonging to the Toy dialect. -def Toy_StructType : - DialectType($_self)">, - "Toy struct type">; - -// Provide a definition of the types that are used within the Toy dialect. -def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; - -//===----------------------------------------------------------------------===// -// Toy Operations -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -// We define a toy operation by inheriting from our base 'Toy_Op' class above. -// Here we provide the mnemonic and a list of traits for the operation. The -// constant operation is marked as 'Pure' as it is a pure operation -// and may be removed if dead. -def ConstantOp : Toy_Op<"constant", - [ConstantLike, Pure, - DeclareOpInterfaceMethods]> { - // Provide a summary and description for this operation. This can be used to - // auto-generate documentation of the operations within our dialect. - let summary = "constant"; - let description = [{ - Constant operation turns a literal into an SSA value. The data is attached - to the operation as an attribute. For example: - - ```mlir - %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> - : tensor<2x3xf64> - ``` - }]; - - // The constant operation takes an attribute as the only input. - let arguments = (ins F64ElementsAttr:$value); - - // The constant operation returns a single value of TensorType. - let results = (outs F64Tensor); - - // Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; - - // Add custom build methods for the constant operation. These method populates - // the `state` that MLIR uses to create operations, i.e. these are used when - // using `ConstantOp::create(builder, ...)`. - let builders = [ - // Build a constant with a given constant tensor value. - OpBuilder<(ins "DenseElementsAttr":$value), [{ - build($_builder, $_state, value.getType(), value); - }]>, - - // Build a constant with a given constant floating-point value. - OpBuilder<(ins "double":$value)> - ]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; - - // Set the folder bit so that we can implement constant folders. - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// AddOp -//===----------------------------------------------------------------------===// - -def AddOp : Toy_Op<"add", - [Pure, DeclareOpInterfaceMethods]> { - let summary = "element-wise addition operation"; - let description = [{ - The "add" operation performs element-wise addition between two tensors. - The shapes of the tensor operands are expected to match. - }]; - - let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); - let results = (outs F64Tensor); - - // Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; - - // Allow building an AddOp with from the two input operands. - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ]; -} - -//===----------------------------------------------------------------------===// -// CastOp -//===----------------------------------------------------------------------===// - -def CastOp : Toy_Op<"cast", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - Pure, - SameOperandsAndResultShape - ]> { - let summary = "shape cast operation"; - let description = [{ - The "cast" operation converts a tensor from one type to an equivalent type - without changing any data elements. The source and destination types must - both be tensor types with the same element type. If both are ranked, then - shape is required to match. The operation is invalid if converting to a - mismatching constant dimension. - }]; - - let arguments = (ins F64Tensor:$input); - let results = (outs F64Tensor:$output); - - let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; -} - -//===----------------------------------------------------------------------===// -// FuncOp -//===----------------------------------------------------------------------===// - -def FuncOp : Toy_Op<"func", [ - FunctionOpInterface, IsolatedFromAbove - ]> { - let summary = "user defined function operation"; - let description = [{ - The "toy.func" operation represents a user defined function. These are - callable SSA-region operations that contain toy computations. - - Example: - - ```mlir - toy.func @main() { - %0 = toy.constant dense<5.500000e+00> : tensor - %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> - toy.print %1 : tensor<2x2xf64> - toy.return - } - ``` - }]; - - let arguments = (ins - SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs - ); - let regions = (region AnyRegion:$body); - - let builders = [OpBuilder<(ins - "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs) - >]; - let extraClassDeclaration = [{ - //===------------------------------------------------------------------===// - // FunctionOpInterface Methods - //===------------------------------------------------------------------===// - - /// Returns the argument types of this function. - ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } - - /// Returns the result types of this function. - ArrayRef getResultTypes() { return getFunctionType().getResults(); } - - Region *getCallableRegion() { return &getBody(); } - }]; - let hasCustomAssemblyFormat = 1; - let skipDefaultBuilders = 1; -} - -//===----------------------------------------------------------------------===// -// GenericCallOp -//===----------------------------------------------------------------------===// - -def GenericCallOp : Toy_Op<"generic_call", - [DeclareOpInterfaceMethods]> { - let summary = "generic call operation"; - let description = [{ - Generic calls represent calls to a user defined function that needs to - be specialized for the shape of its arguments. The callee name is attached - as a symbol reference via an attribute. The arguments list must match the - arguments expected by the callee. For example: - - ```mlir - %4 = toy.generic_call @my_func(%1, %3) - : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> - ``` - - This is only valid if a function named "my_func" exists and takes two - arguments. - }]; - - // The generic call operation takes a symbol reference attribute as the - // callee, and inputs for the call. - let arguments = (ins - FlatSymbolRefAttr:$callee, - Variadic:$inputs, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs - ); - - // The generic call operation returns a single value of TensorType or - // StructType. - let results = (outs Toy_Type); - - // Specialize assembly printing and parsing using a declarative format. - let assemblyFormat = [{ - $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) - }]; - - // Add custom build methods for the generic call operation. - let builders = [ - OpBuilder<(ins "Type":$result_type, "StringRef":$callee, - "ArrayRef":$arguments)> - ]; -} - -//===----------------------------------------------------------------------===// -// MulOp -//===----------------------------------------------------------------------===// - -def MulOp : Toy_Op<"mul", - [Pure, DeclareOpInterfaceMethods]> { - let summary = "element-wise multiplication operation"; - let description = [{ - The "mul" operation performs element-wise multiplication between two - tensors. The shapes of the tensor operands are expected to match. - }]; - - let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); - let results = (outs F64Tensor); - - // Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; - - // Allow building a MulOp with from the two input operands. - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ]; -} - -//===----------------------------------------------------------------------===// -// PrintOp -//===----------------------------------------------------------------------===// - -def PrintOp : Toy_Op<"print"> { - let summary = "print operation"; - let description = [{ - The "print" builtin operation prints a given input tensor, and produces - no results. - }]; - - // The print operation takes an input tensor to print. - // We also allow a F64MemRef to enable interop during partial lowering. - let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); - - let assemblyFormat = "$input attr-dict `:` type($input)"; -} - -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -def ReshapeOp : Toy_Op<"reshape", [Pure]> { - let summary = "tensor reshape operation"; - let description = [{ - Reshape operation is transforming its input tensor into a new tensor with - the same number of elements but different shapes. For example: - - ```mlir - %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> - ``` - }]; - - let arguments = (ins F64Tensor:$input); - - let assemblyFormat = [{ - `(` $input `:` type($input) `)` attr-dict `to` type(results) - }]; - - // Enable registering canonicalization patterns with this operation. - let hasCanonicalizer = 1; - - // We expect that the reshape operation returns a statically shaped tensor. - let results = (outs StaticShapeTensorOf<[F64]>); -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, - Terminator]> { - let summary = "return operation"; - let description = [{ - The "return" operation represents a return operation within a function. - The operation takes an optional operand and produces no results. - The operand type must match the signature of the function that contains - the operation. For example: - - ```mlir - toy.func @foo() -> tensor<2xf64> { - ... - toy.return %0 : tensor<2xf64> - } - ``` - }]; - - // The return operation takes an optional input operand to return. This - // value must match the return type of the enclosing function. - let arguments = (ins Variadic:$input); - - // The return operation only emits the input in the format if it is present. - let assemblyFormat = "($input^ `:` type($input))? attr-dict "; - - // Allow building a ReturnOp with no return operand. - let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> - ]; - - // Provide extra utility definitions on the c++ operation class definition. - let extraClassDeclaration = [{ - bool hasOperand() { return getNumOperands() != 0; } - }]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; -} - -//===----------------------------------------------------------------------===// -// StructAccessOp -//===----------------------------------------------------------------------===// - -def StructAccessOp : Toy_Op<"struct_access", [Pure]> { - let summary = "struct access"; - let description = [{ - Access the Nth element of a value returning a struct type. - }]; - - let arguments = (ins Toy_StructType:$input, I64Attr:$index); - let results = (outs Toy_Type:$output); - - let assemblyFormat = [{ - $input `[` $index `]` attr-dict `:` type($input) `->` type($output) - }]; - - // Allow building a StructAccessOp with just a struct value and an index. - let builders = [ - OpBuilder<(ins "Value":$input, "size_t":$index)> - ]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; - - // Set the folder bit so that we can fold constant accesses. - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// StructConstantOp -//===----------------------------------------------------------------------===// - -def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, Pure]> { - let summary = "struct constant"; - let description = [{ - Constant operation turns a literal struct value into an SSA value. The data - is attached to the operation as an attribute. The struct constant is encoded - as an array of other constant values. For example: - - ```mlir - %0 = toy.struct_constant [ - dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> - ] : !toy.struct> - ``` - }]; - - let arguments = (ins ArrayAttr:$value); - let results = (outs Toy_StructType:$output); - - let assemblyFormat = "$value attr-dict `:` type($output)"; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -def TransposeOp : Toy_Op<"transpose", - [Pure, DeclareOpInterfaceMethods]> { - let summary = "transpose operation"; - - let arguments = (ins F64Tensor:$input); - let results = (outs F64Tensor); - - let assemblyFormat = [{ - `(` $input `:` type($input) `)` attr-dict `to` type(results) - }]; - - // Enable registering canonicalization patterns with this operation. - let hasCanonicalizer = 1; - - // Allow building a TransposeOp with from the input operand. - let builders = [ - OpBuilder<(ins "Value":$input)> - ]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; -} - -#endif // TOY_OPS diff --git a/include/toy/Parser.h b/include/toy/Parser.h deleted file mode 100644 index 7ba7b8f..0000000 --- a/include/toy/Parser.h +++ /dev/null @@ -1,683 +0,0 @@ -//===- Parser.h - Toy Language Parser -------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the parser for the Toy language. It processes the Token -// provided by the Lexer and returns an AST. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_PARSER_H -#define TOY_PARSER_H - -#include "toy/AST.h" -#include "toy/Lexer.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include -#include - -namespace toy { - -/// This is a simple recursive parser for the Toy language. It produces a well -/// formed AST from a stream of Token supplied by the Lexer. No semantic checks -/// or symbol resolution is performed. For example, variables are referenced by -/// string and the code could reference an undeclared variable and the parsing -/// succeeds. -class Parser { -public: - /// Create a Parser for the supplied lexer. - Parser(Lexer &lexer) : lexer(lexer) {} - - /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer - - // Parse functions and structs one at a time and accumulate in this vector. - std::vector> records; - while (true) { - std::unique_ptr record; - switch (lexer.getCurToken()) { - case tok_eof: - break; - case tok_def: - record = parseDefinition(); - break; - case tok_struct: - record = parseStruct(); - break; - default: - return parseError("'def' or 'struct'", - "when parsing top level module records"); - } - if (!record) - break; - records.push_back(std::move(record)); - } - - // If we didn't reach EOF, there was an error during parsing - if (lexer.getCurToken() != tok_eof) - return parseError("nothing", "at end of module"); - - return std::make_unique(std::move(records)); - } - -private: - Lexer &lexer; - - /// Parse a return statement. - /// return :== return ; | return expr ; - std::unique_ptr parseReturn() { - auto loc = lexer.getLastLocation(); - lexer.consume(tok_return); - - // return takes an optional argument - std::optional> expr; - if (lexer.getCurToken() != ';') { - expr = parseExpression(); - if (!expr) - return nullptr; - } - return std::make_unique(std::move(loc), std::move(expr)); - } - - /// Parse a literal number. - /// numberexpr ::= number - std::unique_ptr parseNumberExpr() { - auto loc = lexer.getLastLocation(); - auto result = - std::make_unique(std::move(loc), lexer.getValue()); - lexer.consume(tok_number); - return std::move(result); - } - - /// Parse a literal array expression. - /// tensorLiteral ::= [ literalList ] | number - /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr parseTensorLiteralExpr() { - auto loc = lexer.getLastLocation(); - lexer.consume(Token('[')); - - // Hold the list of values at this nesting level. - std::vector> values; - // Hold the dimensions for all the nesting inside this level. - std::vector dims; - do { - // We can have either another nested array or a number literal. - if (lexer.getCurToken() == '[') { - values.push_back(parseTensorLiteralExpr()); - if (!values.back()) - return nullptr; // parse error in the nested array. - } else { - if (lexer.getCurToken() != tok_number) - return parseError(" or [", "in literal expression"); - values.push_back(parseNumberExpr()); - } - - // End of this list on ']' - if (lexer.getCurToken() == ']') - break; - - // Elements are separated by a comma. - if (lexer.getCurToken() != ',') - return parseError("] or ,", "in literal expression"); - - lexer.getNextToken(); // eat , - } while (true); - if (values.empty()) - return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] - - /// Fill in the dimensions now. First the current nesting level: - dims.push_back(values.size()); - - /// If there is any nested array, process all of them and ensure that - /// dimensions are uniform. - if (llvm::any_of(values, [](std::unique_ptr &expr) { - return llvm::isa(expr.get()); - })) { - auto *firstLiteral = llvm::dyn_cast(values.front().get()); - if (!firstLiteral) - return parseError("uniform well-nested dimensions", - "inside literal expression"); - - // Append the nested dimensions to the current level - auto firstDims = firstLiteral->getDims(); - dims.insert(dims.end(), firstDims.begin(), firstDims.end()); - - // Sanity check that shape is uniform across all elements of the list. - for (auto &expr : values) { - auto *exprLiteral = llvm::cast(expr.get()); - if (!exprLiteral) - return parseError("uniform well-nested dimensions", - "inside literal expression"); - if (exprLiteral->getDims() != firstDims) - return parseError("uniform well-nested dimensions", - "inside literal expression"); - } - } - return std::make_unique(std::move(loc), std::move(values), - std::move(dims)); - } - - /// Parse a literal struct expression. - /// structLiteral ::= { (structLiteral | tensorLiteral)+ } - std::unique_ptr parseStructLiteralExpr() { - auto loc = lexer.getLastLocation(); - lexer.consume(Token('{')); - - // Hold the list of values. - std::vector> values; - do { - // We can have either another nested array or a number literal. - if (lexer.getCurToken() == '[') { - values.push_back(parseTensorLiteralExpr()); - if (!values.back()) - return nullptr; - } else if (lexer.getCurToken() == tok_number) { - values.push_back(parseNumberExpr()); - if (!values.back()) - return nullptr; - } else { - if (lexer.getCurToken() != '{') - return parseError("{, [, or number", - "in struct literal expression"); - values.push_back(parseStructLiteralExpr()); - } - - // End of this list on '}' - if (lexer.getCurToken() == '}') - break; - - // Elements are separated by a comma. - if (lexer.getCurToken() != ',') - return parseError("} or ,", "in struct literal expression"); - - lexer.getNextToken(); // eat , - } while (true); - if (values.empty()) - return parseError("", - "to fill struct literal expression"); - lexer.getNextToken(); // eat } - - return std::make_unique(std::move(loc), - std::move(values)); - } - - /// parenexpr ::= '(' expression ')' - std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. - auto v = parseExpression(); - if (!v) - return nullptr; - - if (lexer.getCurToken() != ')') - return parseError(")", "to close expression with parentheses"); - lexer.consume(Token(')')); - return v; - } - - /// Parse a call expression. - std::unique_ptr parseCallExpr(llvm::StringRef name, - const Location &loc) { - lexer.consume(Token('(')); - std::vector> args; - if (lexer.getCurToken() != ')') { - while (true) { - if (auto arg = parseExpression()) - args.push_back(std::move(arg)); - else - return nullptr; - - if (lexer.getCurToken() == ')') - break; - - if (lexer.getCurToken() != ',') - return parseError(", or )", "in argument list"); - lexer.getNextToken(); - } - } - lexer.consume(Token(')')); - - // It can be a builtin call to print - if (name == "print") { - if (args.size() != 1) - return parseError("", "as argument to print()"); - - return std::make_unique(loc, std::move(args[0])); - } - - // Call to a user-defined function - return std::make_unique(loc, std::string(name), - std::move(args)); - } - - /// identifierexpr - /// ::= identifier - /// ::= identifier '(' expression ')' - std::unique_ptr parseIdentifierExpr() { - std::string name(lexer.getId()); - - auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. - - if (lexer.getCurToken() != '(') // Simple variable ref. - return std::make_unique(std::move(loc), name); - - // This is a function call. - return parseCallExpr(name, loc); - } - - /// primary - /// ::= identifierexpr - /// ::= numberexpr - /// ::= parenexpr - /// ::= tensorliteral - std::unique_ptr parsePrimary() { - switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case '{': - return parseStructLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; - } - } - - /// Recursively parse the right hand side of a binary expression, the ExprPrec - /// argument indicates the precedence of the current binary operator. - /// - /// binoprhs ::= ('+' primary)* - std::unique_ptr parseBinOpRHS(int exprPrec, - std::unique_ptr lhs) { - // If this is a binop, find its precedence. - while (true) { - int tokPrec = getTokPrecedence(); - - // If this is a binop that binds at least as tightly as the current binop, - // consume it, otherwise we are done. - if (tokPrec < exprPrec) - return lhs; - - // Okay, we know this is a binop. - int binOp = lexer.getCurToken(); - lexer.consume(Token(binOp)); - auto loc = lexer.getLastLocation(); - - // Parse the primary expression after the binary operator. - auto rhs = parsePrimary(); - if (!rhs) - return parseError("expression", "to complete binary operator"); - - // If BinOp binds less tightly with rhs than the operator after rhs, let - // the pending operator take rhs as its lhs. - int nextPrec = getTokPrecedence(); - if (tokPrec < nextPrec) { - rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) - return nullptr; - } - - // Merge lhs/RHS. - lhs = std::make_unique(std::move(loc), binOp, - std::move(lhs), std::move(rhs)); - } - } - - /// expression::= primary binop rhs - std::unique_ptr parseExpression() { - auto lhs = parsePrimary(); - if (!lhs) - return nullptr; - - return parseBinOpRHS(0, std::move(lhs)); - } - - /// type ::= < shape_list > - /// shape_list ::= num | num , shape_list - std::unique_ptr parseType() { - if (lexer.getCurToken() != '<') - return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < - - auto type = std::make_unique(); - - while (lexer.getCurToken() == tok_number) { - type->shape.push_back(lexer.getValue()); - lexer.getNextToken(); - if (lexer.getCurToken() == ',') - lexer.getNextToken(); - } - - if (lexer.getCurToken() != '>') - return parseError(">", "to end type"); - lexer.getNextToken(); // eat > - return type; - } - - /// Parse either a variable declaration or a call expression. - std::unique_ptr parseDeclarationOrCallExpr() { - auto loc = lexer.getLastLocation(); - std::string id(lexer.getId()); - lexer.consume(tok_identifier); - - // Check for a call expression. - if (lexer.getCurToken() == '(') - return parseCallExpr(id, loc); - - // Otherwise, this is a variable declaration. - return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); - } - - /// Parse a typed variable declaration. - std::unique_ptr - parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, - const Location &loc) { - // Parse the variable name. - if (lexer.getCurToken() != tok_identifier) - return parseError("name", "in variable declaration"); - std::string id(lexer.getId()); - lexer.getNextToken(); // eat id - - // Parse the initializer. - std::unique_ptr expr; - if (requiresInitializer) { - if (lexer.getCurToken() != '=') - return parseError("initializer", - "in variable declaration"); - lexer.consume(Token('=')); - expr = parseExpression(); - } - - VarType type; - type.name = std::string(typeName); - return std::make_unique(loc, std::move(id), std::move(type), - std::move(expr)); - } - - /// Parse a variable declaration, for either a tensor value or a struct value, - /// with an optionally required initializer. - /// decl ::= var identifier [ type ] (= expr)? - /// decl ::= identifier identifier (= expr)? - std::unique_ptr parseDeclaration(bool requiresInitializer) { - // Check to see if this is a 'var' declaration. - if (lexer.getCurToken() == tok_var) - return parseVarDeclaration(requiresInitializer); - - // Parse the type name. - if (lexer.getCurToken() != tok_identifier) - return parseError("type name", "in variable declaration"); - auto loc = lexer.getLastLocation(); - std::string typeName(lexer.getId()); - lexer.getNextToken(); // eat id - - // Parse the rest of the declaration. - return parseTypedDeclaration(typeName, requiresInitializer, loc); - } - - /// Parse a variable declaration, it starts with a `var` keyword followed by - /// and identifier and an optional type (shape specification) before the - /// optionally required initializer. - /// decl ::= var identifier [ type ] (= expr)? - std::unique_ptr - parseVarDeclaration(bool requiresInitializer) { - if (lexer.getCurToken() != tok_var) - return parseError("var", "to begin declaration"); - auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var - - if (lexer.getCurToken() != tok_identifier) - return parseError("identified", - "after 'var' declaration"); - std::string id(lexer.getId()); - lexer.getNextToken(); // eat id - - std::unique_ptr type; // Type is optional, it can be inferred - if (lexer.getCurToken() == '<') { - type = parseType(); - if (!type) - return nullptr; - } - if (!type) - type = std::make_unique(); - - std::unique_ptr expr; - if (requiresInitializer) { - lexer.consume(Token('=')); - expr = parseExpression(); - } - return std::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); - } - - /// Parse a block: a list of expression separated by semicolons and wrapped in - /// curly braces. - /// - /// block ::= { expression_list } - /// expression_list ::= block_expr ; expression_list - /// block_expr ::= decl | "return" | expr - std::unique_ptr parseBlock() { - if (lexer.getCurToken() != '{') - return parseError("{", "to begin block"); - lexer.consume(Token('{')); - - auto exprList = std::make_unique(); - - // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') - lexer.consume(Token(';')); - - while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { - if (lexer.getCurToken() == tok_identifier) { - // Variable declaration or call - auto expr = parseDeclarationOrCallExpr(); - if (!expr) - return nullptr; - exprList->push_back(std::move(expr)); - } else if (lexer.getCurToken() == tok_var) { - // Variable declaration - auto varDecl = parseDeclaration(/*requiresInitializer=*/true); - if (!varDecl) - return nullptr; - exprList->push_back(std::move(varDecl)); - } else if (lexer.getCurToken() == tok_return) { - // Return statement - auto ret = parseReturn(); - if (!ret) - return nullptr; - exprList->push_back(std::move(ret)); - } else { - // General expression - auto expr = parseExpression(); - if (!expr) - return nullptr; - exprList->push_back(std::move(expr)); - } - // Ensure that elements are separated by a semicolon. - if (lexer.getCurToken() != ';') - return parseError(";", "after expression"); - - // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') - lexer.consume(Token(';')); - } - - if (lexer.getCurToken() != '}') - return parseError("}", "to close block"); - - lexer.consume(Token('}')); - return exprList; - } - - /// prototype ::= def id '(' decl_list ')' - /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr parsePrototype() { - auto loc = lexer.getLastLocation(); - - if (lexer.getCurToken() != tok_def) - return parseError("def", "in prototype"); - lexer.consume(tok_def); - - if (lexer.getCurToken() != tok_identifier) - return parseError("function name", "in prototype"); - - std::string fnName(lexer.getId()); - lexer.consume(tok_identifier); - - if (lexer.getCurToken() != '(') - return parseError("(", "in prototype"); - lexer.consume(Token('(')); - - std::vector> args; - if (lexer.getCurToken() != ')') { - do { - VarType type; - std::string name; - - // Parse either the name of the variable, or its type. - std::string nameOrType(lexer.getId()); - auto loc = lexer.getLastLocation(); - lexer.consume(tok_identifier); - - // If the next token is an identifier, we just parsed the type. - if (lexer.getCurToken() == tok_identifier) { - type.name = std::move(nameOrType); - - // Parse the name. - name = std::string(lexer.getId()); - lexer.consume(tok_identifier); - } else { - // Otherwise, we just parsed the name. - name = std::move(nameOrType); - } - - args.push_back( - std::make_unique(std::move(loc), name, type)); - if (lexer.getCurToken() != ',') - break; - lexer.consume(Token(',')); - if (lexer.getCurToken() != tok_identifier) - return parseError( - "identifier", "after ',' in function parameter list"); - } while (true); - } - if (lexer.getCurToken() != ')') - return parseError(")", "to end function prototype"); - - // success. - lexer.consume(Token(')')); - return std::make_unique(std::move(loc), fnName, - std::move(args)); - } - - /// Parse a function definition, we expect a prototype initiated with the - /// `def` keyword, followed by a block containing a list of expressions. - /// - /// definition ::= prototype block - std::unique_ptr parseDefinition() { - auto proto = parsePrototype(); - if (!proto) - return nullptr; - - if (auto block = parseBlock()) - return std::make_unique(std::move(proto), std::move(block)); - return nullptr; - } - - /// Parse a struct definition, we expect a struct initiated with the - /// `struct` keyword, followed by a block containing a list of variable - /// declarations. - /// - /// definition ::= `struct` identifier `{` decl+ `}` - std::unique_ptr parseStruct() { - auto loc = lexer.getLastLocation(); - lexer.consume(tok_struct); - if (lexer.getCurToken() != tok_identifier) - return parseError("name", "in struct definition"); - std::string name(lexer.getId()); - lexer.consume(tok_identifier); - - // Parse: '{' - if (lexer.getCurToken() != '{') - return parseError("{", "in struct definition"); - lexer.consume(Token('{')); - - // Parse: decl+ - std::vector> decls; - do { - auto decl = parseDeclaration(/*requiresInitializer=*/false); - if (!decl) - return nullptr; - decls.push_back(std::move(decl)); - - if (lexer.getCurToken() != ';') - return parseError(";", - "after variable in struct definition"); - lexer.consume(Token(';')); - } while (lexer.getCurToken() != '}'); - - // Parse: '}' - lexer.consume(Token('}')); - return std::make_unique(loc, name, std::move(decls)); - } - - /// Get the precedence of the pending binary operator token. - int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) - return -1; - - // 1 is lowest precedence. - switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - case '.': - return 60; - default: - return -1; - } - } - - /// Helper function to signal errors while parsing, it takes an argument - /// indicating the expected token and another argument giving more context. - /// Location is retrieved from the lexer to enrich the error message. - template - std::unique_ptr parseError(T &&expected, U &&context = "") { - auto curToken = lexer.getCurToken(); - llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " - << lexer.getLastLocation().col << "): expected '" << expected - << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) - llvm::errs() << " '" << (char)curToken << "'"; - llvm::errs() << "\n"; - return nullptr; - } -}; - -} // namespace toy - -#endif // TOY_PARSER_H diff --git a/main.cpp b/main.cpp deleted file mode 100644 index ffd94bc..0000000 --- a/main.cpp +++ /dev/null @@ -1,333 +0,0 @@ -//===- toyc.cpp - The Toy Compiler ----------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the entry point for the Toy compiler. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "toy/AST.h" -#include "toy/Dialect.h" -#include "toy/Lexer.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" - -#include "mlir/Dialect/Affine/Transforms/Passes.h" -#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "mlir/InitAllDialects.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Export.h" -#include "mlir/Transforms/Passes.h" - -#include "llvm/ADT/StringRef.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/ErrorOr.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/raw_ostream.h" -#include -#include -#include -#include -#include - -using namespace toy; -namespace cl = llvm::cl; - -static cl::opt inputFilename(cl::Positional, - cl::desc(""), - cl::init("-"), - cl::value_desc("filename")); - -namespace { -enum InputType { Toy, MLIR }; -} // namespace -static cl::opt inputType( - "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), - cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), - cl::values(clEnumValN(MLIR, "mlir", - "load the input file as an MLIR file"))); - -namespace { -enum Action { - None, - DumpAST, - DumpMLIR, - DumpMLIRAffine, - DumpMLIRLLVM, - DumpLLVMIR, - RunJIT -}; -} // namespace -static cl::opt emitAction( - "emit", cl::desc("Select the kind of output desired"), - cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), - cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), - cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", - "output the MLIR dump after affine lowering")), - cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", - "output the MLIR dump after llvm lowering")), - cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), - cl::values( - clEnumValN(RunJIT, "jit", - "JIT the code and run it by invoking the main function"))); - -static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); - -/// Returns a Toy AST resulting from parsing the file or a nullptr on error. -static std::unique_ptr -parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code ec = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << ec.message() << "\n"; - return nullptr; - } - auto buffer = fileOrErr.get()->getBuffer(); - LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); - Parser parser(lexer); - return parser.parseModule(); -} - -static int loadMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { - // Handle '.toy' input to the compiler. - if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).ends_with(".mlir")) { - auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) - return 6; - module = mlirGen(context, *moduleAST); - return !module ? 1 : 0; - } - - // Otherwise, the input is '.mlir'. - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code ec = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << ec.message() << "\n"; - return -1; - } - - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return 3; - } - return 0; -} - -static int loadAndProcessMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { - if (int error = loadMLIR(context, module)) - return error; - - mlir::PassManager pm(module.get()->getName()); - // Apply any generic pass manager command line options and run the pipeline. - if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) - return 4; - - // Check to see what granularity of MLIR we are compiling to. - bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; - bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; - - if (enableOpt || isLoweringToAffine) { - // Inline all functions into main and then delete them. - pm.addPass(mlir::createInlinerPass()); - - // Now that there is only one function, we can infer the shapes of each of - // the operations. - mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::toy::createShapeInferencePass()); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); - } - - if (isLoweringToAffine) { - // Partially lower the toy dialect. - pm.addPass(mlir::toy::createLowerToAffinePass()); - - // Add a few cleanups post lowering. - mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); - - // Add optimizations if enabled. - if (enableOpt) { - optPM.addPass(mlir::affine::createLoopFusionPass()); - optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); - } - } - - if (isLoweringToLLVM) { - // Finish lowering the toy IR to the LLVM dialect. - pm.addPass(mlir::toy::createLowerToLLVMPass()); - // This is necessary to have line tables emitted and basic - // debugger working. In the future we will add proper debug information - // emission directly from our frontend. - pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); - } - - if (mlir::failed(pm.run(*module))) - return 4; - return 0; -} - -static int dumpAST() { - if (inputType == InputType::MLIR) { - llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; - return 5; - } - - auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) - return 1; - - dump(*moduleAST); - return 0; -} - -static int dumpLLVMIR(mlir::ModuleOp module) { - // Register the translation to LLVM IR with the MLIR context. - mlir::registerBuiltinDialectTranslation(*module->getContext()); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // Convert the module to LLVM IR in a new LLVM IR context. - llvm::LLVMContext llvmContext; - auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); - if (!llvmModule) { - llvm::errs() << "Failed to emit LLVM IR\n"; - return -1; - } - - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Create target machine and configure the LLVM Module - auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!tmBuilderOrError) { - llvm::errs() << "Could not create JITTargetMachineBuilder\n"; - return -1; - } - - auto tmOrError = tmBuilderOrError->createTargetMachine(); - if (!tmOrError) { - llvm::errs() << "Could not create TargetMachine\n"; - return -1; - } - mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), - tmOrError.get().get()); - - /// Optionally run an optimization pipeline over the llvm module. - auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, - /*targetMachine=*/nullptr); - if (auto err = optPipeline(llvmModule.get())) { - llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; - return -1; - } - llvm::errs() << *llvmModule << "\n"; - return 0; -} - -static int runJit(mlir::ModuleOp module) { - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Register the translation from MLIR to LLVM IR, which must happen before we - // can JIT-compile. - mlir::registerBuiltinDialectTranslation(*module->getContext()); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // An optimization pipeline to use within the execution engine. - auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, - /*targetMachine=*/nullptr); - - // Create an MLIR execution engine. The execution engine eagerly JIT-compiles - // the module. - mlir::ExecutionEngineOptions engineOptions; - engineOptions.transformer = optPipeline; - auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); - assert(maybeEngine && "failed to construct an execution engine"); - auto &engine = maybeEngine.get(); - - // Invoke the JIT-compiled function. - auto invocationResult = engine->invokePacked("main"); - if (invocationResult) { - llvm::errs() << "JIT invocation failed\n"; - return -1; - } - - return 0; -} - -int main(int argc, char **argv) { - // Register any command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerMLIRContextCLOptions(); - mlir::registerPassManagerCLOptions(); - - cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); - - if (emitAction == Action::DumpAST) - return dumpAST(); - - // If we aren't dumping the AST, then we are compiling with/to MLIR. - mlir::DialectRegistry registry; - mlir::func::registerAllExtensions(registry); - mlir::LLVM::registerInlinerInterface(registry); - - mlir::MLIRContext context(registry); - // Load our Dialect in this MLIR Context. - context.getOrLoadDialect(); - - mlir::OwningOpRef module; - if (int error = loadAndProcessMLIR(context, module)) - return error; - - // If we aren't exporting to non-mlir, then we are done. - bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; - if (isOutputingMLIR) { - module->dump(); - return 0; - } - - // Check to see if we are compiling to LLVM IR. - if (emitAction == Action::DumpLLVMIR) - return dumpLLVMIR(*module); - - // Otherwise, we must be running the jit. - if (emitAction == Action::RunJIT) - return runJit(*module); - - llvm::errs() << "No action specified (parsing only?), use -emit=\n"; - return -1; -} diff --git a/mlir/Dialect.cpp b/mlir/Dialect.cpp deleted file mode 100644 index 4d2f063..0000000 --- a/mlir/Dialect.cpp +++ /dev/null @@ -1,665 +0,0 @@ -//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the dialect for the Toy IR: custom type parsing and -// operation verification. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/FunctionImplementation.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::toy; - -#include "toy/Dialect.cpp.inc" - -//===----------------------------------------------------------------------===// -// ToyInlinerInterface -//===----------------------------------------------------------------------===// - -/// This class defines the interface for handling inlining with Toy -/// operations. -struct ToyInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - /// All call operations within toy can be inlined. - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - return true; - } - - /// All operations within toy can be inlined. - bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { - return true; - } - - // All functions within toy can be inlined. - bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { - return true; - } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator(toy.return) by replacing it with a new - /// operation as necessary. - void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only "toy.return" needs to be handled here. - auto returnOp = cast(op); - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } - - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return CastOp::create(builder, conversionLoc, resultType, input); - } -}; - -//===----------------------------------------------------------------------===// -// Toy Operations -//===----------------------------------------------------------------------===// - -/// A generalized parser for binary operations. This parses the different forms -/// of 'printBinaryOp' below. -static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - SmallVector operands; - SMLoc operandsLoc = parser.getCurrentLocation(); - Type type; - if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return mlir::failure(); - - // If the type is a function type, it contains the input and result types of - // this operation. - if (FunctionType funcType = llvm::dyn_cast(type)) { - if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, - result.operands)) - return mlir::failure(); - result.addTypes(funcType.getResults()); - return mlir::success(); - } - - // Otherwise, the parsed type is the type of both operands and results. - if (parser.resolveOperands(operands, type, result.operands)) - return mlir::failure(); - result.addTypes(type); - return mlir::success(); -} - -/// A generalized printer for binary operations. It prints in two different -/// forms depending on if all of the types match. -static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << " " << op->getOperands(); - printer.printOptionalAttrDict(op->getAttrs()); - printer << " : "; - - // If all of the types are the same, print the type directly. - Type resultType = *op->result_type_begin(); - if (llvm::all_of(op->getOperandTypes(), - [=](Type type) { return type == resultType; })) { - printer << resultType; - return; - } - - // Otherwise, print a functional type. - printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); -} - -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - double value) { - auto dataType = RankedTensorType::get({}, builder.getF64Type()); - auto dataAttribute = DenseElementsAttr::get(dataType, value); - ConstantOp::build(builder, state, dataType, dataAttribute); -} - -/// The 'OpAsmParser' class provides a collection of methods for parsing -/// various punctuation, as well as attributes, operands, types, etc. Each of -/// these methods returns a `ParseResult`. This class is a wrapper around -/// `LogicalResult` that can be converted to a boolean `true` value on failure, -/// or `false` on success. This allows for easily chaining together a set of -/// parser rules. These rules are used to populate an `mlir::OperationState` -/// similarly to the `build` methods described above. -mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - mlir::DenseElementsAttr value; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(value, "value", result.attributes)) - return failure(); - - result.addTypes(value.getType()); - return success(); -} - -/// The 'OpAsmPrinter' class is a stream that allows for formatting -/// strings, attributes, operands, types, etc. -void ConstantOp::print(mlir::OpAsmPrinter &printer) { - printer << " "; - printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); - printer << getValue(); -} - -/// Verify that the given attribute value is valid for the given type. -static llvm::LogicalResult verifyConstantForType(mlir::Type type, - mlir::Attribute opaqueValue, - mlir::Operation *op) { - if (llvm::isa(type)) { - // Check that the value is an elements attribute. - auto attrValue = llvm::dyn_cast(opaqueValue); - if (!attrValue) - return op->emitError("constant of TensorType must be initialized by " - "a DenseFPElementsAttr, got ") - << opaqueValue; - - // If the return type of the constant is not an unranked tensor, the shape - // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(type); - if (!resultType) - return success(); - - // Check that the rank of the attribute type matches the rank of the - // constant result type. - auto attrType = llvm::cast(attrValue.getType()); - if (attrType.getRank() != resultType.getRank()) { - return op->emitOpError("return type must match the one of the attached " - "value attribute: ") - << attrType.getRank() << " != " << resultType.getRank(); - } - - // Check that each of the dimensions match between the two types. - for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - return op->emitOpError( - "return type shape mismatches its attribute at dimension ") - << dim << ": " << attrType.getShape()[dim] - << " != " << resultType.getShape()[dim]; - } - } - return mlir::success(); - } - auto resultType = llvm::cast(type); - llvm::ArrayRef resultElementTypes = resultType.getElementTypes(); - - // Verify that the initializer is an Array. - auto attrValue = llvm::dyn_cast(opaqueValue); - if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) - return op->emitError("constant of StructType must be initialized by an " - "ArrayAttr with the same number of elements, got ") - << opaqueValue; - - // Check that each of the elements are valid. - llvm::ArrayRef attrElementValues = attrValue.getValue(); - for (const auto it : llvm::zip(resultElementTypes, attrElementValues)) - if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) - return mlir::failure(); - return mlir::success(); -} - -/// Verifier for the constant operation. This corresponds to the `::verify(...)` -/// in the op definition. -llvm::LogicalResult ConstantOp::verify() { - return verifyConstantForType(getResult().getType(), getValue(), *this); -} - -llvm::LogicalResult StructConstantOp::verify() { - return verifyConstantForType(getResult().getType(), getValue(), *this); -} - -/// Infer the output shape of the ConstantOp, this is required by the shape -/// inference interface. -void ConstantOp::inferShapes() { - getResult().setType(cast(getValue().getType())); -} - -//===----------------------------------------------------------------------===// -// AddOp -//===----------------------------------------------------------------------===// - -void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands({lhs, rhs}); -} - -mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseBinaryOp(parser, result); -} - -void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } - -/// Infer the output shape of the AddOp, this is required by the shape inference -/// interface. -void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } - -//===----------------------------------------------------------------------===// -// CastOp -//===----------------------------------------------------------------------===// - -/// Infer the output shape of the CastOp, this is required by the shape -/// inference interface. -void CastOp::inferShapes() { getResult().setType(getInput().getType()); } - -/// Returns true if the given set of input and result types are compatible with -/// this cast operation. This is required by the `CastOpInterface` to verify -/// this operation and provide other additional utilities. -bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { - if (inputs.size() != 1 || outputs.size() != 1) - return false; - // The inputs must be Tensors with the same element type. - TensorType input = llvm::dyn_cast(inputs.front()); - TensorType output = llvm::dyn_cast(outputs.front()); - if (!input || !output || input.getElementType() != output.getElementType()) - return false; - // The shape is required to match if both types are ranked. - return !input.hasRank() || !output.hasRank() || input == output; -} - -//===----------------------------------------------------------------------===// -// FuncOp -//===----------------------------------------------------------------------===// - -void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - llvm::StringRef name, mlir::FunctionType type, - llvm::ArrayRef attrs) { - // FunctionOpInterface provides a convenient `build` method that will populate - // the state of our FuncOp, and create an entry block. - buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); -} - -mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - // Dispatch to the FunctionOpInterface provided utility method that parses the - // function operation. - auto buildFuncType = - [](mlir::Builder &builder, llvm::ArrayRef argTypes, - llvm::ArrayRef results, - mlir::function_interface_impl::VariadicFlag, - std::string &) { return builder.getFunctionType(argTypes, results); }; - - return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); -} - -void FuncOp::print(mlir::OpAsmPrinter &p) { - // Dispatch to the FunctionOpInterface provided utility method that prints the - // function operation. - mlir::function_interface_impl::printFunctionOp( - p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); -} - -//===----------------------------------------------------------------------===// -// GenericCallOp -//===----------------------------------------------------------------------===// - -void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Type resultType, StringRef callee, - ArrayRef arguments) { - state.addTypes(resultType); - state.addOperands(arguments); - state.addAttribute("callee", - mlir::SymbolRefAttr::get(builder.getContext(), callee)); -} - -/// Return the callee of the generic call operation, this is required by the -/// call interface. -CallInterfaceCallable GenericCallOp::getCallableForCallee() { - return (*this)->getAttrOfType("callee"); -} - -/// Set the callee for the generic call operation, this is required by the call -/// interface. -void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", cast(callee)); -} - -/// Get the argument operands to the called function, this is required by the -/// call interface. -Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } - -/// Get the argument operands to the called function as a mutable range, this is -/// required by the call interface. -MutableOperandRange GenericCallOp::getArgOperandsMutable() { - return getInputsMutable(); -} - -//===----------------------------------------------------------------------===// -// MulOp -//===----------------------------------------------------------------------===// - -void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands({lhs, rhs}); -} - -mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseBinaryOp(parser, result); -} - -void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } - -/// Infer the output shape of the MulOp, this is required by the shape inference -/// interface. -void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -llvm::LogicalResult ReturnOp::verify() { - // We know that the parent operation is a function, because of the 'HasParent' - // trait attached to the operation definition. - auto function = cast((*this)->getParentOp()); - - /// ReturnOps can only have a single optional operand. - if (getNumOperands() > 1) - return emitOpError() << "expects at most 1 return operand"; - - // The operand number and types must match the function signature. - const auto &results = function.getFunctionType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError() << "does not return the same number of values (" - << getNumOperands() << ") as the enclosing function (" - << results.size() << ")"; - - // If the operation does not have an input, we are done. - if (!hasOperand()) - return mlir::success(); - - auto inputType = *operand_type_begin(); - auto resultType = results.front(); - - // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || - llvm::isa(resultType)) - return mlir::success(); - - return emitError() << "type of return operand (" << inputType - << ") doesn't match function result type (" << resultType - << ")"; -} - -//===----------------------------------------------------------------------===// -// StructAccessOp -//===----------------------------------------------------------------------===// - -void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, - mlir::Value input, size_t index) { - // Extract the result type from the input type. - StructType structTy = llvm::cast(input.getType()); - assert(index < structTy.getNumElementTypes()); - mlir::Type resultType = structTy.getElementTypes()[index]; - - // Call into the auto-generated build method. - build(b, state, resultType, input, b.getI64IntegerAttr(index)); -} - -llvm::LogicalResult StructAccessOp::verify() { - StructType structTy = llvm::cast(getInput().getType()); - size_t indexValue = getIndex(); - if (indexValue >= structTy.getNumElementTypes()) - return emitOpError() - << "index should be within the range of the input struct type"; - mlir::Type resultType = getResult().getType(); - if (resultType != structTy.getElementTypes()[indexValue]) - return emitOpError() << "must have the same result type as the struct " - "element referred to by the index"; - return mlir::success(); -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value value) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands(value); -} - -void TransposeOp::inferShapes() { - auto arrayTy = llvm::cast(getOperand().getType()); - SmallVector dims(llvm::reverse(arrayTy.getShape())); - getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); -} - -llvm::LogicalResult TransposeOp::verify() { - auto inputType = llvm::dyn_cast(getOperand().getType()); - auto resultType = llvm::dyn_cast(getType()); - if (!inputType || !resultType) - return mlir::success(); - - auto inputShape = inputType.getShape(); - if (!std::equal(inputShape.begin(), inputShape.end(), - resultType.getShape().rbegin())) { - return emitError() - << "expected result shape to be a transpose of the input"; - } - return mlir::success(); -} - -//===----------------------------------------------------------------------===// -// Toy Types -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace toy { -namespace detail { -/// This class represents the internal storage of the Toy `StructType`. -struct StructTypeStorage : public mlir::TypeStorage { - /// The `KeyTy` is a required type that provides an interface for the storage - /// instance. This type will be used when uniquing an instance of the type - /// storage. For our struct type, we will unique each instance structurally on - /// the elements that it contains. - using KeyTy = llvm::ArrayRef; - - /// A constructor for the type storage instance. - StructTypeStorage(llvm::ArrayRef elementTypes) - : elementTypes(elementTypes) {} - - /// Define the comparison function for the key type with the current storage - /// instance. This is used when constructing a new instance to ensure that we - /// haven't already uniqued an instance of the given key. - bool operator==(const KeyTy &key) const { return key == elementTypes; } - - /// Define a hash function for the key type. This is used when uniquing - /// instances of the storage, see the `StructType::get` method. - /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type - /// have hash functions available, so we could just omit this entirely. - static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_value(key); - } - - /// Define a construction function for the key type from a set of parameters. - /// These parameters will be provided when constructing the storage instance - /// itself. - /// Note: This method isn't necessary because KeyTy can be directly - /// constructed with the given parameters. - static KeyTy getKey(llvm::ArrayRef elementTypes) { - return KeyTy(elementTypes); - } - - /// Define a construction method for creating a new instance of this storage. - /// This method takes an instance of a storage allocator, and an instance of a - /// `KeyTy`. The given allocator must be used for *all* necessary dynamic - /// allocations used to create the type storage and its internal. - static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { - // Copy the elements from the provided `KeyTy` into the allocator. - llvm::ArrayRef elementTypes = allocator.copyInto(key); - - // Allocate the storage instance and construct it. - return new (allocator.allocate()) - StructTypeStorage(elementTypes); - } - - /// The following field contains the element types of the struct. - llvm::ArrayRef elementTypes; -}; -} // namespace detail -} // namespace toy -} // namespace mlir - -/// Create an instance of a `StructType` with the given element types. There -/// *must* be at least one element type. -StructType StructType::get(llvm::ArrayRef elementTypes) { - assert(!elementTypes.empty() && "expected at least 1 element type"); - - // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance - // of this type. The first parameter is the context to unique in. The - // parameters after the context are forwarded to the storage instance. - mlir::MLIRContext *ctx = elementTypes.front().getContext(); - return Base::get(ctx, elementTypes); -} - -/// Returns the element types of this struct type. -llvm::ArrayRef StructType::getElementTypes() { - // 'getImpl' returns a pointer to the internal storage instance. - return getImpl()->elementTypes; -} - -/// Parse an instance of a type registered to the toy dialect. -mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { - // Parse a struct type in the following form: - // struct-type ::= `struct` `<` type (`,` type)* `>` - - // NOTE: All MLIR parser function return a ParseResult. This is a - // specialization of LogicalResult that auto-converts to a `true` boolean - // value on failure to allow for chaining, but may be used with explicit - // `mlir::failed/mlir::succeeded` as desired. - - // Parse: `struct` `<` - if (parser.parseKeyword("struct") || parser.parseLess()) - return Type(); - - // Parse the element types of the struct. - SmallVector elementTypes; - do { - // Parse the current element type. - SMLoc typeLoc = parser.getCurrentLocation(); - mlir::Type elementType; - if (parser.parseType(elementType)) - return nullptr; - - // Check that the type is either a TensorType or another StructType. - if (!llvm::isa(elementType)) { - parser.emitError(typeLoc, "element type for a struct must either " - "be a TensorType or a StructType, got: ") - << elementType; - return Type(); - } - elementTypes.push_back(elementType); - - // Parse the optional: `,` - } while (succeeded(parser.parseOptionalComma())); - - // Parse: `>` - if (parser.parseGreater()) - return Type(); - return StructType::get(elementTypes); -} - -/// Print an instance of a type registered to the toy dialect. -void ToyDialect::printType(mlir::Type type, - mlir::DialectAsmPrinter &printer) const { - // Currently the only toy type is a struct type. - StructType structType = llvm::cast(type); - - // Print the struct type according to the parser format. - printer << "struct<"; - llvm::interleaveComma(structType.getElementTypes(), printer); - printer << '>'; -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "toy/Ops.cpp.inc" - -//===----------------------------------------------------------------------===// -// ToyDialect -//===----------------------------------------------------------------------===// - -/// Dialect initialization, the instance will be owned by the context. This is -/// the point of registration of types and operations for the dialect. -void ToyDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "toy/Ops.cpp.inc" - >(); - addInterfaces(); - addTypes(); -} - -mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, - mlir::Attribute value, - mlir::Type type, - mlir::Location loc) { - if (llvm::isa(type)) - return StructConstantOp::create(builder, loc, type, - llvm::cast(value)); - return ConstantOp::create(builder, loc, type, - llvm::cast(value)); -} diff --git a/mlir/LowerToAffineLoops.cpp b/mlir/LowerToAffineLoops.cpp deleted file mode 100644 index cbe4236..0000000 --- a/mlir/LowerToAffineLoops.cpp +++ /dev/null @@ -1,368 +0,0 @@ -//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a partial lowering of Toy operations to a combination of -// affine loops, memref operations and standard operations. This lowering -// expects that all calls have been inlined, and all shapes have been resolved. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectRegistry.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/TypeID.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/Support/Casting.h" -#include -#include -#include -#include -#include - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns -//===----------------------------------------------------------------------===// - -/// Convert the given RankedTensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(RankedTensorType type) { - return MemRefType::get(type.getShape(), type.getElementType()); -} - -/// Insert an allocation and deallocation for the given MemRefType. -static Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { - auto alloc = memref::AllocOp::create(rewriter, loc, type); - - // Make sure to allocate at the beginning of the block. - auto *parentBlock = alloc->getBlock(); - alloc->moveBefore(&parentBlock->front()); - - // Make sure to deallocate this alloc at the end of the block. This is fine - // as toy functions have no control flow. - auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); - dealloc->moveBefore(&parentBlock->back()); - return alloc; -} - -/// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder and the range of loop induction -/// variables for the iteration. It returns a value to store at the current -/// index of the iteration. -using LoopIterationFn = - function_ref; - -static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, - LoopIterationFn processIteration) { - auto tensorType = llvm::cast((*op->result_type_begin())); - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - - // Create a nest of affine loops, with one loop per dimension of the shape. - // The buildAffineLoopNest function takes a callback that is used to construct - // the body of the innermost loop given a builder, a location and a range of - // loop induction variables. - SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); - SmallVector steps(tensorType.getRank(), /*Value=*/1); - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, tensorType.getShape(), steps, - [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, ivs); - affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, - ivs); - }); - - // Replace this operation with the generated alloc. - rewriter.replaceOp(op, alloc); -} - -namespace { -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns: Binary operations -//===----------------------------------------------------------------------===// - -template -struct BinaryOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OpConversionPattern::OpAdaptor; - - LogicalResult - matchAndRewrite(BinaryOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = - affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); - auto loadedRhs = - affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); - }); - return success(); - } -}; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; - -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns: Constant operations -//===----------------------------------------------------------------------===// - -struct ConstantOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - DenseElementsAttr constantValue = op.getValue(); - Location loc = op.getLoc(); - - // When lowering the constant operation, we allocate and assign the constant - // values to a corresponding memref allocation. - auto tensorType = llvm::cast(op.getType()); - auto memRefType = convertTensorToMemRef(tensorType); - auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - - // We will be generating constant indices up-to the largest dimension. - // Create these constants up-front to avoid large amounts of redundant - // operations. - auto valueShape = memRefType.getShape(); - SmallVector constantIndices; - - if (!valueShape.empty()) { - for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) - constantIndices.push_back( - arith::ConstantIndexOp::create(rewriter, loc, i)); - } else { - // This is the case of a tensor of rank 0. - constantIndices.push_back( - arith::ConstantIndexOp::create(rewriter, loc, 0)); - } - - // The constant operation represents a multi-dimensional constant, so we - // will need to generate a store for each of the elements. The following - // functor recursively walks the dimensions of the constant shape, - // generating a store when the recursion hits the base case. - SmallVector indices; - auto valueIt = constantValue.value_begin(); - std::function storeElements = [&](uint64_t dimension) { - // The last dimension is the base case of the recursion, at this point - // we store the element at the given index. - if (dimension == valueShape.size()) { - affine::AffineStoreOp::create( - rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), - alloc, llvm::ArrayRef(indices)); - return; - } - - // Otherwise, iterate over the current dimension and add the indices to - // the list. - for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { - indices.push_back(constantIndices[i]); - storeElements(dimension + 1); - indices.pop_back(); - } - }; - - // Start the element storing recursion from the first dimension. - storeElements(/*dimension=*/0); - - // Replace this operation with the generated alloc. - rewriter.replaceOp(op, alloc); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns: Func operations -//===----------------------------------------------------------------------===// - -struct FuncOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // We only lower the main function as we expect that all other functions - // have been inlined. - if (op.getName() != "main") - return failure(); - - // Verify that the given main has no inputs and results. - if (op.getNumArguments() || op.getFunctionType().getNumResults()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "expected 'main' to have 0 inputs and 0 results"; - }); - } - - // Create a new non-toy function, with the same region. - auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), - op.getFunctionType()); - rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns: Print operations -//===----------------------------------------------------------------------===// - -struct PrintOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // We don't lower "toy.print" in this pass, but we need to update its - // operands. - rewriter.modifyOpInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns: Return operations -//===----------------------------------------------------------------------===// - -struct ReturnOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // During this lowering, we expect that all function calls have been - // inlined. - if (op.hasOperand()) - return failure(); - - // We lower "toy.return" directly to "func.return". - rewriter.replaceOpWithNewOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine Conversion Patterns: Transpose operations -//===----------------------------------------------------------------------===// - -struct TransposeOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { - Value input = adaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); - }); - return success(); - } -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// ToyToAffineLoweringPass -//===----------------------------------------------------------------------===// - -/// This is a partial lowering to affine loops of the toy operations that are -/// computationally intensive (like matmul for example...) while keeping the -/// rest of the code in the Toy dialect. -namespace { -struct ToyToAffineLoweringPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) - StringRef getArgument() const override { return "toy-to-affine"; } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() final; -}; -} // namespace - -void ToyToAffineLoweringPass::runOnOperation() { - // The first thing to define is the conversion target. This will define the - // final target for this lowering. - ConversionTarget target(getContext()); - - // We define the specific operations, or dialects, that are legal targets for - // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `Arith`, `Func`, and `MemRef` dialects. - target.addLegalDialect(); - - // We also define the Toy dialect as Illegal so that the conversion will fail - // if any of these operations are *not* converted. Given that we actually want - // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands - // to be updated though (as we convert from TensorType to MemRefType), so we - // only treat it as `legal` if its operands are legal. - target.addIllegalDialect(); - target.addDynamicallyLegalOp([](toy::PrintOp op) { - return llvm::none_of(op->getOperandTypes(), - [](Type type) { return llvm::isa(type); }); - }); - - // Now that the conversion target has been defined, we just need to provide - // the set of patterns that will lower the Toy operations. - RewritePatternSet patterns(&getContext()); - patterns.add( - &getContext()); - - // With the target and rewrite patterns defined, we can now attempt the - // conversion. The conversion will signal failure if any of our `illegal` - // operations were not converted successfully. - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} - -/// Create a pass for lowering operations in the `Affine` and `Std` dialects, -/// for a subset of the Toy IR (e.g. matmul). -std::unique_ptr mlir::toy::createLowerToAffinePass() { - return std::make_unique(); -} diff --git a/mlir/MLIRGen.cpp b/mlir/MLIRGen.cpp deleted file mode 100644 index 7313324..0000000 --- a/mlir/MLIRGen.cpp +++ /dev/null @@ -1,691 +0,0 @@ -//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a simple IR generation targeting MLIR from a Module AST -// for the Toy language. -// -//===----------------------------------------------------------------------===// - -#include "toy/MLIRGen.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Value.h" -#include "toy/AST.h" -#include "toy/Dialect.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "toy/Lexer.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/ErrorHandling.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace mlir::toy; -using namespace toy; - -using llvm::ArrayRef; -using llvm::cast; -using llvm::dyn_cast; -using llvm::isa; -using llvm::ScopedHashTableScope; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace { - -/// Implementation of a simple MLIR emission from the Toy AST. -/// -/// This will emit operations that are specific to the Toy language, preserving -/// the semantics of the language and (hopefully) allow to perform accurate -/// analysis and transformation based on these high level semantics. -class MLIRGenImpl { -public: - MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} - - /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module operation. - mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { - // We create an empty MLIR module and codegen functions one at a time and - // add them to the module. - theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - - for (auto &record : moduleAST) { - if (FunctionAST *funcAST = llvm::dyn_cast(record.get())) { - mlir::toy::FuncOp func = mlirGen(*funcAST); - if (!func) - return nullptr; - functionMap.insert({func.getName(), func}); - } else if (StructAST *str = llvm::dyn_cast(record.get())) { - if (failed(mlirGen(*str))) - return nullptr; - } else { - llvm_unreachable("unknown record type"); - } - } - - // Verify the module after we have finished constructing it, this will check - // the structural properties of the IR and invoke any specific verifiers we - // have on the Toy operations. - if (failed(mlir::verify(theModule))) { - theModule.emitError("module verification error"); - return nullptr; - } - - return theModule; - } - -private: - /// A "module" matches a Toy source file: containing a list of functions. - mlir::ModuleOp theModule; - - /// The builder is a helper class to create IR inside a function. The builder - /// is stateful, in particular it keeps an "insertion point": this is where - /// the next operations will be introduced. - mlir::OpBuilder builder; - - /// The symbol table maps a variable name to a value in the current scope. - /// Entering a function creates a new scope, and the function arguments are - /// added to the mapping. When the processing of a function is terminated, the - /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable> - symbolTable; - using SymbolTableScopeT = - llvm::ScopedHashTableScope>; - - /// A mapping for the functions that have been code generated to MLIR. - llvm::StringMap functionMap; - - /// A mapping for named struct types to the underlying MLIR type and the - /// original AST node. - llvm::StringMap> structMap; - - /// Helper conversion for a Toy AST location to an MLIR location. - mlir::Location loc(const Location &loc) { - return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, - loc.col); - } - - /// Declare a variable in the current scope, return success if the variable - /// wasn't declared yet. - llvm::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { - if (symbolTable.count(var.getName())) - return mlir::failure(); - symbolTable.insert(var.getName(), {value, &var}); - return mlir::success(); - } - - /// Create an MLIR type for the given struct. - llvm::LogicalResult mlirGen(StructAST &str) { - if (structMap.count(str.getName())) - return emitError(loc(str.loc())) << "error: struct type with name `" - << str.getName() << "' already exists"; - - auto variables = str.getVariables(); - std::vector elementTypes; - elementTypes.reserve(variables.size()); - for (auto &variable : variables) { - if (variable->getInitVal()) - return emitError(loc(variable->loc())) - << "error: variables within a struct definition must not have " - "initializers"; - if (!variable->getType().shape.empty()) - return emitError(loc(variable->loc())) - << "error: variables within a struct definition must not have " - "initializers"; - - mlir::Type type = getType(variable->getType(), variable->loc()); - if (!type) - return mlir::failure(); - elementTypes.push_back(type); - } - - structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); - return mlir::success(); - } - - /// Create the prototype for an MLIR function with as many arguments as the - /// provided Toy AST prototype. - mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { - auto location = loc(proto.loc()); - - // This is a generic function, the return type will be inferred later. - llvm::SmallVector argTypes; - argTypes.reserve(proto.getArgs().size()); - for (auto &arg : proto.getArgs()) { - mlir::Type type = getType(arg->getType(), arg->loc()); - if (!type) - return nullptr; - argTypes.push_back(type); - } - auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); - return mlir::toy::FuncOp::create(builder, location, proto.getName(), - funcType); - } - - /// Emit a new function and add it to the MLIR module. - mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { - // Create a scope in the symbol table to hold variable declarations. - SymbolTableScopeT varScope(symbolTable); - - // Create an MLIR function for the given prototype. - builder.setInsertionPointToEnd(theModule.getBody()); - mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) - return nullptr; - - // Let's start the body of the function now! - mlir::Block &entryBlock = function.front(); - auto protoArgs = funcAST.getProto()->getArgs(); - - // Declare all the function arguments in the symbol table. - for (const auto nameValue : - llvm::zip(protoArgs, entryBlock.getArguments())) { - if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue)))) - return nullptr; - } - - // Set the insertion point in the builder to the beginning of the function - // body, it will be used throughout the codegen to create operations in this - // function. - builder.setInsertionPointToStart(&entryBlock); - - // Emit the body of the function. - if (mlir::failed(mlirGen(*funcAST.getBody()))) { - function.erase(); - return nullptr; - } - - // Implicitly return void if no return statement was emitted. - // FIXME: we may fix the parser instead to always return the last expression - // (this would possibly help the REPL case later) - ReturnOp returnOp; - if (!entryBlock.empty()) - returnOp = dyn_cast(entryBlock.back()); - if (!returnOp) { - ReturnOp::create(builder, loc(funcAST.getProto()->loc())); - } else if (returnOp.hasOperand()) { - // Otherwise, if this return operation has an operand then add a result to - // the function. - function.setType( - builder.getFunctionType(function.getFunctionType().getInputs(), - *returnOp.operand_type_begin())); - } - - // If this function isn't main, then set the visibility to private. - if (funcAST.getProto()->getName() != "main") - function.setPrivate(); - - return function; - } - - /// Return the struct type that is the result of the given expression, or null - /// if it cannot be inferred. - StructAST *getStructFor(ExprAST *expr) { - llvm::StringRef structName; - if (auto *decl = llvm::dyn_cast(expr)) { - auto varIt = symbolTable.lookup(decl->getName()); - if (!varIt.first) - return nullptr; - structName = varIt.second->getType().name; - } else if (auto *access = llvm::dyn_cast(expr)) { - if (access->getOp() != '.') - return nullptr; - // The name being accessed should be in the RHS. - auto *name = llvm::dyn_cast(access->getRHS()); - if (!name) - return nullptr; - StructAST *parentStruct = getStructFor(access->getLHS()); - if (!parentStruct) - return nullptr; - - // Get the element within the struct corresponding to the name. - VarDeclExprAST *decl = nullptr; - for (auto &var : parentStruct->getVariables()) { - if (var->getName() == name->getName()) { - decl = var.get(); - break; - } - } - if (!decl) - return nullptr; - structName = decl->getType().name; - } - if (structName.empty()) - return nullptr; - - // If the struct name was valid, check for an entry in the struct map. - auto structIt = structMap.find(structName); - if (structIt == structMap.end()) - return nullptr; - return structIt->second.second; - } - - /// Return the numeric member index of the given struct access expression. - std::optional getMemberIndex(BinaryExprAST &accessOp) { - assert(accessOp.getOp() == '.' && "expected access operation"); - - // Lookup the struct node for the LHS. - StructAST *structAST = getStructFor(accessOp.getLHS()); - if (!structAST) - return std::nullopt; - - // Get the name from the RHS. - VariableExprAST *name = llvm::dyn_cast(accessOp.getRHS()); - if (!name) - return std::nullopt; - - auto structVars = structAST->getVariables(); - const auto *it = llvm::find_if(structVars, [&](auto &var) { - return var->getName() == name->getName(); - }); - if (it == structVars.end()) - return std::nullopt; - return it - structVars.begin(); - } - - /// Emit a binary operation - mlir::Value mlirGen(BinaryExprAST &binop) { - // First emit the operations for each side of the operation before emitting - // the operation itself. For example if the expression is `a + foo(a)` - // 1) First it will visiting the LHS, which will return a reference to the - // value holding `a`. This value should have been emitted at declaration - // time and registered in the symbol table, so nothing would be - // codegen'd. If the value is not in the symbol table, an error has been - // emitted and nullptr is returned. - // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted - // and the result value is returned. If an error occurs we get a nullptr - // and propagate. - // - mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) - return nullptr; - auto location = loc(binop.loc()); - - // If this is an access operation, handle it immediately. - if (binop.getOp() == '.') { - std::optional accessIndex = getMemberIndex(binop); - if (!accessIndex) { - emitError(location, "invalid access into struct expression"); - return nullptr; - } - return StructAccessOp::create(builder, location, lhs, *accessIndex); - } - - // Otherwise, this is a normal binary op. - mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) - return nullptr; - - // Derive the operation name from the binary operator. At the moment we only - // support '+' and '*'. - switch (binop.getOp()) { - case '+': - return AddOp::create(builder, location, lhs, rhs); - case '*': - return MulOp::create(builder, location, lhs, rhs); - } - - emitError(location, "invalid binary operator '") << binop.getOp() << "'"; - return nullptr; - } - - /// This is a reference to a variable in an expression. The variable is - /// expected to have been declared and so should have a value in the symbol - /// table, otherwise emit an error and return nullptr. - mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName()).first) - return variable; - - emitError(loc(expr.loc()), "error: unknown variable '") - << expr.getName() << "'"; - return nullptr; - } - - /// Emit a return operation. This will return failure if any generation fails. - llvm::LogicalResult mlirGen(ReturnExprAST &ret) { - auto location = loc(ret.loc()); - - // 'return' takes an optional expression, handle that case here. - mlir::Value expr = nullptr; - if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) - return mlir::failure(); - } - - // Otherwise, this return operation has zero operands. - ReturnOp::create(builder, location, - expr ? ArrayRef(expr) : ArrayRef()); - return mlir::success(); - } - - /// Emit a constant for a literal/constant array. It will be emitted as a - /// flattened array of data in an Attribute attached to a `toy.constant` - /// operation. See documentation on [Attributes](LangRef.md#attributes) for - /// more details. Here is an excerpt: - /// - /// Attributes are the mechanism for specifying constant data in MLIR in - /// places where a variable is never allowed [...]. They consist of a name - /// and a concrete attribute value. The set of expected attributes, their - /// structure, and their interpretation are all contextually dependent on - /// what they are attached to. - /// - /// Example, the source level statement: - /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - /// will be converted to: - /// %0 = "toy.constant"() {value: dense, - /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], - /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> - /// - mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { - // The attribute is a vector with a floating point value per element - // (number) in the array, see `collectData()` below for more details. - std::vector data; - data.reserve(llvm::product_of(lit.getDims())); - collectData(lit, data); - - // The type of this attribute is tensor of 64-bit floating-point with the - // shape of the literal. - mlir::Type elementType = builder.getF64Type(); - auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); - - // This is the actual attribute that holds the list of values for this - // tensor literal. - return mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); - } - mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { - // The type of this attribute is tensor of 64-bit floating-point with no - // shape. - mlir::Type elementType = builder.getF64Type(); - auto dataType = mlir::RankedTensorType::get({}, elementType); - - // This is the actual attribute that holds the list of values for this - // tensor literal. - return mlir::DenseElementsAttr::get(dataType, - llvm::ArrayRef(lit.getValue())); - } - /// Emit a constant for a struct literal. It will be emitted as an array of - /// other literals in an Attribute attached to a `toy.struct_constant` - /// operation. This function returns the generated constant, along with the - /// corresponding struct type. - std::pair - getConstantAttr(StructLiteralExprAST &lit) { - std::vector attrElements; - std::vector typeElements; - - for (auto &var : lit.getValues()) { - if (auto *number = llvm::dyn_cast(var.get())) { - attrElements.push_back(getConstantAttr(*number)); - typeElements.push_back(getType(/*shape=*/{})); - } else if (auto *lit = llvm::dyn_cast(var.get())) { - attrElements.push_back(getConstantAttr(*lit)); - typeElements.push_back(getType(/*shape=*/{})); - } else { - auto *structLit = llvm::cast(var.get()); - auto attrTypePair = getConstantAttr(*structLit); - attrElements.push_back(attrTypePair.first); - typeElements.push_back(attrTypePair.second); - } - } - mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); - mlir::Type dataType = StructType::get(typeElements); - return std::make_pair(dataAttr, dataType); - } - - /// Emit an array literal. - mlir::Value mlirGen(LiteralExprAST &lit) { - mlir::Type type = getType(lit.getDims()); - mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); - - // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` - // method. - return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); - } - - /// Emit a struct literal. It will be emitted as an array of - /// other literals in an Attribute attached to a `toy.struct_constant` - /// operation. - mlir::Value mlirGen(StructLiteralExprAST &lit) { - mlir::ArrayAttr dataAttr; - mlir::Type dataType; - std::tie(dataAttr, dataType) = getConstantAttr(lit); - - // Build the MLIR op `toy.struct_constant`. This invokes the - // `StructConstantOp::build` method. - return StructConstantOp::create(builder, loc(lit.loc()), dataType, - dataAttr); - } - - /// Recursive helper function to accumulate the data that compose an array - /// literal. It flattens the nested structure in the supplied vector. For - /// example with this array: - /// [[1, 2], [3, 4]] - /// we will generate: - /// [ 1, 2, 3, 4 ] - /// Individual numbers are represented as doubles. - /// Attributes are the way MLIR attaches constant to operations. - void collectData(ExprAST &expr, std::vector &data) { - if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) - collectData(*value, data); - return; - } - - assert(isa(expr) && "expected literal or number expr"); - data.push_back(cast(expr).getValue()); - } - - /// Emit a call expression. It emits specific operations for the `transpose` - /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value mlirGen(CallExprAST &call) { - llvm::StringRef callee = call.getCallee(); - auto location = loc(call.loc()); - - // Codegen the operands first. - SmallVector operands; - for (auto &expr : call.getArgs()) { - auto arg = mlirGen(*expr); - if (!arg) - return nullptr; - operands.push_back(arg); - } - - // Builtin calls have their custom operation, meaning this is a - // straightforward emission. - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - emitError(location, "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); - return nullptr; - } - return TransposeOp::create(builder, location, operands[0]); - } - - // Otherwise this is a call to a user-defined function. Calls to - // user-defined functions are mapped to a custom call that takes the callee - // name as an attribute. - auto calledFuncIt = functionMap.find(callee); - if (calledFuncIt == functionMap.end()) { - emitError(location) << "no defined function found for '" << callee << "'"; - return nullptr; - } - mlir::toy::FuncOp calledFunc = calledFuncIt->second; - return GenericCallOp::create(builder, location, - calledFunc.getFunctionType().getResult(0), - callee, operands); - } - - /// Emit a print expression. It emits specific operations for two builtins: - /// transpose(x) and print(x). - llvm::LogicalResult mlirGen(PrintExprAST &call) { - auto arg = mlirGen(*call.getArg()); - if (!arg) - return mlir::failure(); - - PrintOp::create(builder, loc(call.loc()), arg); - return mlir::success(); - } - - /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value mlirGen(NumberExprAST &num) { - return ConstantOp::create(builder, loc(num.loc()), num.getValue()); - } - - /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value mlirGen(ExprAST &expr) { - switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_StructLiteral: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; - } - } - - /// Handle a variable declaration, we'll codegen the expression that forms the - /// initializer and record the value in the symbol table before returning it. - /// Future expressions will be able to reference this variable through symbol - /// table lookup. - mlir::Value mlirGen(VarDeclExprAST &vardecl) { - auto *init = vardecl.getInitVal(); - if (!init) { - emitError(loc(vardecl.loc()), - "missing initializer in variable declaration"); - return nullptr; - } - - mlir::Value value = mlirGen(*init); - if (!value) - return nullptr; - - // Handle the case where we are initializing a struct value. - VarType varType = vardecl.getType(); - if (!varType.name.empty()) { - // Check that the initializer type is the same as the variable - // declaration. - mlir::Type type = getType(varType, vardecl.loc()); - if (!type) - return nullptr; - if (type != value.getType()) { - emitError(loc(vardecl.loc())) - << "struct type of initializer is different than the variable " - "declaration. Got " - << value.getType() << ", but expected " << type; - return nullptr; - } - - // Otherwise, we have the initializer value, but in case the variable was - // declared with specific shape, we emit a "reshape" operation. It will - // get optimized out later as needed. - } else if (!varType.shape.empty()) { - value = ReshapeOp::create(builder, loc(vardecl.loc()), - getType(varType.shape), value); - } - - // Register the value in the symbol table. - if (failed(declare(vardecl, value))) - return nullptr; - return value; - } - - /// Codegen a list of expression, return failure if one of them hit an error. - llvm::LogicalResult mlirGen(ExprASTList &blockAST) { - SymbolTableScopeT varScope(symbolTable); - for (auto &expr : blockAST) { - // Specific handling for variable declarations, return statement, and - // print. These can only appear in block list and not in nested - // expressions. - if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) - return mlir::failure(); - continue; - } - if (auto *ret = dyn_cast(expr.get())) - return mlirGen(*ret); - if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) - return mlir::success(); - continue; - } - - // Generic expression dispatch codegen. - if (!mlirGen(*expr)) - return mlir::failure(); - } - return mlir::success(); - } - - /// Build a tensor type from a list of shape dimensions. - mlir::Type getType(ArrayRef shape) { - // If the shape is empty, then this type is unranked. - if (shape.empty()) - return mlir::UnrankedTensorType::get(builder.getF64Type()); - - // Otherwise, we use the given shape. - return mlir::RankedTensorType::get(shape, builder.getF64Type()); - } - - /// Build an MLIR type from a Toy AST variable type (forward to the generic - /// getType above for non-struct types). - mlir::Type getType(const VarType &type, const Location &location) { - if (!type.name.empty()) { - auto it = structMap.find(type.name); - if (it == structMap.end()) { - emitError(loc(location)) - << "error: unknown struct type '" << type.name << "'"; - return nullptr; - } - return it->second.first; - } - - return getType(type.shape); - } -}; - -} // namespace - -namespace toy { - -// The public API for codegen. -mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { - return MLIRGenImpl(context).mlirGen(moduleAST); -} - -} // namespace toy diff --git a/mlir/ToyCombine.cpp b/mlir/ToyCombine.cpp deleted file mode 100644 index 1d8cf74..0000000 --- a/mlir/ToyCombine.cpp +++ /dev/null @@ -1,89 +0,0 @@ -//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a set of simple combiners for optimizing operations in -// the Toy dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "toy/Dialect.h" -#include "llvm/Support/Casting.h" -#include -using namespace mlir; -using namespace toy; - -namespace { -/// Include the patterns defined in the Declarative Rewrite framework. -#include "ToyCombine.inc" -} // namespace - -/// Fold constants. -OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } - -/// Fold struct constants. -OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } - -/// Fold simple struct access operations that access into a constant. -OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { - auto structAttr = - llvm::dyn_cast_if_present(adaptor.getInput()); - if (!structAttr) - return nullptr; - - size_t elementIndex = getIndex(); - return structAttr[elementIndex]; -} - -/// This is an example of a c++ rewrite pattern for the TransposeOp. It -/// optimizes the following scenario: transpose(transpose(x)) -> x -struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { - /// We register this pattern to match every toy.transpose in the IR. - /// The "benefit" is used by the framework to order the patterns and process - /// them in order of profitability. - SimplifyRedundantTranspose(mlir::MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/1) {} - - /// This method attempts to match a pattern and rewrite it. The rewriter - /// argument is the orchestrator of the sequence of rewrites. The pattern is - /// expected to interact with it to perform any changes to the IR from here. - llvm::LogicalResult - matchAndRewrite(TransposeOp op, - mlir::PatternRewriter &rewriter) const override { - // Look through the input of the current transpose. - mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = transposeInput.getDefiningOp(); - - // Input defined by another transpose? If not, no match. - if (!transposeInputOp) - return failure(); - - // Otherwise, we have a redundant transpose. Use the rewriter. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return success(); - } -}; - -/// Register our patterns as "canonicalization" patterns on the TransposeOp so -/// that they can be picked up by the Canonicalization framework. -void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -/// Register our patterns as "canonicalization" patterns on the ReshapeOp so -/// that they can be picked up by the Canonicalization framework. -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} diff --git a/parser/AST.cpp b/parser/AST.cpp deleted file mode 100644 index aa2c784..0000000 --- a/parser/AST.cpp +++ /dev/null @@ -1,274 +0,0 @@ -//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the AST dump for the Toy language. -// -//===----------------------------------------------------------------------===// - -#include "toy/AST.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Twine.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" -#include - -using namespace toy; - -namespace { - -// RAII helper to manage increasing/decreasing the indentation as we traverse -// the AST -struct Indent { - Indent(int &level) : level(level) { ++level; } - ~Indent() { --level; } - int &level; -}; - -/// Helper class that implement the AST tree traversal and print the nodes along -/// the way. The only data member is the current indentation level. -class ASTDumper { -public: - void dump(ModuleAST *node); - -private: - void dump(const VarType &type); - void dump(VarDeclExprAST *varDecl); - void dump(ExprAST *expr); - void dump(ExprASTList *exprList); - void dump(NumberExprAST *num); - void dump(LiteralExprAST *node); - void dump(StructLiteralExprAST *node); - void dump(VariableExprAST *node); - void dump(ReturnExprAST *node); - void dump(BinaryExprAST *node); - void dump(CallExprAST *node); - void dump(PrintExprAST *node); - void dump(PrototypeAST *node); - void dump(FunctionAST *node); - void dump(StructAST *node); - - // Actually print spaces matching the current indentation level - void indent() { - for (int i = 0; i < curIndent; i++) - llvm::errs() << " "; - } - int curIndent = 0; -}; - -} // namespace - -/// Return a formatted string for the location of any node -template -static std::string loc(T *node) { - const auto &loc = node->loc(); - return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + - llvm::Twine(loc.col)) - .str(); -} - -// Helper Macro to bump the indentation level and print the leading spaces for -// the current indentations -#define INDENT() \ - Indent level_(curIndent); \ - indent(); - -/// Dispatch to a generic expressions to the appropriate subclass using RTTI -void ASTDumper::dump(ExprAST *expr) { - llvm::TypeSwitch(expr) - .Case([&](auto *node) { this->dump(node); }) - .Default([&](ExprAST *) { - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; - }); -} - -/// A variable declaration is printing the variable name, the type, and then -/// recurse in the initializer value. -void ASTDumper::dump(VarDeclExprAST *varDecl) { - INDENT(); - llvm::errs() << "VarDecl " << varDecl->getName(); - dump(varDecl->getType()); - llvm::errs() << " " << loc(varDecl) << "\n"; - if (auto *initVal = varDecl->getInitVal()) - dump(initVal); -} - -/// A "block", or a list of expression -void ASTDumper::dump(ExprASTList *exprList) { - INDENT(); - llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) - dump(expr.get()); - indent(); - llvm::errs() << "} // Block\n"; -} - -/// A literal number, just print the value. -void ASTDumper::dump(NumberExprAST *num) { - INDENT(); - llvm::errs() << num->getValue() << " " << loc(num) << "\n"; -} - -/// Helper to print recursively a literal. This handles nested array like: -/// [ [ 1, 2 ], [ 3, 4 ] ] -/// We print out such array with the dimensions spelled out at every level: -/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -static void printLitHelper(ExprAST *litOrNum) { - // Inside a literal expression we can have either a number or another literal - if (auto *num = llvm::dyn_cast(litOrNum)) { - llvm::errs() << num->getValue(); - return; - } - auto *literal = llvm::cast(litOrNum); - - // Print the dimension for this literal first - llvm::errs() << "<"; - llvm::interleaveComma(literal->getDims(), llvm::errs()); - llvm::errs() << ">"; - - // Now print the content, recursing on every element of the list - llvm::errs() << "[ "; - llvm::interleaveComma(literal->getValues(), llvm::errs(), - [&](auto &elt) { printLitHelper(elt.get()); }); - llvm::errs() << "]"; -} - -/// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *node) { - INDENT(); - llvm::errs() << "Literal: "; - printLitHelper(node); - llvm::errs() << " " << loc(node) << "\n"; -} - -/// Print a struct literal. -void ASTDumper::dump(StructLiteralExprAST *node) { - INDENT(); - llvm::errs() << "Struct Literal: "; - for (auto &value : node->getValues()) - dump(value.get()); - indent(); - llvm::errs() << " " << loc(node) << "\n"; -} - -/// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *node) { - INDENT(); - llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; -} - -/// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *node) { - INDENT(); - llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) - return dump(*node->getExpr()); - { - INDENT(); - llvm::errs() << "(void)\n"; - } -} - -/// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *node) { - INDENT(); - llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; - dump(node->getLHS()); - dump(node->getRHS()); -} - -/// Print a call expression, first the callee name and the list of args by -/// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *node) { - INDENT(); - llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) - dump(arg.get()); - indent(); - llvm::errs() << "]\n"; -} - -/// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *node) { - INDENT(); - llvm::errs() << "Print [ " << loc(node) << "\n"; - dump(node->getArg()); - indent(); - llvm::errs() << "]\n"; -} - -/// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(const VarType &type) { - llvm::errs() << "<"; - if (!type.name.empty()) - llvm::errs() << type.name; - else - llvm::interleaveComma(type.shape, llvm::errs()); - llvm::errs() << ">"; -} - -/// Print a function prototype, first the function name, and then the list of -/// parameters names. -void ASTDumper::dump(PrototypeAST *node) { - INDENT(); - llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; - indent(); - llvm::errs() << "Params: ["; - llvm::interleaveComma(node->getArgs(), llvm::errs(), - [](auto &arg) { llvm::errs() << arg->getName(); }); - llvm::errs() << "]\n"; -} - -/// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *node) { - INDENT(); - llvm::errs() << "Function \n"; - dump(node->getProto()); - dump(node->getBody()); -} - -/// Print a struct. -void ASTDumper::dump(StructAST *node) { - INDENT(); - llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n"; - - { - INDENT(); - llvm::errs() << "Variables: [\n"; - for (auto &variable : node->getVariables()) - dump(variable.get()); - indent(); - llvm::errs() << "]\n"; - } -} - -/// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *node) { - INDENT(); - llvm::errs() << "Module:\n"; - for (auto &record : *node) { - if (FunctionAST *function = llvm::dyn_cast(record.get())) - dump(function); - else if (StructAST *str = llvm::dyn_cast(record.get())) - dump(str); - else - llvm::errs() << "getKind() << ">\n"; - } -} - -namespace toy { - -// Public API -void dump(ModuleAST &module) { ASTDumper().dump(&module); } - -} // namespace toy diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..862831e --- /dev/null +++ b/readme.md @@ -0,0 +1,95 @@ +# +**MLP-MLIR** – An MLIR-based Compiler Infrastructure for Neural Networks + +![MLIR](https://img.shields.io/badge/MLIR-LLVM-blue) +![C++](https://img.shields.io/badge/C%2B%2B-17-green) +![Status](https://img.shields.io/badge/Status-Research%20Prototype-orange) + +--- + +## Requirements + +**llvm-mlir** use branch release/22.x + +--- +## Overview + +**MLP-MLIR** is an experimental compiler framework built on LLVM MLIR that focuses on neural network representation, transformation, and lowering. The project currently targets **Multi-Layer Perceptrons (MLPs)** and provides a foundation for future Transformer and LLM support. + +The current implementation targets **Multi-Layer Perceptrons (MLPs)** and demonstrates: + +- Programmatic construction of MLIR IR in C++ +- Custom dialect operations for neural network primitives +- Tensor-level computation and inspection +- A standalone MLIR execution pipeline + +The project is designed with a **long-term roadmap** toward **Transformers, Large Language Models (LLMs), and heterogeneous hardware backends**. + +--- + +## Goals + +- ✅ Provide a clean MLIR-based representation for **MLP workloads** +- ✅ Experiment with **custom neural network dialect extensions** +- 🔜 Extend to **Transformer architectures** +- 🔜 Support **LLM-scale graphs** +- 🔜 Enable **multi-hardware lowering** (CPU, GPU, accelerators) +- 🔜 Explore **optimization passes and graph-level rewrites** + +--- + +## Why MLIR? + +MLIR provides: +- Multi-level abstraction (tensor → linalg → LLVM) +- Dialect extensibility +- Hardware-agnostic IR design +- First-class support for compiler transformations + +This makes it ideal for **machine learning compilers** that must evolve across: +- models +- hardware targets +- optimization strategies + +--- + +## Current Capabilities + +### Implemented +- MLP-style operations (Add, Mul, ReLU) +- Tensor constants and elementwise ops +- Print operations for debugging +- MLIR IR construction via C++ builders +- Standalone MLIR driver + +### In Progress +- Shape-aware ops +- TOSA / Linalg interoperability +- Type consistency (f32 / f64) +- Modular builder separation + +--- + +## Roadmap + +### Phase 1 — MLP Foundations (Current) +- Custom ops +- Tensor algebra +- MLIR builder utilities + +### Phase 2 — Transformer Support +- Linear layers +- Attention mechanisms +- Softmax, LayerNorm +- Sequence modeling + +### Phase 3 — LLM-Scale Compilation +- Graph-level optimizations +- Memory planning +- Operator fusion + +### Phase 4 — Multi-Hardware Lowering +- CPU (LLVM) +- GPU (NVVM / ROCm) +- Accelerator targets +- Backend-specific optimization passes diff --git a/results/affine_lowered.png b/results/affine_lowered.png new file mode 100644 index 0000000..d3da261 Binary files /dev/null and b/results/affine_lowered.png differ diff --git a/results/mlp_dialect.png b/results/mlp_dialect.png new file mode 100644 index 0000000..94a6078 Binary files /dev/null and b/results/mlp_dialect.png differ diff --git a/results/mlp_to_linalg.png b/results/mlp_to_linalg.png new file mode 100644 index 0000000..f0eb2ba Binary files /dev/null and b/results/mlp_to_linalg.png differ diff --git a/results/relu-to-linalg.png b/results/relu-to-linalg.png new file mode 100644 index 0000000..37daa80 Binary files /dev/null and b/results/relu-to-linalg.png differ diff --git a/scripts/build.sh b/scripts/build.sh new file mode 100755 index 0000000..cd60e4a --- /dev/null +++ b/scripts/build.sh @@ -0,0 +1,4 @@ +cd build + + +cmake .. && make && ./standalone_mlir \ No newline at end of file diff --git a/scripts/git-push.sh b/scripts/git-push.sh new file mode 100755 index 0000000..946163d --- /dev/null +++ b/scripts/git-push.sh @@ -0,0 +1,3 @@ +git add -A +git commit -m "refactor" # Add msg +git push --force \ No newline at end of file diff --git a/src/Builder.cpp b/src/Builder.cpp new file mode 100644 index 0000000..253930f --- /dev/null +++ b/src/Builder.cpp @@ -0,0 +1,458 @@ +#include "Builder.h" +#include "Dialect.h" +#include "Passes.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Transforms/DialectConversion.h" +#include + +#ifdef PRINT +#define PRINT +#endif +// #define PRINT + +using namespace mlir; +// using namespace dbs; +using namespace mlir::mlp; + +namespace builder +{ +/** + * createMainFunction + * + * Construct a top-level "main" function inside the provided MLIR module. + * + * Behavior: + * - Creates a func::FuncOp named "main" with an empty function type (no + * arguments, no results). + * - Adds the function to the supplied ModuleOp and inserts a new entry block. + * - Establishes an OpBuilder insertion point at the start of the entry block. + * - Emits two f32 constants (1.0 and 2.0) into the function body. + * - Emits a func::ReturnOp to terminate the function. + * + * Conditional behavior (enabled when PRINT is defined at compile time): + * - Declares a helper function "print_f32" (f32 -> void) at module scope and + * marks it private. + * - Attempts to emit calls to "print_f32" with the results of intermediate + * operations (intended to print computed f32 values). Note: those intermediate + * values (e.g. add, mul) must exist in the function body for the print calls + * to be valid. + * + * Parameters: + * - ctx: The MLIRContext used to construct operations and types. + * - module: The ModuleOp into which the "main" function will be inserted. + * + * Side effects: + * - Mutates the provided ModuleOp by appending the newly created FuncOp and, + * when PRINT is defined, by inserting a private "print_f32" declaration. + * - Creates operations and IR in the provided MLIRContext. + * + * Returns: + * - The created func::FuncOp corresponding to "main". + * + * Notes: + * - All created operations use builder.getUnknownLoc() for locations. + * - The function currently has no arguments and returns nothing; callers that + * expect different signatures should modify the function type accordingly. + */ + func::FuncOp createMainFunction(MLIRContext &ctx, ModuleOp module) + { + + mlir::OpBuilder builder(&ctx); + + [[maybe_unused]] auto f32 = builder.getF32Type(); + auto funcType = builder.getFunctionType({}, {}); + + // auto func = + // func::FuncOp::create(builder,builder.getUnknownLoc(), "main", + // funcType); + auto func = + mlir::func::FuncOp::create(builder.getUnknownLoc(), "main", funcType); + + module.push_back(func); + + auto *entry = func.addEntryBlock(); + builder.setInsertionPointToStart(entry); + +#ifdef PRINT + // ---- Declare print function ---- + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto printType = builder.getFunctionType({f32}, {}); + builder + .create(builder.getUnknownLoc(), "print_f32", printType) + .setPrivate(); + } +#endif + + Value c1 = arith::ConstantOp::create(builder, builder.getUnknownLoc(), + builder.getF32FloatAttr(1.0)); + Value c2 = arith::ConstantOp::create(builder, builder.getUnknownLoc(), + builder.getF32FloatAttr(2.0)); + + // [[maybe_unused]] Value add = builder + // .create(builder.getUnknownLoc(), "my_mul", + // builder.getF32Type(), ValueRange{c1, c2}) ->getResult(0); + // Value sum = builder.create(builder.getUnknownLoc(), "print_f32", + TypeRange{}, ValueRange{add}); +#endif + + // // --- Call my_mul --- + // Value c3 = arith::ConstantOp::create(builder,builder.getUnknownLoc(), + // builder.getF32FloatAttr(3.0)); + // Value c4 = arith::ConstantOp::create(builder,builder.getUnknownLoc(), + // builder.getF32FloatAttr(4.0)); + + // [[maybe_unused]] Value mul = + // builder + // .create(builder.getUnknownLoc(), "my_mul", + // builder.getF32Type(), ValueRange{c3, c4}) + // ->getResult(0); +#ifdef PRINT + // Print mul result + builder.create(builder.getUnknownLoc(), "print_f32", + TypeRange{}, ValueRange{mul}); +#endif + + // auto c5 = arith::ConstantOp::create(builder,builder.getUnknownLoc(), + // builder.getF32FloatAttr(1.0)); + // auto c6 = arith::ConstantOp::create(builder,builder.getUnknownLoc(), + // builder.getF32FloatAttr(2.0)); + + // // Use your custom dialect op + // auto myAdd = builder.create f32 + // --------------------------------------------- + func::FuncOp createMulFunction(MLIRContext &ctx, ModuleOp module) + { + OpBuilder builder(&ctx); + + auto f32 = builder.getF32Type(); + auto funcType = builder.getFunctionType({f32, f32}, {f32}); + + auto func = + func::FuncOp::create(builder, builder.getUnknownLoc(), "my_mul", funcType); + + func.setVisibility(mlir::SymbolTable::Visibility::Public); + + module.push_back(func); + + Block *entry = func.addEntryBlock(); + builder.setInsertionPointToStart(entry); + + Value a = entry->getArgument(0); + Value b = entry->getArgument(1); + + Value prod = arith::MulFOp::create(builder, builder.getUnknownLoc(), a, b); + + func::ReturnOp::create(builder, builder.getUnknownLoc(), prod); + return func; + } + + // --------------------------------------------- + // Create MLIR function: add(a : f32, b : f32) -> f32 + // --------------------------------------------- + + func::FuncOp createAddFunction(MLIRContext &ctx, ModuleOp module) + { + OpBuilder builder(&ctx); + + auto f32 = builder.getF32Type(); + auto funcType = builder.getFunctionType({f32, f32}, {f32}); + + auto func = + func::FuncOp::create(builder, builder.getUnknownLoc(), "my_add", funcType); + + func.setVisibility(mlir::SymbolTable::Visibility::Public); + + module.push_back(func); + + // Create entry block + mlir::Block *entry = func.addEntryBlock(); + builder.setInsertionPointToStart(entry); + Value a = entry->getArgument(0); + Value b = entry->getArgument(1); + + Value sum = arith::AddFOp::create(builder, builder.getUnknownLoc(), a, b); + + func::ReturnOp::create(builder, builder.getUnknownLoc(), sum); + return func; + } + + func::FuncOp createMLPAddFunction(MLIRContext &ctx, + ModuleOp module) + { // WORKING ON THIS + OpBuilder builder(&ctx); + Location loc = builder.getUnknownLoc(); + + // auto f32 = builder.getF32Type(); + auto f64 = builder.getF64Type(); + + auto rankedtensorf64Ty = RankedTensorType::get({2}, f64); + // auto rankedtensorf32Ty = RankedTensorType::get({2}, f32); + + auto funcType = builder.getFunctionType({}, {}); + + auto func = + func::FuncOp::create(builder, builder.getUnknownLoc(), "main", funcType); + + func.setVisibility(mlir::SymbolTable::Visibility::Public); + + module.push_back(func); + // Create entry block + mlir::Block *entry = func.addEntryBlock(); + builder.setInsertionPointToStart(entry); + + llvm::SmallVector vals1 = {llvm::APFloat(5.0), + llvm::APFloat(7.0)}; + llvm::SmallVector vals2 = {llvm::APFloat(10.0), + llvm::APFloat(5.0)}; + + auto denseAttr1 = mlir::DenseElementsAttr::get(rankedtensorf64Ty, vals1); + auto denseAttr2 = mlir::DenseElementsAttr::get(rankedtensorf64Ty, vals2); + + Value c1 = + mlp::ConstantOp::create(builder, loc, rankedtensorf64Ty, denseAttr1); + + Value c2 = + mlp::ConstantOp::create(builder, loc, rankedtensorf64Ty, denseAttr2); + + Value add = mlp::AddOp::create(builder, loc, rankedtensorf64Ty, c1, c2); + // Value add = builder.create vals1D = {llvm::APFloat(3.0), + llvm::APFloat(1.0)}; + llvm::SmallVector vals1 = { + llvm::APFloat(3.0), llvm::APFloat(1.0), llvm::APFloat(2.0), + llvm::APFloat(2.0)}; + llvm::SmallVector vals2 = { + llvm::APFloat(1.0), llvm::APFloat(5.0), llvm::APFloat(5.0), + llvm::APFloat(2.0)}; + + int64_t N = vals1D.size(); + auto tensor1DTy = RankedTensorType::get({1, N}, f64); // Batch is 1 + // auto tensor1DTy = RankedTensorType::get({1,N}, f64); // Batch is 1 + auto denseAttr1D = mlir::DenseElementsAttr::get(tensor1DTy, vals1D); + Value c1D = mlp::ConstantOp::create(builder, loc, tensor1DTy, denseAttr1D); + + auto denseAttr1 = mlir::DenseElementsAttr::get(rankedtensorfTy, vals1); + auto denseAttr2 = mlir::DenseElementsAttr::get(rankedtensorfTy, vals2); + + Value c1 = mlp::ConstantOp::create(builder, loc, rankedtensorfTy, denseAttr1); + Value c2 = mlp::ConstantOp::create(builder, loc, rankedtensorfTy, denseAttr2); + // mlp::PrintOp::create(builder, loc, c1); + // mlp::PrintOp::create(builder, loc, c2); + + Value lin = mlp::LinearOp::create(builder, loc, rankedtensorfTy, c1, c2); + + // mlp::PrintOp::create(builder, loc, lin); + + // Value relu = builder.create vals1 = {llvm::APFloat(5.0), + // // llvm::APFloat(7.0)}; + // // llvm::SmallVector vals2 = {llvm::APFloat(10.0), + // // llvm::APFloat(5.0)}; + + // // auto denseAttr1 = mlir::DenseElementsAttr::get(rankedtensorTy, vals1); + // // auto denseAttr2 = mlir::DenseElementsAttr::get(rankedtensorTy, vals2); + // std::vector vals1 = {5, 7}; + // std::vector vals2 = {10, 5}; + + // // Convert to FloatAttr + // llvm::SmallVector attrs1; + // for (auto v : vals1) + // attrs1.push_back(builder.getFloatAttr(f64, v)); + + // // llvm::SmallVector attrs2; + // // for (auto v : vals2) + // // attrs2.push_back(builder.getFloatAttr(i32, v)); + + // auto denseAttr1 = mlir::DenseElementsAttr::get(rankedtensor64Ty, attrs1); + // // auto denseAttr2 = mlir::DenseElementsAttr::get(rankedtensorTy, attrs2); + + // ///////////////////////////// + // // 1) f64 tensor constant + // auto denseAttr5 = DenseElementsAttr::get(rankedtensor64Ty, attrs1); + + // // Value c64 = + // // mlp::ConstantOp::create(builder,loc, rankedtensor64Ty, denseAttr5); + + // // // 2) cast f64 -> f32 (tensor-level) + // // Value c32 = + // // builder.create(loc, rankedtensor32Ty, c64); + + // // // 3) use f32 result + // // mlp::PrintOp::create(builder,loc, c32); + + // ///////////////////////////// + + // // Value c1 = mlp::ConstantOp::create(builder,loc, rankedtensor32Ty, + // // denseAttr1); + + // // Value c2 = builder.create attrs64; + // for (float v : vals1) + // attrs64.push_back(builder.getFloatAttr(f64, v)); + + // auto dense32 = DenseElementsAttr::get(t64, attrs64); + // Value c32 = mlp::ConstantOp::create(builder,loc, t32, dense32); + // mlp::PrintOp::create(builder,builder.getUnknownLoc(), c32); + // builder.create vals1 = { + llvm::APFloat(5.0), llvm::APFloat(-7.0), llvm::APFloat(7.0), + llvm::APFloat(10.0)}; + + auto denseAttr1 = mlir::DenseElementsAttr::get(rankedtensorTy, vals1); + + Value c11 = mlp::ConstantOp::create(builder, builder.getUnknownLoc(), + rankedtensorTy, denseAttr1); + + Value relu = + mlp::ReluOp::create(builder, builder.getUnknownLoc(), rankedtensorTy, c11); + + mlp::PrintOp::create(builder, builder.getUnknownLoc(), c11); + + mlp::PrintOp::create(builder, builder.getUnknownLoc(), relu); + + func::ReturnOp::create(builder, builder.getUnknownLoc()); + return func; + } + +} // namespace builder \ No newline at end of file diff --git a/src/Dialect.cpp b/src/Dialect.cpp new file mode 100644 index 0000000..254923e --- /dev/null +++ b/src/Dialect.cpp @@ -0,0 +1,767 @@ +//===- Dialect.cpp - mlp dialect (only add) -----------------------------===// +// +// Minimal mlp dialect implementation with a single AddOp. +// +//===----------------------------------------------------------------------===// + +#include "Dialect.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::mlp; + +/// Generated dialect definitions (MLPDialect, etc.). +#include "Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// MlpInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Mlp +/// operations. +struct MlpInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All call operations within mlp can be inlined. + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + /// All operations within mlp can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + // All functions within mlp can be inlined. + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(mlp.return) by replacing it with a new + /// operation as necessary. + // void handleTerminator(Operation *op, + // ArrayRef valuesToRepl) const final { + // // Only "mlp.return" needs to be handled here. + // auto returnOp = cast(op); + + // // Replace the values directly with the return operands. + // assert(returnOp.getNumOperands() == valuesToRepl.size()); + // for (const auto &it : llvm::enumerate(returnOp.getOperands())) + // valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + // } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + // Operation *materializeCallConversion(OpBuilder &builder, Value input, + // Type resultType, + // Location conversionLoc) const final { + // return builder.create(conversionLoc, resultType, input); + // } +}; + +/// Generated op method definitions (for AddOp). +#define GET_OP_CLASSES +#include "Ops.cpp.inc" +#undef GET_OP_CLASSES + +/// Dialect initialization: register AddOp only. +void MLPDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "Ops.cpp.inc" + // #undef GET_OP_LIST + >(); + // addInterfaces(); + // addTypes(); +} + +// void MlpDialect::initialize() { +// addOperations< +// #define GET_OP_LIST +// #include "Ops.cpp.inc" +// >(); +// addInterfaces(); +// addTypes(); +// } + +// mlir::Operation *MlpDialect::materializeConstant(mlir::OpBuilder &builder, +// mlir::Attribute value, +// mlir::Type type, +// mlir::Location loc) { +// return builder.create(loc, type, +// llvm::cast(value)); +// } + +/// A generalized parser for binary operations. This parses the different forms +/// of 'printBinaryOp' below. + +//===----------------------------------------------------------------------===// +// MLPDialect type parsing / printing +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// MLPDialect type parsing / printing +//===----------------------------------------------------------------------===// + +mlir::Type MLPDialect::parseType(mlir::DialectAsmParser &parser) const { + // If you do NOT want custom types, just reject all: + parser.emitError(parser.getCurrentLocation(), + "mlp dialect has no custom types"); + return Type(); +} + +void MLPDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // We should never be asked to print a mlp-specific type in this minimal + // setup. + llvm_unreachable("mlp dialect has no custom types to print"); +} + +//===----------------------------------------------------------------------===// +// MLPDialect constant materializer +//===----------------------------------------------------------------------===// + +// mlir::Operation *MLPDialect::materializeConstant(mlir::OpBuilder &builder, +// mlir::Attribute value, +// mlir::Type type, +// mlir::Location loc) { +// // If you do not use mlp.constant anymore, just return nullptr. +// return nullptr; +// } + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace mlp { + + //===----------------------------------------------------------------------===// +// Helpers for unary ops +//===----------------------------------------------------------------------===// + +/// Parse a unary operation: one operand, one result. +static mlir::ParseResult parseUnaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand operand; + mlir::Type type; + if (parser.parseOperand(operand) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + if (parser.resolveOperand(operand, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// Print a unary operation +static void printUnaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperand(0); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : " << *op->result_type_begin(); +} + +/// Parse a unary op with an extra float attribute (for LeakyRelu, ELU, etc.) +static mlir::ParseResult parseUnaryOpWithAttr(mlir::OpAsmParser &parser, + mlir::OperationState &result, + StringRef attrName) { + mlir::OpAsmParser::UnresolvedOperand operand; + mlir::Type type; + mlir::Attribute attr; + if (parser.parseOperand(operand) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.parseAttribute(attr, attrName, result.attributes)) + return mlir::failure(); + + if (parser.resolveOperand(operand, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +/// Print unary op with extra attribute +static void printUnaryOpWithAttr(mlir::OpAsmPrinter &printer, + mlir::Operation *op, StringRef attrName) { + printer << " " << op->getOperand(0); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : " << *op->result_type_begin(); +} + +static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + + SmallVector operands; + SMLoc operandsLoc = parser.getCurrentLocation(); + Type type; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type)) + return mlir::failure(); + + // If the type is a function type, it contains the input and result types of + // this operation. + if (FunctionType funcType = llvm::dyn_cast(type)) { + if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, + result.operands)) + return mlir::failure(); + result.addTypes(funcType.getResults()); + return mlir::success(); + } + + // Otherwise, the parsed type is the type of both operands and results. + if (parser.resolveOperands(operands, type, result.operands)) + return mlir::failure(); + result.addTypes(type); + return mlir::success(); +} + +// / A generalized printer for binary operations. It prints in two different +// / forms depending on if all of the types match. +static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { + printer << " " << op->getOperands(); + printer.printOptionalAttrDict(op->getAttrs()); + printer << " : "; + + // If all of the types are the same, print the type directly. + Type resultType = *op->result_type_begin(); + if (llvm::all_of(op->getOperandTypes(), + [=](Type type) { return type == resultType; })) { + printer << resultType; + return; + } + + // Otherwise, print a functional type. + printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +// void AddOp::build(OpBuilder &builder, OperationState &state, Value lhs, Value +// rhs) { +// auto resultType = lhs.getType().cast(); +// state.addOperands({lhs, rhs}); +// state.addTypes(resultType); +// } + +void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({lhs, rhs}); +} + +ParseResult AddOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBinaryOp(parser, result); +} + +void AddOp::print(OpAsmPrinter &p) { printBinaryOp(p, *this); } + +// void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } + +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +void mlir::mlp::LinearOp::build(OpBuilder &builder, OperationState &state, + Value lhs, Value rhs) { + + state.addTypes(lhs.getType()); + state.addOperands({lhs, rhs}); +} + +ParseResult LinearOp::parse(OpAsmParser &parser, OperationState &result) { + return parseBinaryOp(parser, result); +} + +void LinearOp::print(OpAsmPrinter &p) { printBinaryOp(p, *this); } + + + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +void ReluOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult ReluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void ReluOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// LeakyReluOp +//===----------------------------------------------------------------------===// + +void LeakyReluOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult LeakyReluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOpWithAttr(parser, result, "alpha"); +} + +void LeakyReluOp::print(OpAsmPrinter &p) { printUnaryOpWithAttr(p, *this, "alpha"); } + +//===----------------------------------------------------------------------===// +// EluOp +//===----------------------------------------------------------------------===// + +void EluOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult EluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOpWithAttr(parser, result, "alpha"); +} + +void EluOp::print(OpAsmPrinter &p) { printUnaryOpWithAttr(p, *this, "alpha"); } + +//===----------------------------------------------------------------------===// +// SigmoidOp +//===----------------------------------------------------------------------===// + +void SigmoidOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult SigmoidOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void SigmoidOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// TanhOp +//===----------------------------------------------------------------------===// + +void TanhOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult TanhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void TanhOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// SoftmaxOp +//===----------------------------------------------------------------------===// + +void SoftmaxOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult SoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void SoftmaxOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// GeluOp +//===----------------------------------------------------------------------===// + +void GeluOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult GeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void GeluOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// SwishOp +//===----------------------------------------------------------------------===// + +void SwishOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult SwishOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void SwishOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// MishOp +//===----------------------------------------------------------------------===// + +void MishOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + state.addOperands({input}); +} + +ParseResult MishOp::parse(OpAsmParser &parser, OperationState &result) { + return parseUnaryOp(parser, result); +} + +void MishOp::print(OpAsmPrinter &p) { printUnaryOp(p, *this); } + +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::StringRef name, mlir::FunctionType type, + llvm::ArrayRef attrs) { + // FunctionOpInterface provides a convenient `build` method that will populate + // the state of our FuncOp, and create an entry block. + buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); +} + +mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + // Dispatch to the FunctionOpInterface provided utility method that parses the + // function operation. + auto buildFuncType = + [](mlir::Builder &builder, llvm::ArrayRef argTypes, + llvm::ArrayRef results, + mlir::function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return mlir::function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(mlir::OpAsmPrinter &p) { + // Dispatch to the FunctionOpInterface provided utility method that prints the + // function operation. + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + + + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + double value) { + auto dataType = RankedTensorType::get({}, builder.getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// The 'OpAsmParser' class provides a collection of methods for parsing +/// various punctuation, as well as attributes, operands, types, etc. Each of +/// these methods returns a `ParseResult`. This class is a wrapper around +/// `LogicalResult` that can be converted to a boolean `true` value on failure, +/// or `false` on success. This allows for easily chaining together a set of +/// parser rules. These rules are used to populate an `mlir::OperationState` +/// similarly to the `build` methods described above. +mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::DenseElementsAttr value; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(value, "value", result.attributes)) + return failure(); + + result.addTypes(value.getType()); + return success(); +} + +/// The 'OpAsmPrinter' class is a stream that allows for formatting +/// strings, attributes, operands, types, etc. +void ConstantOp::print(mlir::OpAsmPrinter &printer) { + printer << " "; + printer.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"value"}); + printer << getValue(); +} + +// Verify that the given attribute value is valid for the given type. +static mlir::LogicalResult verifyConstantForType(mlir::Type type, + mlir::Attribute opaqueValue, + mlir::Operation *op) { + if (llvm::isa(type)) { + // Check that the value is an elements attribute. + auto attrValue = llvm::dyn_cast(opaqueValue); + if (!attrValue) + return op->emitError("constant of TensorType must be initialized by a " + "DenseFPElementsAttr, got ") + << opaqueValue; + + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = llvm::dyn_cast(type); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the + // constant result type. + auto attrType = llvm::cast(attrValue.getType()); + if (attrType.getRank() != resultType.getRank()) { + return op->emitOpError("return type must match the one of the attached " + "value attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op->emitOpError( + "return type shape mismatches its attribute at dimension") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); + } + + // auto resultType = llvm::cast(type); + // llvm::ArrayRef resultElementTypes = + // resultType.getElementTypes(); + + // Verify that the initializer is an Array. + // auto attrValue = llvm::dyn_cast(opaqueValue); + // if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) + // return op->emitError("constant of StructType must be initialized by an " + // "ArrayAttr with the same number of elements, got ") + // << opaqueValue; + + // Check that each of the elements are valid. + // llvm::ArrayRef attrElementValues = attrValue.getValue(); + // for (const auto it : llvm::zip(resultElementTypes, attrElementValues)) + // if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) + // return mlir::failure(); + // return mlir::success(); +} + +// Verifier for the constant operation. This corresponds to the `::verify(...)` +// in the op definition. +mlir::LogicalResult ConstantOp::verify() { + return verifyConstantForType(getResult().getType(), getValue(), *this); +} + + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +// /// Infer the output shape of the CastOp, this is required by the shape +// /// inference interface. +// void CastOp::inferShapes() { getResult().setType(getInput().getType()); } + +// /// Returns true if the given set of input and result types are compatible +// with +// /// this cast operation. This is required by the `CastOpInterface` to verify +// /// this operation and provide other additional utilities. +// bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { +// if (inputs.size() != 1 || outputs.size() != 1) +// return false; +// // The inputs must be Tensors with the same element type. +// TensorType input = llvm::dyn_cast(inputs.front()); +// TensorType output = llvm::dyn_cast(outputs.front()); +// if (!input || !output || input.getElementType() != output.getElementType()) +// return false; +// // The shape is required to match if both types are ranked. +// return !input.hasRank() || !output.hasRank() || input == output; +// } + +} // namespace mlp +} // namespace mlir + +//===----------------------------------------------------------------------===// +// mlp Types +//===----------------------------------------------------------------------===// + +// namespace mlir { +// namespace mlp { +// namespace detail { +// /// This class represents the internal storage of the mlp `StructType`. +// struct StructTypeStorage : public mlir::TypeStorage { +// /// The `KeyTy` is a required type that provides an interface for the +// storage +// /// instance. This type will be used when uniquing an instance of the type +// /// storage. For our struct type, we will unique each instance structurally +// on +// /// the elements that it contains. +// using KeyTy = llvm::ArrayRef; + +// /// A constructor for the type storage instance. +// StructTypeStorage(llvm::ArrayRef elementTypes) +// : elementTypes(elementTypes) {} + +// /// Define the comparison function for the key type with the current +// storage +// /// instance. This is used when constructing a new instance to ensure that +// we +// /// haven't already uniqued an instance of the given key. +// bool operator==(const KeyTy &key) const { return key == elementTypes; } + +// /// Define a hash function for the key type. This is used when uniquing +// /// instances of the storage, see the `StructType::get` method. +// /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type +// /// have hash functions available, so we could just omit this entirely. +// static llvm::hash_code hashKey(const KeyTy &key) { +// return llvm::hash_value(key); +// } + +// /// Define a construction function for the key type from a set of +// parameters. +// /// These parameters will be provided when constructing the storage +// instance +// /// itself. +// /// Note: This method isn't necessary because KeyTy can be directly +// /// constructed with the given parameters. +// static KeyTy getKey(llvm::ArrayRef elementTypes) { +// return KeyTy(elementTypes); +// } + +// /// Define a construction method for creating a new instance of this +// storage. +// /// This method takes an instance of a storage allocator, and an instance +// of a +// /// `KeyTy`. The given allocator must be used for *all* necessary dynamic +// /// allocations used to create the type storage and its internal. +// static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, +// const KeyTy &key) { +// // Copy the elements from the provided `KeyTy` into the allocator. +// llvm::ArrayRef elementTypes = allocator.copyInto(key); + +// // Allocate the storage instance and construct it. +// return new (allocator.allocate()) +// StructTypeStorage(elementTypes); +// } + +// /// The following field contains the element types of the struct. +// llvm::ArrayRef elementTypes; +// }; +// } // namespace detail +// } // namespace mlp +// } // namespace mlir + +/// Create an instance of a `StructType` with the given element types. There +/// *must* be at least one element type. +// StructType StructType::get(llvm::ArrayRef elementTypes) { +// assert(!elementTypes.empty() && "expected at least 1 element type"); + +// // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance +// // of this type. The first parameter is the context to unique in. The +// // parameters after the context are forwarded to the storage instance. +// mlir::MLIRContext *ctx = elementTypes.front().getContext(); +// return Base::get(ctx, elementTypes); +// } + +/// Returns the element types of this struct type. +// llvm::ArrayRef StructType::getElementTypes() { +// // 'getImpl' returns a pointer to the internal storage instance. +// return getImpl()->elementTypes; +// } + +/// Parse an instance of a type registered to the mlp dialect. +// mlir::Type MLPDialect::parseType(mlir::DialectAsmParser &parser) const { +// // Parse a struct type in the following form: +// // struct-type ::= `struct` `<` type (`,` type)* `>` + +// // NOTE: All MLIR parser function return a ParseResult. This is a +// // specialization of LogicalResult that auto-converts to a `true` boolean +// // value on failure to allow for chaining, but may be used with explicit +// // `mlir::failed/mlir::succeeded` as desired. + +// // Parse: `struct` `<` +// if (parser.parseKeyword("struct") || parser.parseLess()) +// return Type(); + +// // Parse the element types of the struct. +// SmallVector elementTypes; +// do { +// // Parse the current element type. +// SMLoc typeLoc = parser.getCurrentLocation(); +// mlir::Type elementType; +// if (parser.parseType(elementType)) +// return nullptr; + +// // Check that the type is either a TensorType or another StructType. +// if (!llvm::isa(elementType)) { +// parser.emitError(typeLoc, "element type for a struct must either " +// "be a TensorType or a StructType, got: ") +// << elementType; +// return Type(); +// } +// elementTypes.push_back(elementType); + +// // Parse the optional: `,` +// } while (succeeded(parser.parseOptionalComma())); + +// // Parse: `>` +// if (parser.parseGreater()) +// return Type(); +// return StructType::get(elementTypes); +// } \ No newline at end of file diff --git a/src/Jit.cpp b/src/Jit.cpp new file mode 100644 index 0000000..7e52d18 --- /dev/null +++ b/src/Jit.cpp @@ -0,0 +1,113 @@ + +#include "Jit.h" +#include "Dialect.h" +#include "Passes.h" + +#include "mlir/Dialect/Affine/Transforms/Passes.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +/** + * runJit + * + * JIT-compiles and executes the "main" function from the provided MLIR module. + * + * Behavior: + * - Initializes the LLVM native target and native target assembly printer. + * - Registers MLIR-to-LLVM IR translations for both builtin and LLVM dialects + * on the module's MLIRContext (must be done prior to JIT compilation). + * - Constructs an optimization transformer pipeline via makeOptimizingTransformer. + * The pipeline's optimization level is controlled by the local `enableOpt` + * flag (currently disabled by default). + * - Creates an mlir::ExecutionEngine using the module and the chosen transformer. + * The creation is asserted to succeed; if it fails the program will terminate + * due to the assert. + * - Invokes the JIT-compiled function named "main" with no arguments using + * invokePacked(). If invocation fails, an error message is written to + * llvm::errs() and the function returns -1. + * + * Parameters: + * - module: mlir::ModuleOp representing the MLIR module to JIT-compile and run. + * The module's MLIRContext (and any required dialect registrations) + * must remain valid for the duration of this call. + * + * Return value: + * - 0 on successful invocation of "main". + * - -1 if the JIT invocation fails (errors are emitted to llvm::errs()). + * + * Notes: + * - The function currently asserts on failure to construct the ExecutionEngine + * rather than returning an error code; callers should be aware that a failed + * creation will abort the process. + * - To enable compilation optimizations, set `enableOpt` to a truthy value. + */ +namespace jit +{ + int runJit(mlir::ModuleOp module) + { + + int enableOpt = false; + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Register the translation from MLIR to LLVM IR, which must happen before can + // JIT-compile. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + mlir::ExecutionEngineOptions engineOptions; + engineOptions.transformer = optPipeline; + auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invokePacked("main"); + if (invocationResult) + { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; + } + +} // namespace jit diff --git a/src/LowerToAffineLoops.cpp b/src/LowerToAffineLoops.cpp new file mode 100644 index 0000000..023c10f --- /dev/null +++ b/src/LowerToAffineLoops.cpp @@ -0,0 +1,395 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops, memref operations and standard operations. This lowering +// expects that all calls have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "Dialect.h" +#include "Passes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine Conversion Patterns +//===----------------------------------------------------------------------===// + +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) +{ + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) +{ + auto alloc = memref::AllocOp::create(rewriter, loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc->getBlock(); + alloc->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as mlp functions have no control flow. + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); + dealloc->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, + LoopIterationFn processIteration) +{ + auto tensorType = llvm::cast((*op->result_type_begin())); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create a nest of affine loops, with one loop per dimension of the shape. + // The buildAffineLoopNest function takes a callback that is used to construct + // the body of the innermost loop given a builder, a location and a range of + // loop induction variables. + SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); + SmallVector steps(tensorType.getRank(), /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, tensorType.getShape(), steps, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) + { + // Call the processing function with the rewriter + // and the loop induction variables. This function will return the value + // to store at the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); + }); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace +{ + //===----------------------------------------------------------------------===// + // MlpToAffine Conversion Patterns: Binary operations + //===----------------------------------------------------------------------===// + + template + struct BinaryOpLowering : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final + { + auto loc = op->getLoc(); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) + { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); }); + return success(); + } + }; + + using AddOpLowering = BinaryOpLowering; + // using MulOpLowering = BinaryOpLowering; + + //===----------------------------------------------------------------------===// + // MlpToAffine Conversion Patterns: Constant operations + //===----------------------------------------------------------------------===// + + struct ConstantOpLowering : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final + { + DenseElementsAttr constantValue = op.getValue(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = llvm::cast(op.getType()); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + + if (!valueShape.empty()) + { + for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) + constantIndices.push_back( + arith::ConstantIndexOp::create(rewriter, loc, i)); + } + else + { + // This is the case of a tensor of rank 0. + constantIndices.push_back( + arith::ConstantIndexOp::create(rewriter, loc, 0)); + } + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.value_begin(); + std::function storeElements = [&](uint64_t dimension) + { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) + { + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) + { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return success(); + } + }; + + //===----------------------------------------------------------------------===// + // MlpToAffine Conversion Patterns: Func operations + //===----------------------------------------------------------------------===// + + struct FuncOpLowering : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final + { + // We only lower the main function as we expect that all other functions + // have been inlined. + if (op.getName() != "main") + return failure(); + + // Verify that the given main has no inputs and results. + if (op.getNumArguments() || op.getFunctionType().getNumResults()) + { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) + { diag << "expected 'main' to have 0 inputs and 0 results"; }); + } + + // Create a new non-mlp function, with the same region. + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); + rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); + rewriter.eraseOp(op); + return success(); + } + }; + + //===----------------------------------------------------------------------===// + // MlpToAffine Conversion Patterns: Print operations + //===----------------------------------------------------------------------===// + + struct PrintOpLowering : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final + { + // We don't lower "mlp.print" in this pass, but we need to update its + // operands. + rewriter.modifyOpInPlace(op, + [&] + { op->setOperands(adaptor.getOperands()); }); + return success(); + } + }; + + // ===----------------------------------------------------------------------===// + // MlpToAffine Conversion Patterns: Return operations + // ===----------------------------------------------------------------------===// + + // struct ReturnOpLowering : public OpConversionPattern { + // using OpConversionPattern::OpConversionPattern; + + // LogicalResult + // matchAndRewrite(mlp::ReturnOp op, OpAdaptor adaptor, + // ConversionPatternRewriter &rewriter) const final { + // // During this lowering, we expect that all function calls have been + // // inlined. + // if (op.hasOperand()) + // return failure(); + + // // We lower "mlp.return" directly to "func.return". + // rewriter.replaceOpWithNewOp(op); + // return success(); + // } + // }; + + // ===----------------------------------------------------------------------===// + // MlpToAffine Conversion Patterns: Transpose operations + // ===----------------------------------------------------------------------===// + + // struct TransposeOpLowering : public OpConversionPattern { + // using OpConversionPattern::OpConversionPattern; + + // LogicalResult + // matchAndRewrite(mlp::TransposeOp op, OpAdaptor adaptor, + // ConversionPatternRewriter &rewriter) const final { + // auto loc = op->getLoc(); + // lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) + // { + // Value input = adaptor.getInput(); + + // // Transpose the elements by generating a load from the + // // reverse indices. + // SmallVector reverseIvs(llvm::reverse(loopIvs)); + // return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + // }); + // return success(); + // } + // }; + +} // namespace + +//===----------------------------------------------------------------------===// +// MlpToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the mlp operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Mlp dialect. +namespace +{ + struct MlpToAffineLoweringPass + : public PassWrapper> + { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MlpToAffineLoweringPass) + StringRef getArgument() const override { return "mlp-to-affine"; } + + void getDependentDialects(DialectRegistry ®istry) const override + { + registry.insert(); + } + void runOnOperation() final; + }; +} // namespace + +void MlpToAffineLoweringPass::runOnOperation() +{ + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine`, `Arith`, `Func`, and `MemRef` dialects. + target.addLegalDialect(); + + // We also define the Mlp dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Mlp operations that don't want + // to lower, `mlp.print`, as `legal`. `mlp.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. + target.addIllegalDialect(); + target.addDynamicallyLegalOp([](mlp::PrintOp op) + { return llvm::none_of(op->getOperandTypes(), + [](Type type) + { return llvm::isa(type); }); }); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Mlp operations. + RewritePatternSet patterns(&getContext()); + patterns + .add< + AddOpLowering, ConstantOpLowering, FuncOpLowering, PrintOpLowering /*,MulOpLowering, ReturnOpLowering, TransposeOpLowering*/>( + &getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} + +// Create a pass for lowering operations in the `Affine` and `Std` dialects, for +// a subset of the Mlp IR (e.g. matmul). +std::unique_ptr mlir::mlp::createLowerToAffinePass() +{ + return std::make_unique(); +} diff --git a/mlir/LowerToLLVM.cpp b/src/LowerToLLVM.cpp similarity index 90% rename from mlir/LowerToLLVM.cpp rename to src/LowerToLLVM.cpp index 8b48a8f..e28b779 100644 --- a/mlir/LowerToLLVM.cpp +++ b/src/LowerToLLVM.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// // // This file implements full lowering of Toy operations to LLVM MLIR dialect. -// 'toy.print' is lowered to a loop nest that calls `printf` on each element of -// the input array. The file also sets up the ToyToLLVMLoweringPass. This pass +// 'mlp.print' is lowered to a loop nest that calls `printf` on each element of +// the input array. The file also sets up the MlpToLLVMLoweringPass. This pass // lowers the combination of Arithmetic + Affine + SCF + Func dialects to the // LLVM one: // @@ -18,10 +18,12 @@ // Arithmetic + Func --> LLVM (Dialect) // ^ // | -// 'toy.print' --> Loop (SCF) -- +// 'mlp.print' --> Loop (SCF) -- // //===----------------------------------------------------------------------===// +#include "Dialect.h" +#include "Passes.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -29,8 +31,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -59,14 +59,14 @@ using namespace mlir; //===----------------------------------------------------------------------===// namespace { -/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual +/// Lowers `mlp.print` to a loop nest calling `printf` on each of the individual /// elements of the array. -class PrintOpLowering : public OpConversionPattern { +class PrintOpLowering : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + matchAndRewrite(mlp::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast((*op->operand_type_begin())); @@ -174,14 +174,14 @@ class PrintOpLowering : public OpConversionPattern { } // namespace //===----------------------------------------------------------------------===// -// ToyToLLVMLoweringPass +// MlpToLLVMLoweringPass //===----------------------------------------------------------------------===// namespace { -struct ToyToLLVMLoweringPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass) - StringRef getArgument() const override { return "toy-to-llvm"; } +struct MlpToLLVMLoweringPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MlpToLLVMLoweringPass) + StringRef getArgument() const override { return "mlp-to-llvm"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -190,7 +190,7 @@ struct ToyToLLVMLoweringPass }; } // namespace -void ToyToLLVMLoweringPass::runOnOperation() { +void MlpToLLVMLoweringPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. @@ -206,7 +206,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { // Now that the conversion target has been defined, we need to provide the // patterns used for lowering. At this point of the compilation process, we - // have a combination of `toy`, `affine`, and `std` operations. Luckily, there + // have a combination of `mlp`, `affine`, and `std` operations. Luckily, there // are already exists a set of patterns to transform `affine` and `std` // dialects. These patterns lowering in multiple stages, relying on transitive // lowerings. Transitive lowering, or A->B->C lowering, is when multiple @@ -221,7 +221,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns); - // The only remaining operation to lower from the `toy` dialect, is the + // The only remaining operation to lower from the `mlp` dialect, is the // PrintOp. patterns.add(&getContext()); @@ -234,6 +234,6 @@ void ToyToLLVMLoweringPass::runOnOperation() { /// Create a pass for lowering operations the remaining `Toy` operations, as /// well as `Affine` and `Std`, to the LLVM dialect for codegen. -std::unique_ptr mlir::toy::createLowerToLLVMPass() { - return std::make_unique(); +std::unique_ptr mlir::mlp::createLowerToLLVMPass() { + return std::make_unique(); } diff --git a/src/LowerToLinalg.cpp b/src/LowerToLinalg.cpp new file mode 100644 index 0000000..611deeb --- /dev/null +++ b/src/LowerToLinalg.cpp @@ -0,0 +1,419 @@ +#include "Dialect.h" +#include "Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include +#include + +using namespace mlir; + +namespace +{ + + //===----------------------------------------------------------------------===// + // MLPTolinalg RewritePatterns: Constant operations + //===----------------------------------------------------------------------===// + + // struct ConstantOpToArith + // : public mlir::OpConversionPattern { + // using OpConversionPattern::OpConversionPattern; + + // mlir::LogicalResult + // matchAndRewrite(mlir::mlp::ConstantOp op, OpAdaptor adaptor, + // mlir::ConversionPatternRewriter &rewriter) const override { + + // auto attr = llvm::dyn_cast(op.getValue()); + // if (!attr) + // return rewriter.notifyMatchFailure(op, "expected DenseElementsAttr"); + + // // Convert the result type via the type converter + // auto resultType = getTypeConverter()->convertType(op.getType()); + // if (!resultType) + // return failure(); + + // rewriter.replaceOpWithNewOp(op, resultType, + // attr); + + // return mlir::success(); + // } + // }; + + struct ConstantOpToArith : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::ConstantOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override + { + auto attr = llvm::dyn_cast(op.getValue()); + + if (!attr) + return failure(); + rewriter.replaceOpWithNewOp(op, op.getType(), attr); + return success(); + } + }; + + //===----------------------------------------------------------------------===// + // MLPTolinalg RewritePatterns: Print operations + //===----------------------------------------------------------------------===// + + struct PrintOpLowering : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final + { + // We don't lower "MLP.print" in this pass, but we need to update its + // operands. + rewriter.modifyOpInPlace(op, + [&] + { op->setOperands(adaptor.getOperands()); }); + return success(); + } + }; + + struct LinearOpToLinalg : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::LinearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + + // auto resultTy = op.getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(op.getType()); + + if (!resultTy) + return rewriter.notifyMatchFailure(op, "expected ranked tensor result"); + + Value lhs = adaptor.getInput(); // %0 + Value rhs = adaptor.getWeight(); // %1 + + // Create zero-init tensor for matmul output + auto zeroAttr = rewriter.getZeroAttr(resultTy); + Value init = arith::ConstantOp::create(rewriter, loc, resultTy, zeroAttr); + + auto linear = + linalg::MatmulOp::create(rewriter, loc, + /*resultTensorTypes=*/TypeRange{resultTy}, + /*inputs=*/ValueRange{lhs, rhs}, + /*outputs=*/ValueRange{init}); + + rewriter.replaceOp(op, linear.getResult(0)); + return success(); + } + }; + + struct AddOpToLinalg : public OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::AddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + + // auto resultTy = op.getType().dyn_cast(); + auto resultTy = llvm::dyn_cast(op.getType()); + + if (!resultTy) + return rewriter.notifyMatchFailure(op, "expected ranked tensor result"); + + Value rhs = adaptor.getRhs(); // %0 + Value lhs = adaptor.getLhs(); // %1 + + // Create zero-init tensor for matmul output + auto zeroAttr = rewriter.getZeroAttr(resultTy); + Value init = arith::ConstantOp::create(rewriter, loc, resultTy, zeroAttr); + + auto add = linalg::AddOp::create(rewriter, loc, + /*resultTensorTypes=*/TypeRange{resultTy}, + /*inputs=*/ValueRange{lhs, rhs}, + /*outputs=*/ValueRange{init}); + + rewriter.replaceOp(op, add.getResult(0)); + return success(); + } + }; + + struct ReluOpToLinalg : public mlir::OpConversionPattern + { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlp::ReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op.getLoc(); + + auto resultTy = llvm::dyn_cast(op.getType()); + + if (!resultTy) + return rewriter.notifyMatchFailure(op, "expected ranked tensor"); + + Value input = adaptor.getInput(); + + // ------------------------------------------------------------ + // Create init tensor (zero-filled) + // ------------------------------------------------------------ + auto zeroAttr = rewriter.getZeroAttr(resultTy); + Value init = arith::ConstantOp::create(rewriter, loc, resultTy, zeroAttr); + + // ------------------------------------------------------------ + // Build indexing maps + // ------------------------------------------------------------ + auto identity = rewriter.getMultiDimIdentityMap(resultTy.getRank()); + // DBS_PRINT(identity); + + SmallVector indexingMaps = {identity, identity}; + // ------------------------------------------------------------ + // Iterator types (all parallel for ReLU) + // ------------------------------------------------------------ + SmallVector iteratorTypes( + resultTy.getRank(), utils::IteratorType::parallel); + + // ------------------------------------------------------------ + // Create linalg.generic + // ------------------------------------------------------------ + + auto genericOp = linalg::GenericOp::create( + rewriter, loc, + /*resultTensorTypes=*/TypeRange{resultTy}, + /*inputs=*/ValueRange{input}, + /*outputs=*/ValueRange{init}, indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) + { + Value x = args[0]; + + Value zero = arith::ConstantOp::create( + builder, loc, builder.getFloatAttr(x.getType(), 0.0)); + + Value relu = arith::MaximumFOp::create(builder, loc, x, zero); + + linalg::YieldOp::create(builder, loc, relu); + }); + + // auto genericOp = rewriter.create( + // loc, + // /*resultTensorTypes=*/TypeRange{resultTy}, + // /*inputs=*/ValueRange{input}, + // /*outputs=*/ValueRange{init}, indexingMaps, iteratorTypes, + // [&](OpBuilder &builder, Location loc, ValueRange args) { + // Value x = args[0]; + + // Value zero = builder.create( + // loc, builder.getFloatAttr(x.getType(), 0.0)); + + // Value relu = builder.create(loc, x, zero); + + // builder.create(loc, relu); + // }); + + rewriter.replaceOp(op, genericOp.getResult(0)); + return success(); + } + }; + + // struct SoftmaxToLinalg : public mlir::OpConversionPattern { + // using OpConversionPattern::OpConversionPattern; + + // LogicalResult + // matchAndRewrite(mlp::SoftmaxOp op, OpAdaptor adaptor, + // ConversionPatternRewriter &rewriter) const override { + // Location loc = op.getLoc(); + // auto resultTy = op.getType().dyn_cast(); + // if (!resultTy || resultTy.getRank() != 2) + // return rewriter.notifyMatchFailure(op, + // "Softmax expects a 1D ranked + // tensor"); + + // Value input = adaptor.getInput(); + // Type elemTy = resultTy.getElementType(); + // MLIRContext *ctx = rewriter.getContext(); + + // // 1. Create initialization tensors + // // ------------------------------------------------------------ + // auto zeroAttr = rewriter.getZeroAttr(resultTy); + // Value initVec = rewriter.create(loc, resultTy, + // zeroAttr); + + // RankedTensorType scalarTy = RankedTensorType::get({}, elemTy); + // auto zeroScalarAttr = + // DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(elemTy, 0.0)); + // Value initScalar = + // rewriter.create(loc, scalarTy, zeroScalarAttr); + + // // 2. Step 1: Compute exp(x) elementwise + // // ------------------------------------------------------------ + // auto map1D = rewriter.getMultiDimIdentityMap(resultTy.getRank()); + // SmallVector expMaps = {map1D, map1D}; + // SmallVector parallelIter = { + // utils::IteratorType::parallel, utils::IteratorType::parallel}; + // // DBS_PRINT(expMaps[0] << expMaps[1] ); + // auto expOp = rewriter.create( + // loc, resultTy, ValueRange{input}, ValueRange{initVec}, expMaps, + // parallelIter, [&](OpBuilder &b, Location loc, ValueRange args) { + // Value x = args[0]; + // Value ex = b.create(loc, x); // Compute e^x + // b.create(loc, ex); + // }); + + // // 3. Step 2: Compute Sum Reduction of exp(x) + // // ------------------------------------------------------------ + // SmallVector reductionDims = {0}; + // auto sumOp = rewriter.create( + // loc, expOp.getResult(0), initScalar, reductionDims, + // [&](OpBuilder &b, Location loc, ValueRange args) { + // Value val = args[0]; + // Value acc = args[1]; + // Value sum = b.create(loc, acc, val); + // b.create(loc, sum); + // }); + + // // 4. Step 3: Divide exp(x) by the sum (Normalization) + // // ------------------------------------------------------------ + // // The "Magic": Map scalar sum (0D) to the vector (1D) using (d0) -> () + // auto mapScalar = AffineMap::get(1, 0, ctx); // This creates (d0) -> () + // SmallVector divMaps = { + // map1D, // expOp result (1D) + // mapScalar, // sumOp result (0D broadcasted to 1D) + // map1D // Output (1D) + // }; + + // auto divOp = rewriter.create( + // loc, resultTy, ValueRange{expOp.getResult(0), sumOp.getResult(0)}, + // ValueRange{initVec}, divMaps, parallelIter, + // [&](OpBuilder &b, Location loc, ValueRange args) { + // Value ex = args[0]; + // Value totalSum = args[1]; + // Value result = b.create(loc, ex, totalSum); + // b.create(loc, result); + // }); + + // // Replace the original op with the final divided result + // rewriter.replaceOp(op, expOp.getResult(0)); + + // return success(); + // } + // }; +} // namespace + +//===----------------------------------------------------------------------===// +// MlpTolinalgLoweringPass +//===----------------------------------------------------------------------===// + +// / This is a partial lowering to linalg loops of the MLP operations that are +// / computationally intensive (like matmul for example...) while keeping the +// / rest of the code in the MLP dialect. + +//===----------------------------------------------------------------------===// +// // MLPToLinalgLoweringPass +// +//===----------------------------------------------------------------------===// + +namespace +{ + + struct MLPToLinalgLoweringPass + : public PassWrapper> + { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MLPToLinalgLoweringPass) + + StringRef getArgument() const override { return "mlp-lower-to-linalg"; } + StringRef getDescription() const override + { + return "Lower MLP dialect matmul operations to Linalg dialect"; + } + + void getDependentDialects(DialectRegistry ®istry) const override + { + registry.insert(); + } + + void runOnOperation() override + { + ConversionTarget target(getContext()); + // target.addLegalDialect(); + + target.addLegalDialect(); + ModuleOp module = getOperation(); + MLIRContext *ctx = module.getContext(); + + RewritePatternSet patterns(ctx); + + target.addIllegalDialect(); + + // target.addDynamicallyLegalOp([](mlp::PrintOp op) { + // return llvm::none_of(op->getOperandTypes(), [](Type type) { + // return llvm::isa(type); + // }); + // }); + + // target.addDynamicallyLegalOp([](mlp::PrintOp op) { + // return llvm::none_of(op->getOperandTypes(), [](Type type) { + // return llvm::isa(type); + // }); + // }); + target.addDynamicallyLegalOp([](mlp::PrintOp op) + { return llvm::none_of(op->getOperandTypes(), [](Type type) + { return llvm::isa(type); }); }); + + patterns.add(ctx); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } + }; + +} // namespace + +//===----------------------------------------------------------------------===// +// // Pass Registration +// +//===----------------------------------------------------------------------===// + +std::unique_ptr mlir::mlp::createLowerToLinalgPass() +{ + return std::make_unique(); +} \ No newline at end of file diff --git a/src/MlpCombine.cpp b/src/MlpCombine.cpp new file mode 100644 index 0000000..209852c --- /dev/null +++ b/src/MlpCombine.cpp @@ -0,0 +1,86 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// + +//===----------------------------------------------------------------------===// +// +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#include "Dialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "llvm/Support/Casting.h" +#include +using namespace mlir; +using namespace mlp; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "MlpCombine.inc" +} // namespace + +// /// Fold constants. +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + +// /// Fold struct constants. +// OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } + +// /// Fold simple struct access operations that access into a constant. +// OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { +// auto structAttr = +// llvm::dyn_cast_if_present(adaptor.getInput()); +// if (!structAttr) +// return nullptr; + +// size_t elementIndex = getIndex(); +// return structAttr[elementIndex]; +// } + +// /// This is an example of a c++ rewrite pattern for the TransposeOp. It +// /// optimizes the following scenario: transpose(transpose(x)) -> x +// struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { +// /// We register this pattern to match every mlp.transpose in the IR. +// /// The "benefit" is used by the framework to order the patterns and process +// /// them in order of profitability. +// SimplifyRedundantTranspose(mlir::MLIRContext *context) +// : OpRewritePattern(context, /*benefit=*/1) {} + +// /// This method attempts to match a pattern and rewrite it. The rewriter +// /// argument is the orchestrator of the sequence of rewrites. The pattern is +// /// expected to interact with it to perform any changes to the IR from here. +// llvm::LogicalResult +// matchAndRewrite(TransposeOp op, +// mlir::PatternRewriter &rewriter) const override { +// // Look through the input of the current transpose. +// mlir::Value transposeInput = op.getOperand(); +// TransposeOp transposeInputOp = transposeInput.getDefiningOp(); + +// // Input defined by another transpose? If not, no match. +// if (!transposeInputOp) +// return failure(); + +// // Otherwise, we have a redundant transpose. Use the rewriter. +// rewriter.replaceOp(op, {transposeInputOp.getOperand()}); +// return success(); +// } +// }; + +// /// Register our patterns as "canonicalization" patterns on the TransposeOp so +// /// that they can be picked up by the Canonicalization framework. +// void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, +// MLIRContext *context) { +// results.add(context); +// } + +// /// Register our patterns as "canonicalization" patterns on the ReshapeOp so +// /// that they can be picked up by the Canonicalization framework. +// void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, +// MLIRContext *context) { +// results.add(context); +// } diff --git a/mlir/ToyCombine.td b/src/MlpCombine.td similarity index 73% rename from mlir/ToyCombine.td rename to src/MlpCombine.td index 11d7831..e305be3 100644 --- a/mlir/ToyCombine.td +++ b/src/MlpCombine.td @@ -11,11 +11,11 @@ // //===----------------------------------------------------------------------===// -#ifndef TOY_COMBINE -#define TOY_COMBINE +#ifndef MLP_COMBINE +#define MLP_COMBINE include "mlir/IR/PatternBase.td" -include "toy/Ops.td" +include "Ops.td" /// Note: The DRR definition used for defining patterns is shown below: /// @@ -30,8 +30,8 @@ include "toy/Ops.td" //===----------------------------------------------------------------------===// // Reshape(Reshape(x)) = Reshape(x) -def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), - (ReshapeOp $arg)>; +// def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), +// (ReshapeOp $arg)>; //===----------------------------------------------------------------------===// // Pattern-Match and Rewrite using Native Code Call @@ -41,11 +41,11 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), // C++ and C++ helper functions. // Reshape(Constant(x)) = x' -def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; -def FoldConstantReshapeOptPattern : Pat< - (ReshapeOp:$res (ConstantOp $arg)), - (ConstantOp (ReshapeConstant $arg, $res))>; +// def ReshapeConstant : +// NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; +// def FoldConstantReshapeOptPattern : Pat< +// (ReshapeOp:$res (ConstantOp $arg)), +// (ConstantOp (ReshapeConstant $arg, $res))>; //===----------------------------------------------------------------------===// // Pattern-Match and Rewrite with Constraints @@ -55,9 +55,9 @@ def FoldConstantReshapeOptPattern : Pat< // on operand properties. // Reshape(x) = x, where input and output shapes are identical -def TypesAreIdentical : Constraint>; -def RedundantReshapeOptPattern : Pat< - (ReshapeOp:$res $arg), (replaceWithValue $arg), - [(TypesAreIdentical $res, $arg)]>; +// def TypesAreIdentical : Constraint>; +// def RedundantReshapeOptPattern : Pat< +// (ReshapeOp:$res $arg), (replaceWithValue $arg), +// [(TypesAreIdentical $res, $arg)]>; -#endif // TOY_COMBINE +#endif // MLP_COMBINE diff --git a/mlir/ShapeInferencePass.cpp b/src/ShapeInferencePass.cpp similarity index 88% rename from mlir/ShapeInferencePass.cpp rename to src/ShapeInferencePass.cpp index a552e1f..9ab4780 100644 --- a/mlir/ShapeInferencePass.cpp +++ b/src/ShapeInferencePass.cpp @@ -1,9 +1,5 @@ //===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// + //===----------------------------------------------------------------------===// // // This file implements a Function level pass performing interprocedural @@ -11,15 +7,15 @@ // //===----------------------------------------------------------------------===// +#include "Dialect.h" +#include "Passes.h" +#include "ShapeInferenceInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" -#include "toy/ShapeInferenceInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" @@ -30,10 +26,10 @@ #define DEBUG_TYPE "shape-inference" using namespace mlir; -using namespace toy; +using namespace mlp; /// Include the auto-generated definitions for the shape inference interfaces. -#include "toy/ShapeInferenceOpInterfaces.cpp.inc" +#include "ShapeInferenceOpInterfaces.cpp.inc" namespace { /// The ShapeInferencePass is a pass that performs intra-procedural @@ -53,9 +49,9 @@ namespace { /// 3) If the worklist is empty, the algorithm succeeded. /// struct ShapeInferencePass - : public mlir::PassWrapper> { + : public mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) - StringRef getArgument() const override { return "toy-shape-inference"; } + StringRef getArgument() const override { return "mlp-shape-inference"; } void runOnOperation() override { auto f = getOperation(); @@ -118,6 +114,6 @@ struct ShapeInferencePass } // namespace /// Create a Shape Inference pass. -std::unique_ptr mlir::toy::createShapeInferencePass() { +std::unique_ptr mlir::mlp::createShapeInferencePass() { return std::make_unique(); } diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 0000000..2edd2bd --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,527 @@ + +#include "Builder.h" +#include "Dialect.h" +#include "Jit.h" +#include "Passes.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/TosaToArith/TosaToArith.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/TosaToArith/TosaToArith.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Types.h" +#include "mlir/InitAllDialects.h" +// #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" + +#include "mlir/Dialect/Affine/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Verifier.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllDialects.h" + +#include "mlir/InitAllExtensions.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include + +using namespace mlp; +using namespace builder; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace +{ + enum InputType + { + MLP, + MLIR + }; +} // namespace +static cl::opt inputType( + "x", cl::init(MLP), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(MLP, "mlp", "load the input file as a mlp source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace +{ + enum Action + { + None, + DumpAST, + DumpMLIR, + DumpMLIRAffine, + DumpMLIRLinalg, + DumpMLIRLLVM, + DumpLLVMIR, + RunJIT + }; +} // namespace +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering")), + cl::values(clEnumValN(DumpMLIRLinalg, "mlir-linalg", + "output the MLIR dump after linalg lowering")), + cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", + "output the MLIR dump after llvm lowering")), + cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), + cl::values( + clEnumValN(RunJIT, "jit", + "JIT the code and run it by invoking the main function"))); + +static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +static int loadMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) +{ + + // CREATE MODULE FIRST + module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + + context.getOrLoadDialect(); + context.getOrLoadDialect(); + + createMLPLinearFunction(context, *module); + // createMLPAddFunction(context, *module); + // createMLPReluFunction(context, *module); + return 0; +} + +static int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) +{ + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(module.get()->getName()); + // Apply any generic pass manager command line options and run the pipeline. + if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) + return 4; + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + bool isLoweringToLinalg = emitAction >= Action::DumpMLIRLinalg; + bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + + if (enableOpt || isLoweringToAffine) + { + // Inline all functions into main and then delete them. + // pm.addPass(mlir::createInlinerPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + mlir::OpPassManager &optPM = pm.nest(); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::mlp::createShapeInferencePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + } + + // if (isLoweringToAffine) { + // // Partially lower the mlp dialect. + // pm.addPass(mlir::mlp::createLowerToAffinePass()); + + // // Add a few cleanups post lowering. + // mlir::OpPassManager &optPM = pm.nest(); + // optPM.addPass(mlir::createCanonicalizerPass()); + // optPM.addPass(mlir::createCSEPass()); + + // // Add optimizations if enabled. + // if (enableOpt) { + // optPM.addPass(mlir::affine::createLoopFusionPass()); + // optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); + // } + // } + + if (isLoweringToLinalg) + { + pm.addPass(mlir::mlp::createLowerToLinalgPass()); + + // Tensor → MemRef + pm.addPass(mlir::bufferization::createOneShotBufferizePass()); + pm.addPass(mlir::bufferization::createBufferDeallocationSimplificationPass()); + + // Linalg → loops + pm.addPass(mlir::createConvertLinalgToLoopsPass()); + + // SCF → CFG + pm.addPass(mlir::createSCFToControlFlowPass()); + + // LLVM lowering + // pm.addPass(mlir::createConvertArithToLLVMPass()); + // pm.addPass(mlir::createConvertMemRefToLLVMPass()); + // pm.addPass(mlir::createConvertFuncToLLVMPass()); + } + + // if (isLoweringToLinalg) + // { + // // Partially lower the mlp dialect. + // pm.addPass(mlir::mlp::createLowerToLinalgPass()); + + // // Add a few cleanups post lowering. + // // mlir::OpPassManager &optPM = pm.nest(); + // // optPM.addPass(mlir::createCanonicalizerPass()); + // // optPM.addPass(mlir::createCSEPass()); + + // // Add optimizations if enabled. + // if (enableOpt) + // { + // // optPM.addPass(mlir::affine::createLoopFusionPass()); + // // optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); + // } + // // 1. One-Shot Bufferization (Converts Tensor constants to MemRef globals) + + // mlir::bufferization::OneShotBufferizationOptions options; + // options.allowReturnAllocsFromLoops = true; + // //pm.addPass(mlir::bufferization::createOneShotBufferizePass()); + // // //pm.addPass(mlir::bufferization::createBufferDeallocationPass()); + // // pm.addPass(mlir::createConvertLinalgToLoopsPass()); + // // pm.addPass(mlir::createConvertSCFToCFPass()); + // // ------------------------------------------------------------ + // // pm.addPass(mlir::createConvertArithToLLVMPass()); + // // pm.addPass(mlir::createConvertMemRefToLLVMPass()); + // // pm.addPass(mlir::createConvertFuncToLLVMPass()); + // // pm.addPass(mlir::tosa::createTosaToArith()); + // llvm::errs() << "\n=== PASS PIPELINE ===\n"; + // pm.printAsTextualPipeline(llvm::errs()); + // llvm::errs() << "\n====================\n"; + + // } + + if (isLoweringToLLVM) + { + // Finish lowering the mlp IR to the LLVM dialect. + pm.addPass(mlir::mlp::createLowerToLLVMPass()); + // This is necessary to have line tables emitted and basic + // debugger working. In the future we will add proper debug information + // emission directly from our frontend. + pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); + } + + if (mlir::failed(pm.run(*module))) + return 4; + return 0; +} + +static int dumpAST() +{ + if (inputType == InputType::MLIR) + { + llvm::errs() << "Can't dump a mlp AST when the input is MLIR\n"; + return 5; + } + + // auto moduleAST = parseInputFile(inputFilename); + // if (!moduleAST) + // return 1; + + // dump(*moduleAST); + return 0; +} + +static int dumpLLVMIR(mlir::ModuleOp module) +{ + // Register the translation to LLVM IR with the MLIR context. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // Convert the module to LLVM IR in a new LLVM IR context. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); + if (!llvmModule) + { + llvm::errs() << "Failed to emit LLVM IR\n"; + return -1; + } + + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Create target machine and configure the LLVM Module + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!tmBuilderOrError) + { + llvm::errs() << "Could not create JITTargetMachineBuilder\n"; + return -1; + } + + auto tmOrError = tmBuilderOrError->createTargetMachine(); + if (!tmOrError) + { + llvm::errs() << "Could not create TargetMachine\n"; + return -1; + } + mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), + tmOrError.get().get()); + + /// Optionally run an optimization pipeline over the llvm module. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + if (auto err = optPipeline(llvmModule.get())) + { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + return -1; + } + llvm::errs() << *llvmModule << "\n"; + return 0; +} + +static int runJit(mlir::ModuleOp module) +{ + // Initialize LLVM targets. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + // Register the translation from MLIR to LLVM IR, which must happen before we + // can JIT-compile. + mlir::registerBuiltinDialectTranslation(*module->getContext()); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // An optimization pipeline to use within the execution engine. + auto optPipeline = mlir::makeOptimizingTransformer( + /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, + /*targetMachine=*/nullptr); + + // Create an MLIR execution engine. The execution engine eagerly JIT-compiles + // the module. + mlir::ExecutionEngineOptions engineOptions; + engineOptions.transformer = optPipeline; + auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); + assert(maybeEngine && "failed to construct an execution engine"); + auto &engine = maybeEngine.get(); + + // Invoke the JIT-compiled function. + auto invocationResult = engine->invokePacked("main"); + if (invocationResult) + { + llvm::errs() << "JIT invocation failed\n"; + return -1; + } + + return 0; +} + +int main(int argc, char **argv) +{ + // Register any command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + + cl::ParseCommandLineOptions(argc, argv, "mlp compiler\n"); + + if (emitAction == Action::DumpAST) + return dumpAST(); + + // If we aren't dumping the AST, then we are compiling with/to MLIR. + mlir::DialectRegistry registry; + registry + .insert(); + + + MLIRContext context(registry); + + + // Register external models for bufferization + // mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); + // mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); + // mlir::linalg::registerBufferizableOpInterfaceExternalModels(registry); + // mlir::bufferization::registerOneShotBufferizePass(); + + //mlir::registerAllDialects(registry); + + mlir::func::registerAllExtensions(registry); + mlir::LLVM::registerInlinerInterface(registry); + + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + context.loadAllAvailableDialects(); + + // context.getOrLoadDialect(); + + mlir::OwningOpRef module; + if (int error = loadAndProcessMLIR(context, module)) + return error; + + // If we aren't exporting to non-mlir, then we are done. + bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; + if (isOutputingMLIR) + { + module->dump(); + return 0; + } + + // Check to see if we are compiling to LLVM IR. + if (emitAction == Action::DumpLLVMIR) + return dumpLLVMIR(*module); + + // Otherwise, we must be running the jit. + if (emitAction == Action::RunJIT) + return runJit(*module); + + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + return -1; +} diff --git a/src/main_backup.cpp b/src/main_backup.cpp new file mode 100644 index 0000000..4477b75 --- /dev/null +++ b/src/main_backup.cpp @@ -0,0 +1,820 @@ +// //===- mlpc.cpp - The Mlp Compiler ----------------------------------------===// +// // +// // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// // See https://llvm.org/LICENSE.txt for license information. +// // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// // +// //===----------------------------------------------------------------------===// +// // +// // This file implements the entry point for the Mlp compiler. +// // +// //===----------------------------------------------------------------------===// + +// #include "Builder.h" +// #include "Dialect.h" +// #include "Jit.h" +// #include "Passes.h" +// #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +// #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +// #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +// #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +// #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +// #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +// #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +// #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +// #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +// #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +// #include "mlir/Conversion/TosaToArith/TosaToArith.h" +// #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +// #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +// #include "mlir/Dialect/Affine/IR/AffineOps.h" +// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +// #include "mlir/Dialect/Func/IR/FuncOps.h" +// #include "mlir/IR/Types.h" +// #include "mlir/InitAllDialects.h" +// // #include "mlir/Dialect/Affine/Passes.h" +// #include "mlir/Dialect/Affine/Transforms/Passes.h" +// #include "mlir/Dialect/Arith/IR/Arith.h" +// #include "mlir/Dialect/Arith/Transforms/Passes.h" +// #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +// #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +// #include "mlir/Dialect/Func/IR/FuncOps.h" +// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +// #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +// #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +// #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +// #include "mlir/Dialect/Linalg/Passes.h" +// #include "mlir/Dialect/MemRef/IR/MemRef.h" +// #include "mlir/Dialect/MemRef/Transforms/Passes.h" +// #include "mlir/Dialect/SCF/Transforms/Passes.h" +// #include "mlir/Dialect/Tosa/Transforms/Passes.h" +// #include "mlir/ExecutionEngine/ExecutionEngine.h" +// #include "mlir/ExecutionEngine/OptUtils.h" +// #include "mlir/IR/AsmState.h" +// #include "mlir/IR/Builders.h" +// #include "mlir/IR/BuiltinOps.h" +// #include "mlir/IR/DialectRegistry.h" +// #include "mlir/IR/MLIRContext.h" +// #include "mlir/IR/Verifier.h" +// #include "mlir/InitAllExtensions.h" +// #include "mlir/Parser/Parser.h" +// #include "mlir/Pass/PassManager.h" +// #include "mlir/Support/LLVM.h" +// #include "mlir/Support/LogicalResult.h" +// #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +// #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +// #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +// #include "mlir/Target/LLVMIR/Export.h" +// #include "mlir/Target/LLVMIR/ModuleTranslation.h" +// #include "mlir/Transforms/DialectConversion.h" +// #include "mlir/Transforms/Passes.h" +// #include "llvm/ADT/StringRef.h" +// #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +// #include "llvm/IR/LLVMContext.h" +// #include "llvm/IR/Module.h" +// #include "llvm/IR/Verifier.h" +// #include "llvm/Support/CommandLine.h" +// #include "llvm/Support/ErrorOr.h" +// #include "llvm/Support/MemoryBuffer.h" +// #include "llvm/Support/SourceMgr.h" +// #include "llvm/Support/TargetSelect.h" +// #include "llvm/Support/raw_ostream.h" +// #include +// #include +// #include +// #include +// #include +// #include + +// // using namespace mlp; +// namespace cl = llvm::cl; + +// // static cl::opt inputFilename(cl::Positional, +// // cl::desc(""), +// // cl::init("-"), +// // cl::value_desc("filename")); + +// // namespace { +// // enum InputType { Mlp, MLIR }; +// // } // namespace + +// // static cl::opt inputType( +// // "x", cl::init(Mlp), cl::desc("Decided the kind of output desired"), +// // cl::values(clEnumValN(Mlp, "mlp", "load the input file as a Mlp +// // source.")), cl::values(clEnumValN(MLIR, "mlir", +// // "load the input file as an MLIR file"))); + +// // namespace { +// // enum Action { +// // None, +// // DumpAST, +// // DumpMLIR, +// // DumpMLIRAffine, +// // DumpMLIRLLVM, +// // DumpLLVMIR, +// // RunJIT +// // }; +// // } // namespace +// // static cl::opt emitAction( +// // "emit", cl::desc("Select the kind of output desired"), +// // cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), +// // cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), +// // cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", +// // "output the MLIR dump after affine lowering")), +// // cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", +// // "output the MLIR dump after llvm lowering")), +// // cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), +// // cl::values( +// // clEnumValN(RunJIT, "jit", +// // "JIT the code and run it by invoking the main +// // function"))); + +// // static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); + +// // /// Returns a Mlp AST resulting from parsing the file or a nullptr on error. +// // static std::unique_ptr +// // parseInputFile(llvm::StringRef filename) { +// // llvm::ErrorOr> fileOrErr = +// // llvm::MemoryBuffer::getFileOrSTDIN(filename); +// // if (std::error_code ec = fileOrErr.getError()) { +// // llvm::errs() << "Could not open input file: " << ec.message() << "\n"; +// // return nullptr; +// // } +// // auto buffer = fileOrErr.get()->getBuffer(); +// // LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); +// // Parser parser(lexer); +// // return parser.parseModule(); +// // } + +// // static int loadMLIR(mlir::MLIRContext &context, +// // mlir::OwningOpRef &module) { +// // // Handle '.mlp' input to the compiler. +// // if (inputType != InputType::MLIR && +// // !llvm::StringRef(inputFilename).ends_with(".mlir")) { +// // auto moduleAST = parseInputFile(inputFilename); +// // if (!moduleAST) +// // return 6; +// // module = mlirGen(context, *moduleAST); +// // return !module ? 1 : 0; +// // } + +// // // Otherwise, the input is '.mlir'. +// // llvm::ErrorOr> fileOrErr = +// // llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); +// // if (std::error_code ec = fileOrErr.getError()) { +// // llvm::errs() << "Could not open input file: " << ec.message() << "\n"; +// // return -1; +// // } + +// // // Parse the input mlir. +// // llvm::SourceMgr sourceMgr; +// // sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); +// // module = mlir::parseSourceFile(sourceMgr, &context); +// // if (!module) { +// // llvm::errs() << "Error can't load file " << inputFilename << "\n"; +// // return 3; +// // } +// // return 0; +// // } + +// // static int loadAndProcessMLIR(mlir::MLIRContext &context, +// // mlir::OwningOpRef &module) { +// // if (int error = loadMLIR(context, module)) +// // return error; + +// // mlir::PassManager pm(module.get()->getName()); +// // // Apply any generic pass manager command line options and run the +// // pipeline. if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) +// // return 4; + +// // // Check to see what granularity of MLIR we are compiling to. +// // bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; +// // bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; + +// // if (enableOpt || isLoweringToAffine) { +// // // Inline all functions into main and then delete them. +// // pm.addPass(mlir::createInlinerPass()); + +// // // Now that there is only one function, we can infer the shapes of each +// // of +// // // the operations. +// // mlir::OpPassManager &optPM = pm.nest(); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::mlp::createShapeInferencePass()); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::createCSEPass()); +// // } + +// // if (isLoweringToAffine) { +// // // Partially lower the mlp dialect. +// // pm.addPass(mlir::mlp::createLowerToAffinePass()); + +// // // Add a few cleanups post lowering. +// // mlir::OpPassManager &optPM = pm.nest(); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::createCSEPass()); + +// // // Add optimizations if enabled. +// // if (enableOpt) { +// // optPM.addPass(mlir::affine::createLoopFusionPass()); +// // optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); +// // } +// // } + +// // if (isLoweringToLLVM) { +// // // Finish lowering the mlp IR to the LLVM dialect. +// // pm.addPass(mlir::mlp::createLowerToLLVMPass()); +// // // This is necessary to have line tables emitted and basic +// // // debugger working. In the future we will add proper debug information +// // // emission directly from our frontend. +// // pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); +// // } + +// // if (mlir::failed(pm.run(*module))) +// // return 4; +// // return 0; +// // } + +// // static int dumpAST() { +// // if (inputType == InputType::MLIR) { +// // llvm::errs() << "Can't dump a Mlp AST when the input is MLIR\n"; +// // return 5; +// // } + +// // auto moduleAST = parseInputFile(inputFilename); +// // if (!moduleAST) +// // return 1; + +// // dump(*moduleAST); +// // return 0; +// // } + +// // static int dumpLLVMIR(mlir::ModuleOp module) { +// // // Register the translation to LLVM IR with the MLIR context. +// // mlir::registerBuiltinDialectTranslation(*module->getContext()); +// // mlir::registerLLVMDialectTranslation(*module->getContext()); + +// // // Convert the module to LLVM IR in a new LLVM IR context. +// // llvm::LLVMContext llvmContext; +// // auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); +// // if (!llvmModule) { +// // llvm::errs() << "Failed to emit LLVM IR\n"; +// // return -1; +// // } + +// // // Initialize LLVM targets. +// // llvm::InitializeNativeTarget(); +// // llvm::InitializeNativeTargetAsmPrinter(); + +// // // Create target machine and configure the LLVM Module +// // auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); +// // if (!tmBuilderOrError) { +// // llvm::errs() << "Could not create JITTargetMachineBuilder\n"; +// // return -1; +// // } + +// // auto tmOrError = tmBuilderOrError->createTargetMachine(); +// // if (!tmOrError) { +// // llvm::errs() << "Could not create TargetMachine\n"; +// // return -1; +// // } +// // mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), +// // tmOrError.get().get()); + +// // /// Optionally run an optimization pipeline over the llvm module. +// // auto optPipeline = mlir::makeOptimizingTransformer( +// // /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, +// // /*targetMachine=*/nullptr); +// // if (auto err = optPipeline(llvmModule.get())) { +// // llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; +// // return -1; +// // } +// // llvm::errs() << *llvmModule << "\n"; +// // return 0; +// // } + +// // static int runJit(mlir::ModuleOp module) { +// // // Initialize LLVM targets. +// // llvm::InitializeNativeTarget(); +// // llvm::InitializeNativeTargetAsmPrinter(); + +// // // Register the translation from MLIR to LLVM IR, which must happen before +// // we +// // // can JIT-compile. +// // mlir::registerBuiltinDialectTranslation(*module->getContext()); +// // mlir::registerLLVMDialectTranslation(*module->getContext()); + +// // // An optimization pipeline to use within the execution engine. +// // auto optPipeline = mlir::makeOptimizingTransformer( +// // /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, +// // /*targetMachine=*/nullptr); + +// // // Create an MLIR execution engine. The execution engine eagerly +// // JIT-compiles +// // // the module. +// // mlir::ExecutionEngineOptions engineOptions; +// // engineOptions.transformer = optPipeline; +// // auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); +// // assert(maybeEngine && "failed to construct an execution engine"); +// // auto &engine = maybeEngine.get(); + +// // // Invoke the JIT-compiled function. +// // auto invocationResult = engine->invokePacked("main"); +// // if (invocationResult) { +// // llvm::errs() << "JIT invocation failed\n"; +// // return -1; +// // } + +// // return 0; +// // } + +// // int main(int argc, char **argv) { +// // Register any command line options. +// // mlir::registerAsmPrinterCLOptions(); +// // mlir::registerMLIRContextCLOptions(); +// // mlir::registerPassManagerCLOptions(); + +// // cl::ParseCommandLineOptions(argc, argv, "mlp compiler\n"); + +// // if (emitAction == Action::DumpAST) +// // return dumpAST(); + +// // // If we aren't dumping the AST, then we are compiling with/to MLIR. +// // mlir::DialectRegistry registry; +// // mlir::func::registerAllExtensions(registry); +// // mlir::LLVM::registerInlinerInterface(registry); + +// // mlir::MLIRContext context(registry); +// // // Load our Dialect in this MLIR Context. +// // context.getOrLoadDialect(); + +// // mlir::OwningOpRef module; +// // if (int error = loadAndProcessMLIR(context, module)) +// // return error; + +// // // If we aren't exporting to non-mlir, then we are done. +// // bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; +// // if (isOutputingMLIR) { +// // module->dump(); +// // return 0; +// // } + +// // // Check to see if we are compiling to LLVM IR. +// // if (emitAction == Action::DumpLLVMIR) +// // return dumpLLVMIR(*module); + +// // // Otherwise, we must be running the jit. +// // if (emitAction == Action::RunJIT) +// // return runJit(*module); + +// // llvm::errs() << "No action specified (parsing only?), use -emit=\n"; +// // return -1; +// //} + +// #ifdef PRINT +// #define PRINT +// #endif +// // #define PRINT + +// using namespace mlir; +// using namespace builder1; +// using namespace jit; +// // using namespace dbs; +// // using namespace mlp; + +// int main() { +// // Register any command line options. +// mlir::registerAsmPrinterCLOptions(); +// mlir::registerMLIRContextCLOptions(); +// mlir::registerPassManagerCLOptions(); + +// mlir::DialectRegistry registry; + +// registry +// .insert(); + +// mlir::func::registerAllExtensions(registry); +// mlir::registerBuiltinDialectTranslation(registry); +// mlir::registerLLVMDialectTranslation(registry); +// // mlir::registerAllDialects(registry); + +// mlir::MLIRContext ctx(registry); +// // ctx.appendDialectRegistry(registry); +// ctx.getOrLoadDialect(); + +// ctx.loadAllAvailableDialects(); + +// // registry.insert(); +// // registry.insert(); + +// // Correct way to create a module: +// mlir::OwningOpRef module = +// mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx)); + +// // createMainFunction(ctx, *module); +// // createAddFunction(ctx, *module); +// // createMulFunction(ctx, *module); +// createMLPAddFunction(ctx, *module); + +// // createMLPAddTOSAFunction(ctx, *module); +// // createMLPReluFunction(ctx, *module); +// // createMLPTESTFunction(ctx, *module); + +// llvm::outs() +// << "\n===================== mlir dialect ========================\n"; + +// module->print(llvm::outs()); +// llvm::outs() << "\n"; + +// // run pass pipeline (use ctx variable from your main) +// PassManager pm(&ctx); + +// // if (1) { +// // // Inline all functions into main and then delete them. +// // // pm.addPass(mlir::createInlinerPass()); + +// // // Now that there is only one function, we can infer the shapes of each +// // of +// // // the operations. + +// // mlir::OpPassManager &optPM = pm.nest(); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::mlp::createShapeInferencePass()); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::createCSEPass()); +// // } + +// // if (1) { +// // // Partially lower the mlp dialect. +// // pm.addPass(mlir::dbs::createLowerToLinalgPass()); + +// // mlir::bufferization::OneShotBufferizationOptions options; +// // options.allowReturnAllocsFromLoops = true; +// // // pm.addPass(mlir::bufferization::createOneShotBufferizePass(options)); +// // // //pm.addPass(mlir::bufferization::createBufferDeallocationPass()); +// // // pm.addPass(mlir::createConvertLinalgToLoopsPass()); +// // // pm.addPass(mlir::createConvertSCFToCFPass()); +// // // ------------------------------------------------------------ +// // // pm.addPass(mlir::createConvertArithToLLVMPass()); +// // // pm.addPass(mlir::createConvertMemRefToLLVMPass()); +// // // pm.addPass(mlir::createConvertFuncToLLVMPass()); +// // // pm.addPass(mlir::tosa::createTosaToArith()); + +// // // // Add a few cleanups post lowering. +// // // mlir::OpPassManager &optPM = pm.nest(); +// // // optPM.addPass(mlir::createCanonicalizerPass()); +// // // optPM.addPass(mlir::createCSEPass()); + +// // // // Add optimizations if enabled. +// // // if (auto opt = true; opt) { +// // // optPM.addPass(mlir::affine::createLoopFusionPass()); +// // // optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); +// // // } +// // } + +// // if (0) { +// // // Partially lower the mlp dialect. +// // pm.addPass(mlir::dbs::createLowerToAffinePass()); + +// // // Add a few cleanups post lowering. +// // mlir::OpPassManager &optPM = pm.nest(); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::createCSEPass()); + +// // // Add optimizations if enabled. +// // if (auto opt = true; opt) { +// // optPM.addPass(mlir::affine::createLoopFusionPass()); +// // optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); +// // } +// // } + +// // if (0) { +// // // Finish lowering the mlp IR to the LLVM dialect. +// // pm.addPass(mlir::mlp::createLowerToLLVMPass()); +// // // This is necessary to have line tables emitted and basic +// // // debugger working. In the future we will add proper debug information +// // // emission directly from our frontend. +// // pm.addNestedPass( +// // mlir::LLVM::createDIScopeForLLVMFuncOpPass()); +// // } + +// if (mlir::failed(pm.run(*module))) { +// llvm::errs() << "Lowering pipeline failed!\n"; +// return 1; +// } +// llvm::outs() +// << "\n===================== lowered dialect =====================\n"; +// module->print(llvm::outs()); + +// // translate and print LLVM IR +// if (0) { + +// llvm::LLVMContext llvmCtx; +// std::unique_ptr llvmModule = +// mlir::translateModuleToLLVMIR(*module, llvmCtx); + +// if (!llvmModule) { +// llvm::errs() << "translateModuleToLLVMIR failed\n"; +// return 1; +// } +// llvm::outs() << "\n=== LLVM IR ===\n"; +// llvmModule->dump(); + +// if (llvm::verifyModule(*llvmModule, &llvm::errs())) { +// llvm::errs() << "IR verification failed\n"; +// return 0; +// } +// } + +// // Otherwise, we must be running the jit. +// // if (emitAction == Action::RunJIT) +// // runJit(*module); +// std::cout << std::endl; +// // llvm::outs() << "\n=== MLIR TEST dialect ===\n"; +// // module->print(llvm::outs()); +// return 0; +// } + + +////////////////////////////////////////////////// +// //===- main.cpp - The Mlp Compiler ----------------------------------------===// +// // +// //===----------------------------------------------------------------------===// +// // +// // This file implements the entry point for the Mlp compiler. +// // +// //===----------------------------------------------------------------------===// + +// #include "Builder.h" +// #include "Dialect.h" +// #include "Jit.h" +// #include "Passes.h" +// #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +// #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +// #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +// #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +// #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +// #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +// #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +// #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +// #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +// #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +// #include "mlir/Conversion/TosaToArith/TosaToArith.h" +// #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +// #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +// #include "mlir/Dialect/Affine/IR/AffineOps.h" +// #include "mlir/Dialect/Affine/Transforms/Passes.h" +// #include "mlir/Dialect/Arith/IR/Arith.h" +// #include "mlir/Dialect/Arith/Transforms/Passes.h" +// #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +// #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +// #include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" +// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +// #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +// #include "mlir/Dialect/Func/IR/FuncOps.h" +// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +// #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +// #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +// #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +// #include "mlir/Dialect/Linalg/Passes.h" +// #include "mlir/Dialect/Math/IR/Math.h" +// #include "mlir/Dialect/MemRef/IR/MemRef.h" +// #include "mlir/Dialect/MemRef/Transforms/Passes.h" +// #include "mlir/Dialect/SCF/Transforms/Passes.h" +// #include "mlir/Dialect/Tosa/Transforms/Passes.h" +// #include "mlir/ExecutionEngine/ExecutionEngine.h" +// #include "mlir/ExecutionEngine/OptUtils.h" +// #include "mlir/IR/AsmState.h" +// #include "mlir/IR/Builders.h" +// #include "mlir/IR/BuiltinOps.h" +// #include "mlir/IR/DialectRegistry.h" +// #include "mlir/IR/MLIRContext.h" +// #include "mlir/IR/Types.h" +// #include "mlir/IR/Verifier.h" +// #include "mlir/InitAllDialects.h" +// #include "mlir/InitAllExtensions.h" +// #include "mlir/Parser/Parser.h" +// #include "mlir/Pass/PassManager.h" +// #include "mlir/Support/LLVM.h" +// #include "mlir/Support/LogicalResult.h" +// #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +// #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +// #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +// #include "mlir/Target/LLVMIR/Export.h" +// #include "mlir/Target/LLVMIR/ModuleTranslation.h" +// #include "mlir/Transforms/DialectConversion.h" +// #include "mlir/Transforms/Passes.h" +// #include "llvm/ADT/StringRef.h" +// #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +// #include "llvm/IR/LLVMContext.h" +// #include "llvm/IR/Module.h" +// #include "llvm/IR/Verifier.h" +// #include "llvm/Support/CommandLine.h" +// #include "llvm/Support/ErrorOr.h" +// #include "llvm/Support/MemoryBuffer.h" +// #include "llvm/Support/SourceMgr.h" +// #include "llvm/Support/TargetSelect.h" +// #include "llvm/Support/raw_ostream.h" +// #include +// #include +// #include +// #include +// #include +// #include + +// // using namespace mlp; +// // namespace cl = llvm::cl; + +// #ifdef PRINT +// #define PRINT +// #endif +// // #define PRINT + +// using namespace mlir; +// using namespace builder; +// using namespace jit; + +// void registerLLVMTranslations(mlir::MLIRContext &context) { +// mlir::registerBuiltinDialectTranslation(context); +// mlir::registerLLVMDialectTranslation(context); +// } + +// int main() { +// // Register any command line options. +// mlir::registerAsmPrinterCLOptions(); +// mlir::registerMLIRContextCLOptions(); +// mlir::registerPassManagerCLOptions(); + +// mlir::DialectRegistry registry; + +// registry +// .insert(); + +// mlir::func::registerAllExtensions(registry); +// mlir::registerBuiltinDialectTranslation(registry); +// mlir::registerLLVMDialectTranslation(registry); +// // mlir::registerAllDialects(registry); + +// mlir::MLIRContext ctx(registry); +// ctx.getOrLoadDialect(); + +// ctx.loadAllAvailableDialects(); + +// registry.insert(); +// registry.insert(); + +// registerLLVMTranslations(ctx); +// // Correct way to create a module: +// mlir::OwningOpRef module = +// mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx)); + +// // createMainFunction(ctx, *module); +// // createAddFunction(ctx, *module); +// // createMulFunction(ctx, *module); +// // createMLPAddFunction(ctx, *module); + +// // createMLPAddTOSAFunction(ctx, *module); +// // createMLPReluFunction(ctx, *module); +// createMLPLinearFunction(ctx, *module); + +// llvm::outs() +// << "\n===================== mlir dialect ========================\n"; + +// module->print(llvm::outs()); +// llvm::outs() << "\n"; + +// // run pass pipeline (use ctx variable from your main) +// PassManager pm(&ctx); + +// if (1) { +// // Inline all functions into main and then delete them. +// pm.addPass(mlir::createInlinerPass()); + +// // Now that there is only one function, we can infer the shapes of each of +// // the operations. + +// mlir::OpPassManager &optPM = pm.nest(); +// optPM.addPass(mlir::createCanonicalizerPass()); +// optPM.addPass(mlir::mlp::createShapeInferencePass()); +// optPM.addPass(mlir::createCanonicalizerPass()); +// optPM.addPass(mlir::createCSEPass()); +// } + +// if (0) { +// // Partially lower the mlp dialect. +// pm.addPass(mlir::mlp::createLowerToLinalgPass()); + +// // mlir::bufferization::OneShotBufferizationOptions options; +// // options.allowReturnAllocsFromLoops = true; +// // pm.addPass(mlir::bufferization::createOneShotBufferizePass(options)); +// // //pm.addPass(mlir::bufferization::createBufferDeallocationPass()); +// // pm.addPass(mlir::createConvertLinalgToLoopsPass()); +// // pm.addPass(mlir::createConvertSCFToCFPass()); +// // ------------------------------------------------------------ +// // pm.addPass(mlir::createConvertArithToLLVMPass()); +// // pm.addPass(mlir::createConvertMemRefToLLVMPass()); +// // pm.addPass(mlir::createConvertFuncToLLVMPass()); +// // pm.addPass(mlir::tosa::createTosaToArith()); + +// // // Add a few cleanups post lowering. +// mlir::OpPassManager &optPM = pm.nest(); +// optPM.addPass(mlir::createCanonicalizerPass()); +// optPM.addPass(mlir::createCSEPass()); + +// // Add optimizations if enabled. +// if (auto opt = true; opt) { +// optPM.addPass(mlir::affine::createLoopFusionPass()); +// optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); +// } +// } + +// if (1) { +// // Partially lower the toy dialect. +// pm.addPass(mlir::mlp::createLowerToAffinePass()); + +// // Add a few cleanups post lowering. +// mlir::OpPassManager &optPM = pm.nest(); +// optPM.addPass(mlir::createCanonicalizerPass()); +// optPM.addPass(mlir::createCSEPass()); + +// // Add optimizations if enabled. +// if (0) { +// optPM.addPass(mlir::affine::createLoopFusionPass()); +// optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); +// } +// } + +// // if (1) { +// // // Partially lower the mlp dialect. +// // pm.addPass(mlir::mlp::createLowerToAffinePass()); +// // // Add a few cleanups post lowering. +// // mlir::OpPassManager &optPM = pm.nest(); +// // optPM.addPass(mlir::createCanonicalizerPass()); +// // optPM.addPass(mlir::createCSEPass()); +// // // Add optimizations if enabled. +// // if (auto opt = true; opt) { +// // optPM.addPass(mlir::affine::createLoopFusionPass()); +// // optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); +// // } +// // } + +// if (1) { + +// // Finish lowering the mlp IR to the LLVM dialect. +// pm.addPass(mlir::mlp::createLowerToLLVMPass()); +// // This is necessary to have line tables emitted and basic +// // debugger working. In the future we will add proper debug information +// // emission directly from our frontend. +// pm.addNestedPass( +// mlir::LLVM::createDIScopeForLLVMFuncOpPass()); +// } + +// if (mlir::failed(pm.run(*module))) { +// llvm::errs() << "Lowering pipeline failed!\n"; +// return 1; +// } +// llvm::outs() +// << "\n===================== lowered dialect =====================\n"; +// module->print(llvm::outs()); + +// // translate and print LLVM IR +// if (0) { + +// llvm::LLVMContext llvmCtx; +// std::unique_ptr llvmModule = +// mlir::translateModuleToLLVMIR(*module, llvmCtx); + +// if (!llvmModule) { +// llvm::errs() << "translateModuleToLLVMIR failed\n"; +// return 1; +// } +// llvm::outs() << "\n=== LLVM IR ===\n"; +// llvmModule->dump(); + +// if (llvm::verifyModule(*llvmModule, &llvm::errs())) { +// llvm::errs() << "IR verification failed\n"; +// return 0; +// } +// } + +// // Otherwise, we must be running the jit. +// // if (emitAction == Action::RunJIT) +// // runJit(*module); +// std::cout << std::endl; +// // llvm::outs() << "\n=== MLIR TEST dialect ===\n"; +// // module->print(llvm::outs()); +// return 0; +// } \ No newline at end of file diff --git a/targets/CMakeLists.txt b/targets/CMakeLists.txt new file mode 100644 index 0000000..fdf8149 --- /dev/null +++ b/targets/CMakeLists.txt @@ -0,0 +1,20 @@ + +# Automatically detect all backend directories under targets/ +file(GLOB BACKEND_DIRS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/*) + +foreach(BACKEND ${BACKEND_DIRS}) + set(BACKEND_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${BACKEND}) + if(IS_DIRECTORY ${BACKEND_PATH}) + # Expect each backend to have a single backend source file named Backend.cpp + set(BACKEND_SRC ${BACKEND_PATH}/${BACKEND}Backend.cpp) + + if(EXISTS ${BACKEND_SRC}) + message(STATUS "Adding backend: ${BACKEND}") + add_library(ppytorch_${BACKEND} ${BACKEND_SRC}) + target_include_directories(ppytorch_${BACKEND} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include) + target_link_libraries(ppytorch_${BACKEND} PUBLIC ppytorch_core MLIRIR MLIRLLVM) + else() + message(WARNING "No source file found for backend ${BACKEND}, skipping...") + endif() + endif() +endforeach() diff --git a/targets/cpu/LLVMBackend.cpp b/targets/cpu/LLVMBackend.cpp new file mode 100644 index 0000000..d74eb83 --- /dev/null +++ b/targets/cpu/LLVMBackend.cpp @@ -0,0 +1,19 @@ +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/IR/BuiltinOps.h" +#include "Passes.h" + +namespace mlir +{ + namespace mlp + { + namespace hardware + { + + void registerCPUPasses() + { + } + + } // namespace cpu + } // namespace ppytorch +} // namespace mlir diff --git a/targets/gpu/CudaBackend.cpp b/targets/gpu/CudaBackend.cpp new file mode 100644 index 0000000..75848ab --- /dev/null +++ b/targets/gpu/CudaBackend.cpp @@ -0,0 +1,19 @@ +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/IR/BuiltinOps.h" +#include "Passes.h" + +namespace mlir +{ + namespace mlp + { + namespace hardware + { + + void registerCudaPasses() + { + } + + } // namespace cpu + } // namespace ppytorch +} // namespace mlir diff --git a/targets/metal/MetalBackend.cpp b/targets/metal/MetalBackend.cpp new file mode 100644 index 0000000..8edc102 --- /dev/null +++ b/targets/metal/MetalBackend.cpp @@ -0,0 +1,19 @@ +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/IR/BuiltinOps.h" +#include "Passes.h" + +namespace mlir +{ + namespace mlp + { + namespace hardware + { + + void registerMetalPasses() + { + } + + } // namespace cpu + } // namespace ppytorch +} // namespace mlir diff --git a/targets/riscv/Backend.cpp b/targets/riscv/Backend.cpp new file mode 100644 index 0000000..cf37fb8 --- /dev/null +++ b/targets/riscv/Backend.cpp @@ -0,0 +1,19 @@ +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/IR/BuiltinOps.h" +#include "Passes.h" + +namespace mlir +{ + namespace mlp + { + namespace hardware + { + + void registerRISCVPasses() + { + } + + } // namespace cpu + } // namespace ppytorch +} // namespace mlir diff --git a/targets/rocm/ROCmBackend.cpp b/targets/rocm/ROCmBackend.cpp new file mode 100644 index 0000000..b7e8361 --- /dev/null +++ b/targets/rocm/ROCmBackend.cpp @@ -0,0 +1,19 @@ +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/IR/BuiltinOps.h" +#include "Passes.h" + +namespace mlir +{ + namespace mlp + { + namespace hardware + { + + void registerROCMPasses() + { + } + + } // namespace cpu + } // namespace ppytorch +} // namespace mlir