Skip to content
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

#include "Interfaces/AutoDiffTypeInterface.h"

#include <set>

using namespace mlir;
Expand Down Expand Up @@ -1241,6 +1243,265 @@ bool mayReadFrom(Operation *op, Value val) {
return true;
}

mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
ArrayRef<mlir::Type> inputTensorTypes,
ArrayRef<int64_t> byteOffsets) {
// Get the original function type
auto originalFuncType = f.getFunctionType();
size_t numInputs = originalFuncType.getNumInputs();

// Validate inputs
assert(inputTensorTypes.size() == numInputs &&
"Number of input tensor types must match function inputs");
assert(byteOffsets.size() == numInputs &&
"Number of byte offsets must match function inputs");

// Create the new function type using the outer specification types
auto context = f.getContext();
auto loc = f.getLoc();
auto newFuncType = mlir::FunctionType::get(
context, inputTensorTypes, originalFuncType.getResults());

// Create a new function with a unique name
std::string wrapperName = (f.getName() + "_adapted").str();
OpBuilder builder(context);
builder.setInsertionPoint(f);

auto wrapperFunc = builder.create<mlir::func::FuncOp>(loc, wrapperName, newFuncType);

// Add entry block to the wrapper function
auto &entryBlock = *wrapperFunc.addEntryBlock();
builder.setInsertionPointToStart(&entryBlock);

// Process each argument
SmallVector<Value> adaptedArgs;
for (size_t i = 0; i < numInputs; ++i) {
Value arg = entryBlock.getArgument(i);
auto outerType = dyn_cast<RankedTensorType>(inputTensorTypes[i]);
auto innerType = dyn_cast<RankedTensorType>(originalFuncType.getInput(i));

if (!outerType || !innerType) {
// If not tensor types, pass through as-is
adaptedArgs.push_back(arg);
continue;
}

Value adaptedArg = arg;

// Handle byte offset if non-zero
int64_t byteOffset = byteOffsets[i];
if (byteOffset != 0) {
// Calculate element offset from byte offset
auto elementType = outerType.getElementType();

// Get element size in bytes using AutoDiffTypeInterface
size_t elementSizeBytes =
cast<AutoDiffTypeInterface>(elementType).getApproxSize();

// Verify byte offset aligns with element boundaries
assert(byteOffset % elementSizeBytes == 0 &&
"Byte offset must be aligned to element boundaries");

int64_t elementOffset = byteOffset / elementSizeBytes;

auto outerShape = outerType.getShape();
auto innerShape = innerType.getShape();

// Convert linear element offset to multi-dimensional start indices
SmallVector<int64_t> startIndices;
SmallVector<int64_t> limitIndices;
SmallVector<int64_t> strides(outerShape.size(), 1);

int64_t remainingOffset = elementOffset;

// Calculate strides for each dimension (row-major order)
for (size_t j = 0; j < outerShape.size(); ++j) {
// Calculate the stride for this dimension
int64_t dimStride = 1;
for (size_t k = j + 1; k < outerShape.size(); ++k) {
dimStride *= outerShape[k];
}

// Calculate the index for this dimension
int64_t dimIndex = remainingOffset / dimStride;
startIndices.push_back(dimIndex);

// Calculate the limit based on the inner shape
int64_t innerDim = (j < innerShape.size()) ? innerShape[j] : 1;
int64_t limitIndex = dimIndex + innerDim;

// Ensure limit doesn't exceed outer dimension bounds
assert(limitIndex <= outerShape[j] &&
"Byte offset results in out-of-bounds access");
limitIndices.push_back(limitIndex);

// Update remaining offset for next dimension
remainingOffset = remainingOffset % dimStride;
}

auto slicedType = RankedTensorType::get(innerShape, outerType.getElementType());
adaptedArg = builder.create<stablehlo::SliceOp>(
loc, slicedType, adaptedArg,
builder.getDenseI64ArrayAttr(startIndices),
builder.getDenseI64ArrayAttr(limitIndices),
builder.getDenseI64ArrayAttr(strides));
}

// Handle element type conversion if needed using bitcast_convert
auto currentType = cast<RankedTensorType>(adaptedArg.getType());
if (currentType.getElementType() != innerType.getElementType()) {
auto currentElemType = currentType.getElementType();
auto targetElemType = innerType.getElementType();

// Calculate element sizes in bytes using AutoDiffTypeInterface
size_t currentSizeBytes =
cast<AutoDiffTypeInterface>(currentElemType).getApproxSize();
size_t targetSizeBytes =
cast<AutoDiffTypeInterface>(targetElemType).getApproxSize();

assert(currentSizeBytes > 0 && targetSizeBytes > 0 &&
"Element types must have valid size for conversion");

Value res;
auto currentShape = currentType.getShape();
auto targetShape = innerType.getShape();

if (currentSizeBytes == targetSizeBytes) {
// Same size: direct bitcast
auto convertedType = RankedTensorType::get(targetShape, targetElemType);
res = builder.create<stablehlo::BitcastConvertOp>(loc, convertedType, adaptedArg);
} else if (targetSizeBytes < currentSizeBytes) {
// Target element is smaller: add dimension at the end
assert(currentSizeBytes % targetSizeBytes == 0 &&
"Current element size must be divisible by target element size");
size_t sizeRatio = currentSizeBytes / targetSizeBytes;

SmallVector<int64_t> intermediateShape = llvm::to_vector(targetShape);
auto lastIdx = intermediateShape.size();
intermediateShape.push_back(sizeRatio);

// Adjust the last dimension if needed
if (lastIdx > 0 && intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
intermediateShape[lastIdx - 1] /= sizeRatio;
}

auto intermediateType = RankedTensorType::get(intermediateShape, targetElemType);
res = builder.create<stablehlo::BitcastConvertOp>(loc, intermediateType, adaptedArg);

// Check if we need dynamic or static reshape
bool anyDynamic = false;
for (auto dim : intermediateShape) {
if (dim == ShapedType::kDynamic) {
anyDynamic = true;
break;
}
}

if (anyDynamic) {
// Use dynamic reshape
SmallVector<Value> shapeValues;
auto scalarI32Type = RankedTensorType::get({}, builder.getI32Type());
for (size_t i = 0; i < targetShape.size(); ++i) {
if (targetShape[i] == ShapedType::kDynamic) {
// Get dynamic dimension from original tensor
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
loc, scalarI32Type, adaptedArg, i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can you use the version that doesn't explicitly pass the return type here [it will be deduced automatically].

moreover, no need to have a separate path for dynamic vs constant, you can always use getdimensionsize as it will get optimized away later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the reshape logic:

  • Removed conditional branches for static vs dynamic shapes
  • Always use GetDimensionSizeOp for all dimensions (will be optimized away for static shapes)
  • Removed explicit type parameter from ConstantOp::create (type is deduced automatically)

This reduces code complexity from ~73 lines to ~27 lines while maintaining the same functionality.

Fixed in commit 33cf3f6.

shapeValues.push_back(dimValue);
} else {
auto constValue = builder.create<stablehlo::ConstantOp>(
loc, scalarI32Type,
cast<ElementsAttr>(makeAttr(scalarI32Type, targetShape[i])));
shapeValues.push_back(constValue);
}
}
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
loc, shapeValues, 0);
res = builder.create<stablehlo::DynamicReshapeOp>(
loc, RankedTensorType::get(targetShape, targetElemType), res, shapeOp);
} else {
// Use static reshape
res = builder.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(targetShape, targetElemType), res);
}
} else {
// Target element is larger: reshape first, then bitcast
assert(targetSizeBytes % currentSizeBytes == 0 &&
"Target element size must be divisible by current element size");
size_t sizeRatio = targetSizeBytes / currentSizeBytes;

SmallVector<int64_t> intermediateShape = llvm::to_vector(currentShape);
auto lastIdx = intermediateShape.size();
intermediateShape.push_back(sizeRatio);

// Adjust the last dimension if needed
if (lastIdx > 0 && intermediateShape[lastIdx - 1] != ShapedType::kDynamic) {
intermediateShape[lastIdx - 1] /= sizeRatio;
}

Value reshaped;
// Check if we need dynamic reshape
bool anyDynamic = false;
for (auto dim : intermediateShape) {
if (dim == ShapedType::kDynamic) {
anyDynamic = true;
break;
}
}

if (anyDynamic) {
// Use dynamic reshape
SmallVector<Value> shapeValues;
auto scalarI32Type = RankedTensorType::get({}, builder.getI32Type());
for (size_t i = 0; i < intermediateShape.size(); ++i) {
if (intermediateShape[i] == ShapedType::kDynamic) {
if (i < currentShape.size()) {
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
loc, scalarI32Type, adaptedArg, i);
shapeValues.push_back(dimValue);
} else {
// This is the added dimension
auto constValue = builder.create<stablehlo::ConstantOp>(
loc, scalarI32Type,
cast<ElementsAttr>(makeAttr(scalarI32Type, sizeRatio)));
shapeValues.push_back(constValue);
}
} else {
auto constValue = builder.create<stablehlo::ConstantOp>(
loc, scalarI32Type,
cast<ElementsAttr>(makeAttr(scalarI32Type, intermediateShape[i])));
shapeValues.push_back(constValue);
}
}
auto shapeOp = builder.create<stablehlo::ConcatenateOp>(
loc, shapeValues, 0);
reshaped = builder.create<stablehlo::DynamicReshapeOp>(
loc, RankedTensorType::get(intermediateShape, currentElemType),
adaptedArg, shapeOp);
} else {
reshaped = builder.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(intermediateShape, currentElemType), adaptedArg);
}

