diff --git a/water/include/water/Transforms/Passes.td b/water/include/water/Transforms/Passes.td index c64609557..3263951fd 100644 --- a/water/include/water/Transforms/Passes.td +++ b/water/include/water/Transforms/Passes.td @@ -113,4 +113,25 @@ def WaterGPUToGPURuntime : Pass<"water-gpu-to-gpu-runtime", "::mlir::ModuleOp"> let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; } +def WaterGPUModuleToBinary : Pass<"water-gpu-module-to-binary", ""> { + let summary = "Transforms GPU modules into binaries."; + let description = [{ + This pass searches for all nested GPU modules with target attributes + and serializes them to binary format, producing a GPU binary operation. + + This is a simplified version of the upstream gpu-module-to-binary pass, + tailored for the Wave project. Currently supports ROCDL targets only. + }]; + let options = [ + Option<"toolkitPath", "toolkit", "std::string", [{""}], + "Toolkit path.">, + ListOption<"linkFiles", "l", "std::string", + "Extra bitcode files to link to.">, + Option<"dumpIntermediates", "dump-intermediates", "std::string", [{""}], + "Directory to dump intermediate compilation files (LLVM IR, ISA).">, + Option<"overrideIntermediates", "override-intermediates", "std::string", [{""}], + "Directory containing intermediate files to use instead of generating them.">, + ]; +} + #endif // WATER_PASSES diff --git a/water/lib/Transforms/AssembleISA.cpp b/water/lib/Transforms/AssembleISA.cpp new file mode 100644 index 000000000..6a931e7ff --- /dev/null +++ b/water/lib/Transforms/AssembleISA.cpp @@ -0,0 +1,143 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "AssembleISA.h" + +#include "llvm/MC/MCAsmBackend.h" +#include "llvm/MC/MCAsmInfo.h" +#include "llvm/MC/MCCodeEmitter.h" +#include "llvm/MC/MCContext.h" +#include "llvm/MC/MCInstrInfo.h" +#include "llvm/MC/MCObjectFileInfo.h" +#include "llvm/MC/MCObjectWriter.h" +#include "llvm/MC/MCParser/MCAsmParser.h" +#include "llvm/MC/MCParser/MCTargetAsmParser.h" +#include "llvm/MC/MCRegisterInfo.h" +#include "llvm/MC/MCStreamer.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" + +using namespace mlir; + +namespace mlir::water { + +void initializeAMDGPUTarget() { + static bool initialized = []() { + LLVMInitializeAMDGPUTarget(); + LLVMInitializeAMDGPUTargetInfo(); + LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmParser(); + LLVMInitializeAMDGPUAsmPrinter(); + return true; + }(); + (void)initialized; +} + +FailureOr> +assembleISAToHSACO(Operation *op, StringRef isa, + llvm::TargetMachine &targetMachine, StringRef toolkitPath) { + initializeAMDGPUTarget(); + + // Step 1: Assemble ISA to object file using MC infrastructure. + llvm::Triple triple = targetMachine.getTargetTriple(); + std::string error; + const llvm::Target *target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (!target) + return op->emitError() << "Failed to lookup target: " << error; + + // Set up MC infrastructure. + llvm::SourceMgr srcMgr; + srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa), + llvm::SMLoc()); + + const llvm::MCTargetOptions mcOptions; + std::unique_ptr mri(target->createMCRegInfo(triple)); + std::unique_ptr mai( + target->createMCAsmInfo(*mri, triple, mcOptions)); + std::unique_ptr sti( + target->createMCSubtargetInfo(triple, targetMachine.getTargetCPU(), + targetMachine.getTargetFeatureString())); + + SmallVector objectBuffer; + llvm::raw_svector_ostream os(objectBuffer); + + llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr, + &mcOptions); + std::unique_ptr mofi(target->createMCObjectFileInfo( + ctx, /*PIC=*/false, /*LargeCodeModel=*/false)); + ctx.setObjectFileInfo(mofi.get()); + + std::unique_ptr mcii(target->createMCInstrInfo()); + llvm::MCCodeEmitter *ce = target->createMCCodeEmitter(*mcii, ctx); + llvm::MCAsmBackend *mab = target->createMCAsmBackend(*sti, *mri, mcOptions); + std::unique_ptr mcStreamer(target->createMCObjectStreamer( + triple, ctx, std::unique_ptr(mab), + mab->createObjectWriter(os), std::unique_ptr(ce), + *sti)); + + std::unique_ptr parser( + createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); + std::unique_ptr tap( + target->createMCAsmParser(*sti, *parser, *mcii, mcOptions)); + + if (!tap) + return op->emitError("Assembler initialization error"); + + parser->setTargetParser(*tap); + if (parser->Run(false)) + return op->emitError("Assembly parsing failed"); + + // Step 2: Link object file to create HSACO. + // Write object to temporary file. + int tempObjFd = -1; + SmallString<128> tempObjFilename; + if (llvm::sys::fs::createTemporaryFile("kernel%%", "o", tempObjFd, + tempObjFilename)) + return op->emitError("Failed to create temporary file for object"); + + llvm::FileRemover cleanupObj(tempObjFilename); + { + llvm::raw_fd_ostream tempObjOs(tempObjFd, true); + tempObjOs << StringRef(objectBuffer.data(), objectBuffer.size()); + tempObjOs.flush(); + } + + // Create temporary file for HSACO. + SmallString<128> tempHsacoFilename; + if (llvm::sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFilename)) + return op->emitError("Failed to create temporary file for HSACO"); + + llvm::FileRemover cleanupHsaco(tempHsacoFilename); + + // Link using ld.lld. + SmallString<128> lldPath(toolkitPath); + llvm::sys::path::append(lldPath, "llvm", "bin", "ld.lld"); + int lldResult = llvm::sys::ExecuteAndWait( + lldPath, {"ld.lld", "-shared", tempObjFilename, "-o", tempHsacoFilename}); + if (lldResult != 0) + return op->emitError("ld.lld invocation failed"); + + // Read HSACO file. + auto hsacoFile = + llvm::MemoryBuffer::getFile(tempHsacoFilename, /*IsText=*/false); + if (!hsacoFile) + return op->emitError("Failed to read HSACO from temporary file"); + + StringRef buffer = (*hsacoFile)->getBuffer(); + return SmallVector(buffer.begin(), buffer.end()); +} + +} // namespace mlir::water diff --git a/water/lib/Transforms/AssembleISA.h b/water/lib/Transforms/AssembleISA.h new file mode 100644 index 000000000..657065ed6 --- /dev/null +++ b/water/lib/Transforms/AssembleISA.h @@ -0,0 +1,42 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef WATER_LIB_TRANSFORMS_ASSEMBLEISA_H +#define WATER_LIB_TRANSFORMS_ASSEMBLEISA_H + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class TargetMachine; +} // namespace llvm + +namespace mlir::water { + +/// Initializes the LLVM AMDGPU target. Safe to call multiple times. +void initializeAMDGPUTarget(); + +/// Assembles ISA (assembly code) to HSACO (HSA Code Object) binary. +/// +/// This function: +/// 1. Parses the ISA using LLVM MC infrastructure +/// 2. Assembles it to an ELF object file +/// 3. Links the object file using ld.lld to create an HSACO +/// +/// \param op Operation for error reporting +/// \param isa Assembly code to assemble +/// \param targetMachine Target machine for MC infrastructure setup +/// \param toolkitPath Path to toolkit containing ld.lld +/// \return Binary data of the HSACO file, or failure +FailureOr> +assembleISAToHSACO(Operation *op, StringRef isa, + llvm::TargetMachine &targetMachine, StringRef toolkitPath); + +} // namespace mlir::water + +#endif // WATER_LIB_TRANSFORMS_ASSEMBLEISA_H diff --git a/water/lib/Transforms/CMakeLists.txt b/water/lib/Transforms/CMakeLists.txt index c441b192b..7ad8c3f99 100644 --- a/water/lib/Transforms/CMakeLists.txt +++ b/water/lib/Transforms/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_dialect_library(MLIRWaterTransforms AccessCheckers.cpp + AssembleISA.cpp CheckStaticAssertions.cpp + GPUModuleToBinary.cpp GPUToGPURuntime.cpp SLPVectorizer.cpp @@ -20,6 +22,7 @@ add_mlir_dialect_library(MLIRWaterTransforms MLIRLLVMDialect MLIRMemRefDialect MLIRPass + MLIRROCDLTarget MLIRRewrite MLIRTransformUtils MLIRVectorDialect diff --git a/water/lib/Transforms/GPUModuleToBinary.cpp b/water/lib/Transforms/GPUModuleToBinary.cpp new file mode 100644 index 000000000..f2fd8f970 --- /dev/null +++ b/water/lib/Transforms/GPUModuleToBinary.cpp @@ -0,0 +1,494 @@ +// Copyright 2025 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "water/Transforms/Passes.h" + +#include "AssembleISA.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVM/ROCDL/Utils.h" +#include "mlir/Target/LLVMIR/Export.h" + +#include "llvm/ADT/StringSet.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/Internalize.h" + +using namespace mlir; +using namespace mlir::gpu; + +namespace mlir::water { +#define GEN_PASS_DEF_WATERGPUMODULETOBINARY +#include "water/Transforms/Passes.h.inc" +} // namespace mlir::water + +namespace { +class WaterGPUModuleToBinaryPass + : public water::impl::WaterGPUModuleToBinaryBase< + WaterGPUModuleToBinaryPass> { +public: + using Base::Base; + void runOnOperation() final; + +private: + LogicalResult serializeModule(GPUModuleOp mod); + + // Helper methods + std::unique_ptr loadBitcodeFile(llvm::LLVMContext &context, + StringRef path); + LogicalResult + linkBitcodeFiles(llvm::Module &mod, + SmallVector> &&libs); + FailureOr> + createTargetMachine(Attribute targetAttr); + LogicalResult optimizeModule(llvm::Module &mod, + llvm::TargetMachine &targetMachine); + FailureOr compileToISA(llvm::Module &mod, + llvm::TargetMachine &targetMachine); + + // Dump helpers + LogicalResult dumpLLVMModule(llvm::Module &mod, StringRef modName, + StringRef suffix); + LogicalResult dumpText(StringRef text, StringRef modName, StringRef suffix); + LogicalResult dumpBinary(ArrayRef data, StringRef modName, + StringRef suffix); + + // Override helpers + FailureOr> + tryLoadOverrideLLVM(llvm::LLVMContext &context, StringRef modName, + StringRef suffix); + FailureOr> tryLoadOverrideText(StringRef modName, + StringRef suffix); + FailureOr>> + tryLoadOverrideBinary(StringRef modName, StringRef suffix); +}; +} // namespace + +LogicalResult WaterGPUModuleToBinaryPass::serializeModule(GPUModuleOp mod) { + // Check that there is exactly one target. + if (!mod.getTargetsAttr() || mod.getTargetsAttr().size() != 1) + return mod.emitError("GPU module must have exactly one target attribute"); + + // Get the target attribute. + Attribute targetAttr = mod.getTargetsAttr()[0]; + if (!targetAttr) + return mod.emitError("Target attribute cannot be null"); + + // Step 1: Translate GPU module to LLVM IR. + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + translateModuleToLLVMIR(mod, llvmContext); + + if (!llvmModule) + return mod.emitError("Failed to translate GPU module to LLVM IR"); + + // Create dump directory if specified. + if (!dumpIntermediates.empty()) { + std::error_code ec = llvm::sys::fs::create_directories(dumpIntermediates); + if (ec) + return mod.emitError() + << "Failed to create dump directory: " << dumpIntermediates << ": " + << ec.message(); + } + + auto dumpAndOverrideLLVM = [&](StringRef suffix) -> LogicalResult { + StringRef modName = mod.getName(); + if (failed(dumpLLVMModule(*llvmModule, modName, suffix))) + return failure(); + + auto overrideLLVM = tryLoadOverrideLLVM(llvmContext, modName, suffix); + if (failed(overrideLLVM)) + return failure(); + + if (*overrideLLVM) + llvmModule = std::move(*overrideLLVM); + + return success(); + }; + // Dump/override original LLVM IR. + if (failed(dumpAndOverrideLLVM("_original"))) + return failure(); + + // Step 2: Load and link device libraries. + SmallVector> bitcodeLibs; + for (const std::string &path : linkFiles) { + auto lib = loadBitcodeFile(llvmContext, path); + if (!lib) + return mod.emitError("Failed to load bitcode file: " + path); + bitcodeLibs.push_back(std::move(lib)); + } + + if (failed(linkBitcodeFiles(*llvmModule, std::move(bitcodeLibs)))) + return mod.emitError("Failed to link bitcode libraries"); + + // Dump/override linked LLVM IR. + if (failed(dumpAndOverrideLLVM("_linked"))) + return failure(); + + // Step 3: Create target machine and set data layout. + FailureOr> targetMachine = + createTargetMachine(targetAttr); + if (failed(targetMachine)) + return mod.emitError("Failed to create target machine"); + + // Set the data layout and target triple to match the target machine. + llvmModule->setDataLayout((*targetMachine)->createDataLayout()); + llvmModule->setTargetTriple((*targetMachine)->getTargetTriple()); + + // Step 4: Optimize LLVM IR. + if (failed(optimizeModule(*llvmModule, **targetMachine))) + return mod.emitError("Failed to optimize LLVM IR"); + + // Dump optimized LLVM IR. + if (failed(dumpAndOverrideLLVM("_optimized"))) + return failure(); + + // Step 5: Compile to ISA. + FailureOr isa = compileToISA(*llvmModule, **targetMachine); + if (failed(isa)) + return mod.emitError("Failed to compile to ISA"); + + auto dumpAndOverrideISA = [&](StringRef suffix) -> LogicalResult { + StringRef modName = mod.getName(); + if (failed(dumpText(*isa, modName, suffix))) + return failure(); + + auto overrideISA = tryLoadOverrideText(modName, suffix); + if (failed(overrideISA)) + return failure(); + + if (*overrideISA) + isa = std::move(**overrideISA); + + return success(); + }; + + // Dump/override ISA. + if (failed(dumpAndOverrideISA(".s"))) + return failure(); + + // Step 6: Assemble to binary. + // Use ROCM_PATH environment variable if toolkitPath is not provided. + StringRef actualToolkitPath = toolkitPath; + if (actualToolkitPath.empty()) + actualToolkitPath = ROCDL::getROCMPath(); + + FailureOr> binary = + water::assembleISAToHSACO(mod, *isa, **targetMachine, actualToolkitPath); + if (failed(binary)) + return mod.emitError("Failed to assemble to binary"); + + SmallVector binaryData = std::move(*binary); + + auto dumpAndOverrideBinary = [&](StringRef suffix) -> LogicalResult { + StringRef modName = mod.getName(); + if (failed(dumpBinary(binaryData, modName, suffix))) + return failure(); + + auto overrideBinary = tryLoadOverrideBinary(modName, suffix); + if (failed(overrideBinary)) + return failure(); + + if (*overrideBinary) + binaryData = std::move(**overrideBinary); + + return success(); + }; + + // Dump/override HSACO binary. + if (failed(dumpAndOverrideBinary(".hsaco"))) + return failure(); + + // Create object attribute. + Builder attrBuilder(mod.getContext()); + StringAttr binaryAttr = attrBuilder.getStringAttr( + StringRef(binaryData.data(), binaryData.size())); + + DictionaryAttr properties{}; + gpu::KernelTableAttr kernels; + + Attribute objectAttr = attrBuilder.getAttr( + targetAttr, gpu::CompilationTarget::Binary, binaryAttr, properties, + kernels); + + // Create gpu.binary op. + OpBuilder builder(mod.getContext()); + builder.setInsertionPointAfter(mod); + gpu::BinaryOp::create(builder, mod.getLoc(), mod.getName(), + /*offloadingHandler=*/nullptr, + builder.getArrayAttr({objectAttr})); + + // Erase the original module. + mod->erase(); + return success(); +} + +std::unique_ptr +WaterGPUModuleToBinaryPass::loadBitcodeFile(llvm::LLVMContext &context, + StringRef path) { + llvm::SMDiagnostic error; + std::unique_ptr library = + llvm::getLazyIRFileModule(path, error, context); + if (!library) { + getOperation()->emitError() << "Failed loading bitcode file from " << path + << ", error: " << error.getMessage(); + return nullptr; + } + return library; +} + +LogicalResult WaterGPUModuleToBinaryPass::linkBitcodeFiles( + llvm::Module &mod, SmallVector> &&libs) { + if (libs.empty()) + return success(); + + llvm::Linker linker(mod); + for (std::unique_ptr &libModule : libs) { + // Link the library, importing only needed symbols. + bool err = linker.linkInModule( + std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded, + [](llvm::Module &m, const StringSet<> &gvs) { + llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) { + return !gv.hasName() || (gvs.count(gv.getName()) == 0); + }); + }); + + if (err) { + getOperation()->emitError("Failed during bitcode linking"); + return failure(); + } + } + return success(); +} + +FailureOr> +WaterGPUModuleToBinaryPass::createTargetMachine(Attribute targetAttr) { + water::initializeAMDGPUTarget(); + + auto rocdlTarget = dyn_cast(targetAttr); + if (!rocdlTarget) + return getOperation()->emitError( + "Only ROCDL targets are currently supported"); + + std::string error; + llvm::Triple triple(llvm::Triple::normalize(rocdlTarget.getTriple())); + const llvm::Target *llvmTarget = + llvm::TargetRegistry::lookupTarget(triple, error); + + if (!llvmTarget) + return getOperation()->emitError() + << "Failed to lookup target for triple '" << rocdlTarget.getTriple() + << "': " << error; + + std::unique_ptr targetMachine( + llvmTarget->createTargetMachine(triple, rocdlTarget.getChip(), + rocdlTarget.getFeatures(), {}, {})); + if (!targetMachine) + return getOperation()->emitError("Failed to create target machine"); + + // Set optimization level from target attribute. + targetMachine->setOptLevel( + static_cast(rocdlTarget.getO())); + + return targetMachine; +} + +LogicalResult +WaterGPUModuleToBinaryPass::optimizeModule(llvm::Module &mod, + llvm::TargetMachine &targetMachine) { + // Get optimization level from target machine. + int optLevel = static_cast(targetMachine.getOptLevel()); + + auto transformer = + makeOptimizingTransformer(optLevel, /*sizeLevel=*/0, &targetMachine); + auto error = transformer(&mod); + if (error) { + InFlightDiagnostic mlirError = getOperation()->emitError(); + llvm::handleAllErrors( + std::move(error), [&mlirError](const llvm::ErrorInfoBase &ei) { + mlirError << "Failed to optimize LLVM IR: " << ei.message(); + }); + return failure(); + } + return success(); +} + +FailureOr +WaterGPUModuleToBinaryPass::compileToISA(llvm::Module &mod, + llvm::TargetMachine &targetMachine) { + SmallVector isaBuffer; + llvm::raw_svector_ostream stream(isaBuffer); + + llvm::legacy::PassManager codegen; + if (targetMachine.addPassesToEmitFile(codegen, stream, nullptr, + llvm::CodeGenFileType::AssemblyFile)) + return getOperation()->emitError("Target machine cannot emit assembly"); + + codegen.run(mod); + return std::string(isaBuffer.begin(), isaBuffer.end()); +} + +LogicalResult WaterGPUModuleToBinaryPass::dumpLLVMModule(llvm::Module &mod, + StringRef modName, + StringRef suffix) { + if (dumpIntermediates.empty()) + return success(); + + SmallString<128> path(dumpIntermediates); + llvm::sys::path::append(path, modName + suffix + ".ll"); + + std::error_code ec; + llvm::ToolOutputFile outputFile(path, ec, llvm::sys::fs::OF_None); + if (ec) + return getOperation()->emitError() + << "Failed to open file for dumping: " << path << ": " + << ec.message(); + + mod.print(outputFile.os(), nullptr); + outputFile.keep(); + return success(); +} + +LogicalResult WaterGPUModuleToBinaryPass::dumpText(StringRef text, + StringRef modName, + StringRef suffix) { + if (dumpIntermediates.empty()) + return success(); + + SmallString<128> path(dumpIntermediates); + llvm::sys::path::append(path, modName + suffix); + + std::error_code ec; + llvm::ToolOutputFile outputFile(path, ec, llvm::sys::fs::OF_None); + if (ec) + return getOperation()->emitError() + << "Failed to open file for dumping: " << path << ": " + << ec.message(); + + outputFile.os() << text; + outputFile.keep(); + return success(); +} + +LogicalResult WaterGPUModuleToBinaryPass::dumpBinary(ArrayRef data, + StringRef modName, + StringRef suffix) { + if (dumpIntermediates.empty()) + return success(); + + SmallString<128> path(dumpIntermediates); + llvm::sys::path::append(path, modName + suffix); + + std::error_code ec; + llvm::ToolOutputFile outputFile(path, ec, llvm::sys::fs::OF_None); + if (ec) + return getOperation()->emitError() + << "Failed to open file for dumping: " << path << ": " + << ec.message(); + + outputFile.os().write(data.data(), data.size()); + outputFile.keep(); + return success(); +} + +FailureOr> +WaterGPUModuleToBinaryPass::tryLoadOverrideLLVM(llvm::LLVMContext &context, + StringRef modName, + StringRef suffix) { + if (overrideIntermediates.empty()) + return std::unique_ptr(nullptr); + + SmallString<128> path(overrideIntermediates); + llvm::sys::path::append(path, modName + suffix + ".ll"); + + if (!llvm::sys::fs::exists(path)) + return std::unique_ptr(nullptr); + + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIRFile(path, error, context); + if (!module) + return getOperation()->emitError() + << "Failed to load override LLVM IR from " << path << ": " + << error.getMessage(); + + return module; +} + +FailureOr> +WaterGPUModuleToBinaryPass::tryLoadOverrideText(StringRef modName, + StringRef suffix) { + if (overrideIntermediates.empty()) + return std::optional(std::nullopt); + + SmallString<128> path(overrideIntermediates); + llvm::sys::path::append(path, modName + suffix); + + if (!llvm::sys::fs::exists(path)) + return std::optional(std::nullopt); + + auto bufferOrError = llvm::MemoryBuffer::getFile(path); + if (!bufferOrError) + return getOperation()->emitError() + << "Failed to load override file from " << path << ": " + << bufferOrError.getError().message(); + + return std::optional(bufferOrError.get()->getBuffer().str()); +} + +FailureOr>> +WaterGPUModuleToBinaryPass::tryLoadOverrideBinary(StringRef modName, + StringRef suffix) { + if (overrideIntermediates.empty()) + return std::optional>(std::nullopt); + + SmallString<128> path(overrideIntermediates); + llvm::sys::path::append(path, modName + suffix); + + if (!llvm::sys::fs::exists(path)) + return std::optional>(std::nullopt); + + auto bufferOrError = llvm::MemoryBuffer::getFile(path); + if (!bufferOrError) + return getOperation()->emitError() + << "Failed to load override binary from " << path << ": " + << bufferOrError.getError().message(); + + StringRef data = bufferOrError.get()->getBuffer(); + SmallVector result(data.begin(), data.end()); + return std::optional>(std::move(result)); +} + +void WaterGPUModuleToBinaryPass::runOnOperation() { + // Walk all regions and blocks looking for GPUModuleOp instances. + for (Region ®ion : getOperation()->getRegions()) { + for (Block &block : region.getBlocks()) { + // Use early_inc_range since we're erasing modules during iteration. + for (auto module : + llvm::make_early_inc_range(block.getOps())) { + if (failed(serializeModule(module))) + return signalPassFailure(); + } + } + } +} diff --git a/water/test/Transforms/gpu-module-to-binary-dump.mlir b/water/test/Transforms/gpu-module-to-binary-dump.mlir new file mode 100644 index 000000000..b21329299 --- /dev/null +++ b/water/test/Transforms/gpu-module-to-binary-dump.mlir @@ -0,0 +1,20 @@ +// RUN: rm -rf %t +// RUN: water-opt %s --water-gpu-module-to-binary="dump-intermediates=%t" | FileCheck %s +// RUN: test -f %t/kernel_module_original.ll +// RUN: test -f %t/kernel_module_linked.ll +// RUN: test -f %t/kernel_module_optimized.ll +// RUN: test -f %t/kernel_module.s +// RUN: test -f %t/kernel_module.hsaco + +// Test that the pass dumps intermediate compilation files when dump-intermediates is specified + +// CHECK-LABEL: module attributes {gpu.container_module} +module attributes {gpu.container_module} { + // CHECK-NOT: gpu.module + // CHECK: gpu.binary @kernel_module [#gpu.object<#rocdl.target, bin = + gpu.module @kernel_module [#rocdl.target] { + llvm.func @simple_kernel(%arg0: f32) attributes {gpu.kernel} { + llvm.return + } + } +} diff --git a/water/test/Transforms/gpu-module-to-binary-override.mlir b/water/test/Transforms/gpu-module-to-binary-override.mlir new file mode 100644 index 000000000..f34e03ec4 --- /dev/null +++ b/water/test/Transforms/gpu-module-to-binary-override.mlir @@ -0,0 +1,23 @@ +// RUN: rm -rf %t && mkdir -p %t/dump1 %t/dump2 %t/override +// RUN: water-opt %s --water-gpu-module-to-binary="dump-intermediates=%t/dump1" | FileCheck %s +// RUN: cp %t/dump1/kernel_module_linked.ll %t/override/kernel_module_linked.ll +// RUN: sed -i 's/i32/i64/g' %t/override/kernel_module_linked.ll +// RUN: water-opt %s --water-gpu-module-to-binary="dump-intermediates=%t/dump2 override-intermediates=%t/override" | FileCheck %s +// RUN: grep "define.*i64" %t/dump2/kernel_module_optimized.ll + +// Test that override-intermediates works by: +// 1. First run dumps all intermediates to dump1 +// 2. Copy linked LLVM IR to override directory and modify it (i32 -> i64) +// 3. Second run uses the modified linked IR from override directory +// 4. Verify the modification appears in the optimized IR (next stage after linked) + +// CHECK-LABEL: module attributes {gpu.container_module} +module attributes {gpu.container_module} { + // CHECK-NOT: gpu.module + // CHECK: gpu.binary @kernel_module [#gpu.object<#rocdl.target, bin = + gpu.module @kernel_module [#rocdl.target] { + llvm.func @simple_kernel(%arg0: i32) attributes {gpu.kernel} { + llvm.return + } + } +} diff --git a/water/test/Transforms/gpu-module-to-binary.mlir b/water/test/Transforms/gpu-module-to-binary.mlir new file mode 100644 index 000000000..bbcf0acab --- /dev/null +++ b/water/test/Transforms/gpu-module-to-binary.mlir @@ -0,0 +1,19 @@ +// RUN: water-opt %s --water-gpu-module-to-binary | FileCheck %s + +// Test that the pass converts a gpu.module with ROCDL target to a gpu.binary +// The gpu.module contains already-lowered LLVM IR +// +// This test requires ROCm to be installed. It uses mlir::ROCDL::getROCMPath() +// which checks ROCM_PATH, ROCM_ROOT, ROCM_HOME environment variables or uses +// the CMake-detected path. + +// CHECK-LABEL: module attributes {gpu.container_module} +module attributes {gpu.container_module} { + // CHECK-NOT: gpu.module + // CHECK: gpu.binary @kernel_module [#gpu.object<#rocdl.target, bin = + gpu.module @kernel_module [#rocdl.target] { + llvm.func @simple_kernel(%arg0: f32) attributes {gpu.kernel} { + llvm.return + } + } +}