Skip to content
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

#include "Interfaces/AutoDiffTypeInterface.h"

#include <set>

using namespace mlir;
Expand Down Expand Up @@ -1241,6 +1243,220 @@ bool mayReadFrom(Operation *op, Value val) {
return true;
}

mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
ArrayRef<mlir::Type> inputTensorTypes,
ArrayRef<int64_t> byteOffsets) {
// Get the original function type
auto originalFuncType = f.getFunctionType();
size_t numInputs = originalFuncType.getNumInputs();

// Validate inputs
assert(inputTensorTypes.size() == numInputs &&
"Number of input tensor types must match function inputs");
assert(byteOffsets.size() == numInputs &&
"Number of byte offsets must match function inputs");

// Create the new function type using the outer specification types
auto context = f.getContext();
auto loc = f.getLoc();
auto newFuncType = mlir::FunctionType::get(
context, inputTensorTypes, originalFuncType.getResults());

// Create a new function with a unique name
std::string wrapperName = (f.getName() + "_adapted").str();
OpBuilder builder(context);
builder.setInsertionPoint(f);

auto wrapperFunc = builder.create<mlir::func::FuncOp>(loc, wrapperName, newFuncType);

// Add entry block to the wrapper function
auto &entryBlock = *wrapperFunc.addEntryBlock();
builder.setInsertionPointToStart(&entryBlock);

// Process each argument
SmallVector<Value> adaptedArgs;
for (size_t i = 0; i < numInputs; ++i) {
Value arg = entryBlock.getArgument(i);
auto outerType = dyn_cast<RankedTensorType>(inputTensorTypes[i]);
auto innerType = dyn_cast<RankedTensorType>(originalFuncType.getInput(i));

if (!outerType || !innerType) {
// If not tensor types, pass through as-is
adaptedArgs.push_back(arg);
continue;
}

Value adaptedArg = arg;

// Handle byte offset if non-zero
int64_t byteOffset = byteOffsets[i];
if (byteOffset != 0) {
// Calculate element offset from byte offset
auto elementType = outerType.getElementType();

// Get element size in bytes using AutoDiffTypeInterface
size_t elementSizeBytes =
cast<AutoDiffTypeInterface>(elementType).getApproxSize();

// Verify byte offset aligns with element boundaries
assert(byteOffset % elementSizeBytes == 0 &&
"Byte offset must be aligned to element boundaries");

int64_t elementOffset = byteOffset / elementSizeBytes;

auto outerShape = outerType.getShape();
auto innerShape = innerType.getShape();

// Convert linear element offset to multi-dimensional start indices
SmallVector<int64_t> startIndices;
SmallVector<int64_t> limitIndices;
SmallVector<int64_t> strides(outerShape.size(), 1);

int64_t remainingOffset = elementOffset;

// Calculate strides for each dimension (row-major order)
for (size_t j = 0; j < outerShape.size(); ++j) {
// Calculate the stride for this dimension
int64_t dimStride = 1;
for (size_t k = j + 1; k < outerShape.size(); ++k) {
dimStride *= outerShape[k];
}

// Calculate the index for this dimension
int64_t dimIndex = remainingOffset / dimStride;
startIndices.push_back(dimIndex);

// Calculate the limit based on the inner shape
int64_t innerDim = (j < innerShape.size()) ? innerShape[j] : 1;
int64_t limitIndex = dimIndex + innerDim;

// Ensure limit doesn't exceed outer dimension bounds
assert(limitIndex <= outerShape[j] &&
"Byte offset results in out-of-bounds access");
limitIndices.push_back(limitIndex);

// Update remaining offset for next dimension
remainingOffset = remainingOffset % dimStride;
}

auto slicedType = RankedTensorType::get(innerShape, outerType.getElementType());
adaptedArg = builder.create<stablehlo::SliceOp>(
loc, slicedType, adaptedArg,
builder.getDenseI64ArrayAttr(startIndices),
builder.getDenseI64ArrayAttr(limitIndices),
builder.getDenseI64ArrayAttr(strides));
}

// Handle element type conversion if needed using bitcast_convert
auto currentType = cast<RankedTensorType>(adaptedArg.getType());
if (currentType.getElementType() != innerType.getElementType()) {
auto currentElemType = currentType.getElementType();
auto targetElemType = innerType.getElementType();

// Calculate element sizes in bytes using AutoDiffTypeInterface
size_t currentSizeBytes =
cast<AutoDiffTypeInterface>(currentElemType).getApproxSize();
size_t targetSizeBytes =
cast<AutoDiffTypeInterface>(targetElemType).getApproxSize();

assert(currentSizeBytes > 0 && targetSizeBytes > 0 &&
"Element types must have valid size for conversion");

Value res;
auto currentShape = currentType.getShape();
auto targetShape = innerType.getShape();

// Scalar i32 tensor type for shape constants
auto scalarI32Type = RankedTensorType::get({}, builder.getI32Type());

if (currentSizeBytes == targetSizeBytes) {
// Same size: direct bitcast
auto convertedType = RankedTensorType::get(targetShape, targetElemType);
res = builder.create<stablehlo::BitcastConvertOp>(loc, convertedType, adaptedArg);
} else if (targetSizeBytes < currentSizeBytes) {
// Target element is smaller: add dimension at the end
assert(currentSizeBytes % targetSizeBytes == 0 &&
"Current element size must be divisible by target element size");
size_t sizeRatio = currentSizeBytes / targetSizeBytes;

SmallVector<int64_t> intermediateShape = llvm::to_vector(targetShape);
auto lastIdx = intermediateShape.size();
intermediateShape.push_back(sizeRatio);

// Adjust the last dimension if needed
if (lastIdx > 0 && intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
intermediateShape[lastIdx - 1] /= sizeRatio;
}

auto intermediateType = RankedTensorType::get(intermediateShape, targetElemType);
res = builder.create<stablehlo::BitcastConvertOp>(loc, intermediateType, adaptedArg);

// Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
SmallVector<Value> shapeValues;
for (size_t i = 0; i < targetShape.size(); ++i) {
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
loc, res, i);
shapeValues.push_back(dimValue);
}
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
loc, shapeValues, 0);
res = builder.create<stablehlo::DynamicReshapeOp>(
loc, RankedTensorType::get(targetShape, targetElemType), res, shapeOp);
} else {
// Target element is larger: reshape first, then bitcast
assert(targetSizeBytes % currentSizeBytes == 0 &&
"Target element size must be divisible by current element size");
size_t sizeRatio = targetSizeBytes / currentSizeBytes;

SmallVector<int64_t> intermediateShape = llvm::to_vector(currentShape);
auto lastIdx = intermediateShape.size();
intermediateShape.push_back(sizeRatio);

// Adjust the last dimension if needed
if (lastIdx > 0 && intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
intermediateShape[lastIdx - 1] /= sizeRatio;
}

// Always use dynamic reshape with GetDimensionSizeOp (will be optimized away for static shapes)
SmallVector<Value> shapeValues;
for (size_t i = 0; i < intermediateShape.size(); ++i) {
if (i < currentShape.size()) {
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
loc, adaptedArg, i);
shapeValues.push_back(dimValue);
} else {
// This is the added dimension
auto constValue = builder.create<stablehlo::ConstantOp>(
loc, cast<ElementsAttr>(makeAttr(scalarI32Type, sizeRatio)));
shapeValues.push_back(constValue);
}
}
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
loc, shapeValues, 0);
Value reshaped = builder.create<stablehlo::DynamicReshapeOp>(
loc, RankedTensorType::get(intermediateShape, currentElemType),
adaptedArg, shapeOp);