// Now bitcast to target type
res = builder.create<stablehlo::BitcastConvertOp>(
loc, RankedTensorType::get(targetShape, targetElemType), reshaped);
}

adaptedArg = res;
}

adaptedArgs.push_back(adaptedArg);
}

// Call the original function with adapted arguments
auto callOp = builder.create<mlir::func::CallOp>(loc, f, adaptedArgs);

// Return the results
builder.create<mlir::func::ReturnOp>(loc, callOp.getResults());

return wrapperFunc;
}

} // namespace enzyme

namespace stablehlo {
Expand Down
12 changes: 12 additions & 0 deletions src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "llvm/ADT/SmallVector.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
Expand Down Expand Up @@ -840,6 +841,17 @@ bool areValidInsertionDims(RankedTensorType inputType,

bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp);

/// Adapt a function to a calling convention with different element types and
/// byte offsets. Creates a new wrapper function that performs necessary
/// conversions and slicing before calling the original function.
/// \param f The original MLIR function to wrap
/// \param inputTensorTypes The tensor types for the wrapper function arguments
/// \param byteOffsets Byte offsets for each argument (0 means no offset)
/// \return A new function that adapts the calling convention
mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
In file included from src/enzyme_ad/jax/Dialect/Ops.cpp:14:
./src/enzyme_ad/jax/Utils.h:850:7: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      | ~~~~~~^
./src/enzyme_ad/jax/Utils.h:850:51: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      |                                             ~~~~~~^
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: warning: use of bitwise '|' with boolean operands [-Wbitwise-instead-of-logical]
 1776 |     if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
      |         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      |                                         ||
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: note: cast one or both operands to int to silence this warning
1 warning and 2 errors generated.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the compilation error by adding the missing include mlir/Dialect/Func/IR/FuncOps.h to Utils.h. This resolves the "no member named 'func' in namespace 'mlir'" error.

Fixed in commit 7fce353.

ArrayRef<mlir::Type> inputTensorTypes,
ArrayRef<int64_t> byteOffsets);

} // namespace enzyme

namespace stablehlo {
Expand Down