diff --git a/src/enzyme_ad/jax/Utils.cpp b/src/enzyme_ad/jax/Utils.cpp index d3813e215..c22594dd1 100644 --- a/src/enzyme_ad/jax/Utils.cpp +++ b/src/enzyme_ad/jax/Utils.cpp @@ -33,6 +33,8 @@ #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "Interfaces/AutoDiffTypeInterface.h" + #include using namespace mlir; @@ -1241,6 +1243,231 @@ bool mayReadFrom(Operation *op, Value val) { return true; } +mlir::func::FuncOp +adaptToCallingConvention(mlir::func::FuncOp f, + ArrayRef inputTensorTypes, + ArrayRef byteOffsets) { + // Get the original function type + auto originalFuncType = f.getFunctionType(); + size_t numInputs = originalFuncType.getNumInputs(); + + // Validate inputs + assert(inputTensorTypes.size() == numInputs && + "Number of input tensor types must match function inputs"); + assert(byteOffsets.size() == numInputs && + "Number of byte offsets must match function inputs"); + + // Create the new function type using the outer specification types + auto context = f.getContext(); + auto loc = f.getLoc(); + auto newFuncType = mlir::FunctionType::get(context, inputTensorTypes, + originalFuncType.getResults()); + + // Create a new function with a unique name + std::string wrapperName = (f.getName() + "_adapted").str(); + OpBuilder builder(context); + builder.setInsertionPoint(f); + + auto wrapperFunc = + builder.create(loc, wrapperName, newFuncType); + + // Add entry block to the wrapper function + auto &entryBlock = *wrapperFunc.addEntryBlock(); + builder.setInsertionPointToStart(&entryBlock); + + // Process each argument + SmallVector adaptedArgs; + for (size_t i = 0; i < numInputs; ++i) { + Value arg = entryBlock.getArgument(i); + auto outerType = dyn_cast(inputTensorTypes[i]); + auto innerType = dyn_cast(originalFuncType.getInput(i)); + + if (!outerType || !innerType) { + // If not tensor types, pass through as-is + adaptedArgs.push_back(arg); + continue; + } + + Value adaptedArg = arg; + + // Handle byte offset if non-zero + int64_t byteOffset = byteOffsets[i]; + if (byteOffset != 0) { + // Calculate element offset from byte offset + auto elementType = outerType.getElementType(); + + // Get element size in bytes using AutoDiffTypeInterface + size_t elementSizeBytes = + cast(elementType).getApproxSize(); + + // Verify byte offset aligns with element boundaries + assert(byteOffset % elementSizeBytes == 0 && + "Byte offset must be aligned to element boundaries"); + + int64_t elementOffset = byteOffset / elementSizeBytes; + + auto outerShape = outerType.getShape(); + auto innerShape = innerType.getShape(); + + // Convert linear element offset to multi-dimensional start indices + SmallVector startIndices; + SmallVector limitIndices; + SmallVector strides(outerShape.size(), 1); + + int64_t remainingOffset = elementOffset; + + // Calculate strides for each dimension (row-major order) + for (size_t j = 0; j < outerShape.size(); ++j) { + // Calculate the stride for this dimension + int64_t dimStride = 1; + for (size_t k = j + 1; k < outerShape.size(); ++k) { + dimStride *= outerShape[k]; + } + + // Calculate the index for this dimension + int64_t dimIndex = remainingOffset / dimStride; + startIndices.push_back(dimIndex); + + // Calculate the limit based on the inner shape + int64_t innerDim = (j < innerShape.size()) ? innerShape[j] : 1; + int64_t limitIndex = dimIndex + innerDim; + + // Ensure limit doesn't exceed outer dimension bounds + assert(limitIndex <= outerShape[j] && + "Byte offset results in out-of-bounds access"); + limitIndices.push_back(limitIndex); + + // Update remaining offset for next dimension + remainingOffset = remainingOffset % dimStride; + } + + auto slicedType = + RankedTensorType::get(innerShape, outerType.getElementType()); + adaptedArg = builder.create( + loc, slicedType, adaptedArg, + builder.getDenseI64ArrayAttr(startIndices), + builder.getDenseI64ArrayAttr(limitIndices), + builder.getDenseI64ArrayAttr(strides)); + } + + // Handle element type conversion if needed using bitcast_convert + auto currentType = cast(adaptedArg.getType()); + if (currentType.getElementType() != innerType.getElementType()) { + auto currentElemType = currentType.getElementType(); + auto targetElemType = innerType.getElementType(); + + // Calculate element sizes in bytes using AutoDiffTypeInterface + size_t currentSizeBytes = + cast(currentElemType).getApproxSize(); + size_t targetSizeBytes = + cast(targetElemType).getApproxSize(); + + assert(currentSizeBytes > 0 && targetSizeBytes > 0 && + "Element types must have valid size for conversion"); + + Value res; + auto currentShape = currentType.getShape(); + auto targetShape = innerType.getShape(); + + // Scalar i32 tensor type for shape constants + auto scalarI32Type = RankedTensorType::get({}, builder.getI32Type()); + + if (currentSizeBytes == targetSizeBytes) { + // Same size: direct bitcast + auto convertedType = RankedTensorType::get(targetShape, targetElemType); + res = builder.create(loc, convertedType, + adaptedArg); + } else if (targetSizeBytes < currentSizeBytes) { + // Target element is smaller: add dimension at the end + assert(currentSizeBytes % targetSizeBytes == 0 && + "Current element size must be divisible by target element size"); + size_t sizeRatio = currentSizeBytes / targetSizeBytes; + + SmallVector intermediateShape = llvm::to_vector(targetShape); + auto lastIdx = intermediateShape.size(); + intermediateShape.push_back(sizeRatio); + + // Adjust the last dimension if needed + if (lastIdx > 0 && + intermediateShape[lastIdx - 1] != ShapedType::kDynamic) { + intermediateShape[lastIdx - 1] /= sizeRatio; + } + + auto intermediateType = + RankedTensorType::get(intermediateShape, targetElemType); + res = builder.create(loc, intermediateType, + adaptedArg); + + // Always use dynamic reshape with GetDimensionSizeOp (will be optimized + // away for static shapes) + SmallVector shapeValues; + for (size_t i = 0; i < targetShape.size(); ++i) { + auto dimValue = + builder.create(loc, res, i); + shapeValues.push_back(dimValue); + } + auto shapeOp = + builder.create(loc, shapeValues, 0); + res = builder.create( + loc, RankedTensorType::get(targetShape, targetElemType), res, + shapeOp); + } else { + // Target element is larger: reshape first, then bitcast + assert(targetSizeBytes % currentSizeBytes == 0 && + "Target element size must be divisible by current element size"); + size_t sizeRatio = targetSizeBytes / currentSizeBytes; + + SmallVector intermediateShape = llvm::to_vector(currentShape); + auto lastIdx = intermediateShape.size(); + intermediateShape.push_back(sizeRatio); + + // Adjust the last dimension if needed + if (lastIdx > 0 && + intermediateShape[lastIdx - 1] != ShapedType::kDynamic) { + intermediateShape[lastIdx - 1] /= sizeRatio; + } + + // Always use dynamic reshape with GetDimensionSizeOp (will be optimized + // away for static shapes) + SmallVector shapeValues; + for (size_t i = 0; i < intermediateShape.size(); ++i) { + if (i < currentShape.size()) { + auto dimValue = builder.create( + loc, adaptedArg, i); + shapeValues.push_back(dimValue); + } else { + // This is the added dimension + auto constValue = builder.create( + loc, cast(makeAttr(scalarI32Type, sizeRatio))); + shapeValues.push_back(constValue); + } + } + auto shapeOp = + builder.create(loc, shapeValues, 0); + Value reshaped = builder.create( + loc, RankedTensorType::get(intermediateShape, currentElemType), + adaptedArg, shapeOp); + + // Now bitcast to target type + res = builder.create( + loc, RankedTensorType::get(targetShape, targetElemType), reshaped); + } + + adaptedArg = res; + } + + adaptedArgs.push_back(adaptedArg); + } + + // Call the original function with adapted arguments + auto callOp = builder.create(loc, f, adaptedArgs); + + // Return the results + builder.create(loc, callOp.getResults()); + + return wrapperFunc; +} + } // namespace enzyme namespace stablehlo { diff --git a/src/enzyme_ad/jax/Utils.h b/src/enzyme_ad/jax/Utils.h index 9f52f2371..e75994a26 100644 --- a/src/enzyme_ad/jax/Utils.h +++ b/src/enzyme_ad/jax/Utils.h @@ -12,6 +12,7 @@ #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" @@ -840,6 +841,18 @@ bool areValidInsertionDims(RankedTensorType inputType, bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp); +/// Adapt a function to a calling convention with different element types and +/// byte offsets. Creates a new wrapper function that performs necessary +/// conversions and slicing before calling the original function. +/// \param f The original MLIR function to wrap +/// \param inputTensorTypes The tensor types for the wrapper function arguments +/// \param byteOffsets Byte offsets for each argument (0 means no offset) +/// \return A new function that adapts the calling convention +mlir::func::FuncOp +adaptToCallingConvention(mlir::func::FuncOp f, + ArrayRef inputTensorTypes, + ArrayRef byteOffsets); + } // namespace enzyme namespace stablehlo {