Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,140 @@ bool mayReadFrom(Operation *op, Value val) {
return true;
}

mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
ArrayRef<mlir::Type> inputTensorTypes,
ArrayRef<int64_t> byteOffsets) {
// Get the original function type
auto originalFuncType = f.getFunctionType();
size_t numInputs = originalFuncType.getNumInputs();

// Validate inputs
assert(inputTensorTypes.size() == numInputs &&
"Number of input tensor types must match function inputs");
assert(byteOffsets.size() == numInputs &&
"Number of byte offsets must match function inputs");

// Create the new function type using the outer specification types
auto context = f.getContext();
auto loc = f.getLoc();
auto newFuncType = mlir::FunctionType::get(
context, inputTensorTypes, originalFuncType.getResults());

// Create a new function with a unique name
std::string wrapperName = (f.getName() + "_adapted").str();
OpBuilder builder(context);
builder.setInsertionPoint(f);

auto wrapperFunc = builder.create<mlir::func::FuncOp>(loc, wrapperName, newFuncType);

// Add entry block to the wrapper function
auto &entryBlock = *wrapperFunc.addEntryBlock();
builder.setInsertionPointToStart(&entryBlock);

// Process each argument
SmallVector<Value> adaptedArgs;
for (size_t i = 0; i < numInputs; ++i) {
Value arg = entryBlock.getArgument(i);
auto outerType = dyn_cast<RankedTensorType>(inputTensorTypes[i]);
auto innerType = dyn_cast<RankedTensorType>(originalFuncType.getInput(i));

if (!outerType || !innerType) {
// If not tensor types, pass through as-is
adaptedArgs.push_back(arg);
continue;
}

Value adaptedArg = arg;

// Handle byte offset if non-zero
int64_t byteOffset = byteOffsets[i];
if (byteOffset != 0) {
// Calculate element offset from byte offset
auto elementType = outerType.getElementType();

// Get element size in bytes
int64_t elementBytes = 0;
if (auto complexType = dyn_cast<ComplexType>(elementType)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot you can use AutodiffTypeInterface to simplify this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified to use AutoDiffTypeInterface::getApproxSize() instead of manual bit width calculation. This handles all element types including complex types uniformly.

Fixed in commit 30715f5.

// Complex types have two components of the underlying element type
auto componentType = complexType.getElementType();
unsigned componentBitWidth = componentType.getIntOrFloatBitWidth();
elementBytes = 2 * ((componentBitWidth + 7) / 8);
} else {
unsigned elementBitWidth = elementType.getIntOrFloatBitWidth();
assert(elementBitWidth > 0 &&
"Element type must have valid bit width for byte offset calculation");
elementBytes = (elementBitWidth + 7) / 8;
}

// Verify byte offset aligns with element boundaries
assert(byteOffset % elementBytes == 0 &&
"Byte offset must be aligned to element boundaries");

int64_t elementOffset = byteOffset / elementBytes;

auto outerShape = outerType.getShape();
auto innerShape = innerType.getShape();

// Convert linear element offset to multi-dimensional start indices
SmallVector<int64_t> startIndices;
SmallVector<int64_t> limitIndices;
SmallVector<int64_t> strides(outerShape.size(), 1);

int64_t remainingOffset = elementOffset;

// Calculate strides for each dimension (row-major order)
for (size_t j = 0; j < outerShape.size(); ++j) {
// Calculate the stride for this dimension
int64_t dimStride = 1;
for (size_t k = j + 1; k < outerShape.size(); ++k) {
dimStride *= outerShape[k];
}

// Calculate the index for this dimension
int64_t dimIndex = remainingOffset / dimStride;
startIndices.push_back(dimIndex);

// Calculate the limit based on the inner shape
int64_t innerDim = (j < innerShape.size()) ? innerShape[j] : 1;
int64_t limitIndex = dimIndex + innerDim;

// Ensure limit doesn't exceed outer dimension bounds
assert(limitIndex <= outerShape[j] &&
"Byte offset results in out-of-bounds access");
limitIndices.push_back(limitIndex);

// Update remaining offset for next dimension
remainingOffset = remainingOffset % dimStride;
}

auto slicedType = RankedTensorType::get(innerShape, outerType.getElementType());
adaptedArg = builder.create<stablehlo::SliceOp>(
loc, slicedType, adaptedArg,
builder.getDenseI64ArrayAttr(startIndices),
builder.getDenseI64ArrayAttr(limitIndices),
builder.getDenseI64ArrayAttr(strides));
}

// Handle element type conversion if needed
if (outerType.getElementType() != innerType.getElementType()) {
auto convertedType = RankedTensorType::get(
cast<RankedTensorType>(adaptedArg.getType()).getShape(),
innerType.getElementType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot you need to use BitCastConvertOp, look at other uses of it in the repo to learn how it might force an additional tensor argument to be added or removed. At this point we do not know whether the target type has any dynamic indices, so you will need to pick between dynamic or static reshape, accordingly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use BitcastConvertOp instead of ConvertOp. The implementation now:

  • Handles element size differences by adding/removing dimensions as needed
  • Supports both dynamic and static reshape operations based on whether dimensions are dynamic
  • Follows the pattern from AffineToStableHLORaising.cpp

Fixed in commit 18cb37a.

adaptedArg = builder.create<stablehlo::ConvertOp>(loc, convertedType, adaptedArg);
}

adaptedArgs.push_back(adaptedArg);
}

// Call the original function with adapted arguments
auto callOp = builder.create<mlir::func::CallOp>(loc, f, adaptedArgs);

// Return the results
builder.create<mlir::func::ReturnOp>(loc, callOp.getResults());

return wrapperFunc;
}

} // namespace enzyme

namespace stablehlo {
Expand Down
11 changes: 11 additions & 0 deletions src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,17 @@ bool areValidInsertionDims(RankedTensorType inputType,

bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp);

/// Adapt a function to a calling convention with different element types and
/// byte offsets. Creates a new wrapper function that performs necessary
/// conversions and slicing before calling the original function.
/// \param f The original MLIR function to wrap
/// \param inputTensorTypes The tensor types for the wrapper function arguments
/// \param byteOffsets Byte offsets for each argument (0 means no offset)
/// \return A new function that adapts the calling convention
mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
In file included from src/enzyme_ad/jax/Dialect/Ops.cpp:14:
./src/enzyme_ad/jax/Utils.h:850:7: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      | ~~~~~~^
./src/enzyme_ad/jax/Utils.h:850:51: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      |                                             ~~~~~~^
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: warning: use of bitwise '|' with boolean operands [-Wbitwise-instead-of-logical]
 1776 |     if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
      |         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      |                                         ||
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: note: cast one or both operands to int to silence this warning
1 warning and 2 errors generated.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the compilation error by adding the missing include mlir/Dialect/Func/IR/FuncOps.h to Utils.h. This resolves the "no member named 'func' in namespace 'mlir'" error.

Fixed in commit 7fce353.

ArrayRef<mlir::Type> inputTensorTypes,
ArrayRef<int64_t> byteOffsets);

} // namespace enzyme

namespace stablehlo {
Expand Down