-
Notifications
You must be signed in to change notification settings - Fork 24
Add adaptToCallingConvention utility for element type conversion and byte offset handling #1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
89bb73e
e034410
5d6cc36
255768b
18cb37a
2d431e9
30715f5
3f51246
7fce353
a4c219a
6af39ce
33cf3f6
f132335
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1241,6 +1241,140 @@ bool mayReadFrom(Operation *op, Value val) { | |
| return true; | ||
| } | ||
|
|
||
| mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f, | ||
| ArrayRef<mlir::Type> inputTensorTypes, | ||
| ArrayRef<int64_t> byteOffsets) { | ||
| // Get the original function type | ||
| auto originalFuncType = f.getFunctionType(); | ||
| size_t numInputs = originalFuncType.getNumInputs(); | ||
|
|
||
| // Validate inputs | ||
| assert(inputTensorTypes.size() == numInputs && | ||
| "Number of input tensor types must match function inputs"); | ||
| assert(byteOffsets.size() == numInputs && | ||
| "Number of byte offsets must match function inputs"); | ||
|
|
||
| // Create the new function type using the outer specification types | ||
| auto context = f.getContext(); | ||
| auto loc = f.getLoc(); | ||
| auto newFuncType = mlir::FunctionType::get( | ||
| context, inputTensorTypes, originalFuncType.getResults()); | ||
|
|
||
| // Create a new function with a unique name | ||
| std::string wrapperName = (f.getName() + "_adapted").str(); | ||
| OpBuilder builder(context); | ||
| builder.setInsertionPoint(f); | ||
|
|
||
| auto wrapperFunc = builder.create<mlir::func::FuncOp>(loc, wrapperName, newFuncType); | ||
|
|
||
| // Add entry block to the wrapper function | ||
| auto &entryBlock = *wrapperFunc.addEntryBlock(); | ||
| builder.setInsertionPointToStart(&entryBlock); | ||
|
|
||
| // Process each argument | ||
| SmallVector<Value> adaptedArgs; | ||
| for (size_t i = 0; i < numInputs; ++i) { | ||
| Value arg = entryBlock.getArgument(i); | ||
| auto outerType = dyn_cast<RankedTensorType>(inputTensorTypes[i]); | ||
| auto innerType = dyn_cast<RankedTensorType>(originalFuncType.getInput(i)); | ||
|
|
||
| if (!outerType || !innerType) { | ||
| // If not tensor types, pass through as-is | ||
| adaptedArgs.push_back(arg); | ||
| continue; | ||
| } | ||
|
|
||
| Value adaptedArg = arg; | ||
|
|
||
| // Handle byte offset if non-zero | ||
| int64_t byteOffset = byteOffsets[i]; | ||
| if (byteOffset != 0) { | ||
| // Calculate element offset from byte offset | ||
| auto elementType = outerType.getElementType(); | ||
|
|
||
| // Get element size in bytes | ||
| int64_t elementBytes = 0; | ||
| if (auto complexType = dyn_cast<ComplexType>(elementType)) { | ||
| // Complex types have two components of the underlying element type | ||
| auto componentType = complexType.getElementType(); | ||
| unsigned componentBitWidth = componentType.getIntOrFloatBitWidth(); | ||
| elementBytes = 2 * ((componentBitWidth + 7) / 8); | ||
| } else { | ||
| unsigned elementBitWidth = elementType.getIntOrFloatBitWidth(); | ||
| assert(elementBitWidth > 0 && | ||
| "Element type must have valid bit width for byte offset calculation"); | ||
| elementBytes = (elementBitWidth + 7) / 8; | ||
| } | ||
|
|
||
| // Verify byte offset aligns with element boundaries | ||
| assert(byteOffset % elementBytes == 0 && | ||
| "Byte offset must be aligned to element boundaries"); | ||
|
|
||
| int64_t elementOffset = byteOffset / elementBytes; | ||
|
|
||
| auto outerShape = outerType.getShape(); | ||
| auto innerShape = innerType.getShape(); | ||
|
|
||
| // Convert linear element offset to multi-dimensional start indices | ||
| SmallVector<int64_t> startIndices; | ||
| SmallVector<int64_t> limitIndices; | ||
| SmallVector<int64_t> strides(outerShape.size(), 1); | ||
|
|
||
| int64_t remainingOffset = elementOffset; | ||
|
|
||
| // Calculate strides for each dimension (row-major order) | ||
| for (size_t j = 0; j < outerShape.size(); ++j) { | ||
| // Calculate the stride for this dimension | ||
| int64_t dimStride = 1; | ||
| for (size_t k = j + 1; k < outerShape.size(); ++k) { | ||
| dimStride *= outerShape[k]; | ||
| } | ||
|
|
||
| // Calculate the index for this dimension | ||
| int64_t dimIndex = remainingOffset / dimStride; | ||
| startIndices.push_back(dimIndex); | ||
|
|
||
| // Calculate the limit based on the inner shape | ||
| int64_t innerDim = (j < innerShape.size()) ? innerShape[j] : 1; | ||
| int64_t limitIndex = dimIndex + innerDim; | ||
|
|
||
| // Ensure limit doesn't exceed outer dimension bounds | ||
| assert(limitIndex <= outerShape[j] && | ||
| "Byte offset results in out-of-bounds access"); | ||
| limitIndices.push_back(limitIndex); | ||
|
|
||
| // Update remaining offset for next dimension | ||
| remainingOffset = remainingOffset % dimStride; | ||
| } | ||
|
|
||
| auto slicedType = RankedTensorType::get(innerShape, outerType.getElementType()); | ||
| adaptedArg = builder.create<stablehlo::SliceOp>( | ||
| loc, slicedType, adaptedArg, | ||
| builder.getDenseI64ArrayAttr(startIndices), | ||
| builder.getDenseI64ArrayAttr(limitIndices), | ||
| builder.getDenseI64ArrayAttr(strides)); | ||
| } | ||
|
|
||
| // Handle element type conversion if needed | ||
| if (outerType.getElementType() != innerType.getElementType()) { | ||
| auto convertedType = RankedTensorType::get( | ||
| cast<RankedTensorType>(adaptedArg.getType()).getShape(), | ||
| innerType.getElementType()); | ||
|
||
| adaptedArg = builder.create<stablehlo::ConvertOp>(loc, convertedType, adaptedArg); | ||
| } | ||
|
|
||
| adaptedArgs.push_back(adaptedArg); | ||
| } | ||
|
|
||
| // Call the original function with adapted arguments | ||
| auto callOp = builder.create<mlir::func::CallOp>(loc, f, adaptedArgs); | ||
|
|
||
| // Return the results | ||
| builder.create<mlir::func::ReturnOp>(loc, callOp.getResults()); | ||
|
|
||
| return wrapperFunc; | ||
| } | ||
|
|
||
| } // namespace enzyme | ||
|
|
||
| namespace stablehlo { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -840,6 +840,17 @@ bool areValidInsertionDims(RankedTensorType inputType, | |
|
|
||
| bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp); | ||
|
|
||
| /// Adapt a function to a calling convention with different element types and | ||
| /// byte offsets. Creates a new wrapper function that performs necessary | ||
| /// conversions and slicing before calling the original function. | ||
| /// \param f The original MLIR function to wrap | ||
| /// \param inputTensorTypes The tensor types for the wrapper function arguments | ||
| /// \param byteOffsets Byte offsets for each argument (0 means no offset) | ||
| /// \return A new function that adapts the calling convention | ||
| mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f, | ||
|
||
| ArrayRef<mlir::Type> inputTensorTypes, | ||
| ArrayRef<int64_t> byteOffsets); | ||
|
|
||
| } // namespace enzyme | ||
|
|
||
| namespace stablehlo { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot you can use AutodiffTypeInterface to simplify this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified to use
AutoDiffTypeInterface::getApproxSize()instead of manual bit width calculation. This handles all element types including complex types uniformly.Fixed in commit 30715f5.