-
Notifications
You must be signed in to change notification settings - Fork 24
Add adaptToCallingConvention utility for element type conversion and byte offset handling #1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
89bb73e
e034410
5d6cc36
255768b
18cb37a
2d431e9
30715f5
3f51246
7fce353
a4c219a
6af39ce
33cf3f6
f132335
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| #include "llvm/ADT/SmallVector.h" | ||
|
|
||
| #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/IR/IRMapping.h" | ||
| #include "mlir/IR/IntegerSet.h" | ||
|
|
@@ -840,6 +841,17 @@ bool areValidInsertionDims(RankedTensorType inputType, | |
|
|
||
| bool isOnlyUsedInOperation(Operation *operation, Operation *parentOp); | ||
|
|
||
| /// Adapt a function to a calling convention with different element types and | ||
| /// byte offsets. Creates a new wrapper function that performs necessary | ||
| /// conversions and slicing before calling the original function. | ||
| /// \param f The original MLIR function to wrap | ||
| /// \param inputTensorTypes The tensor types for the wrapper function arguments | ||
| /// \param byteOffsets Byte offsets for each argument (0 means no offset) | ||
| /// \return A new function that adapts the calling convention | ||
| mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f, | ||
|
||
| ArrayRef<mlir::Type> inputTensorTypes, | ||
| ArrayRef<int64_t> byteOffsets); | ||
|
|
||
| } // namespace enzyme | ||
|
|
||
| namespace stablehlo { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot can you use the version that doesn't explicitly pass the return type here [it will be deduced automatically].
moreover, no need to have a separate path for dynamic vs constant, you can always use getdimensionsize as it will get optimized away later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified the reshape logic:
GetDimensionSizeOpfor all dimensions (will be optimized away for static shapes)ConstantOp::create(type is deduced automatically)This reduces code complexity from ~73 lines to ~27 lines while maintaining the same functionality.
Fixed in commit 33cf3f6.