// Now bitcast to target type
res = builder.create<stablehlo::BitcastConvertOp>(
loc, RankedTensorType::get(targetShape, targetElemType), reshaped);
}

adaptedArg = res;
}

adaptedArgs.push_back(adaptedArg);
}

// Call the original function with adapted arguments
auto callOp = builder.create<mlir::func::CallOp>(loc, f, adaptedArgs);

// Return the results
builder.create<mlir::func::ReturnOp>(loc, callOp.getResults());

return wrapperFunc;
}

} // namespace enzyme

namespace stablehlo {
Expand Down
12 changes: 12 additions & 0 deletions src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "llvm/ADT/SmallVector.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
Expand Down Expand Up @@ -840,6 +841,17 @@ bool areValidInsertionDims(RankedTensorType inputType,

bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp);

/// Adapt a function to a calling convention with different element types and
/// byte offsets. Creates a new wrapper function that performs necessary
/// conversions and slicing before calling the original function.
/// \param f The original MLIR function to wrap
/// \param inputTensorTypes The tensor types for the wrapper function arguments
/// \param byteOffsets Byte offsets for each argument (0 means no offset)
/// \return A new function that adapts the calling convention
mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
In file included from src/enzyme_ad/jax/Dialect/Ops.cpp:14:
./src/enzyme_ad/jax/Utils.h:850:7: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      | ~~~~~~^
./src/enzyme_ad/jax/Utils.h:850:51: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      |                                             ~~~~~~^
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: warning: use of bitwise '|' with boolean operands [-Wbitwise-instead-of-logical]
 1776 |     if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
      |         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      |                                         ||
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: note: cast one or both operands to int to silence this warning
1 warning and 2 errors generated.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the compilation error by adding the missing include mlir/Dialect/Func/IR/FuncOps.h to Utils.h. This resolves the "no member named 'func' in namespace 'mlir'" error.

Fixed in commit 7fce353.

ArrayRef<mlir::Type> inputTensorTypes,
ArrayRef<int64_t> byteOffsets);

} // namespace enzyme

namespace stablehlo {
Expand Down
Loading