-
Notifications
You must be signed in to change notification settings - Fork 24
Add adaptToCallingConvention utility for element type conversion and byte offset handling #1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Added function declaration to Utils.h - Implemented adaptToCallingConvention in Utils.cpp - Function creates wrapper that handles element type conversion and byte offsets - Uses stablehlo::ConvertOp for type conversion - Uses stablehlo::SliceOp for byte offset handling Co-authored-by: wsmoses <[email protected]>
- Convert linear byte offset to proper multi-dimensional indices - Use row-major (C-style) layout for offset calculation - Handle arbitrary tensor shapes correctly Co-authored-by: wsmoses <[email protected]>
- Remove trailing whitespace - Add proper handling for complex types in byte offset calculation - Add validation that byte offset aligns with element boundaries - Add bounds checking to prevent out-of-bounds slice access - Add assertions for invalid element types Co-authored-by: wsmoses <[email protected]>
src/enzyme_ad/jax/Utils.cpp
Outdated
| if (outerType.getElementType() != innerType.getElementType()) { | ||
| auto convertedType = RankedTensorType::get( | ||
| cast<RankedTensorType>(adaptedArg.getType()).getShape(), | ||
| innerType.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot you need to use BitCastConvertOp, look at other uses of it in the repo to learn how it might force an additional tensor argument to be added or removed. At this point we do not know whether the target type has any dynamic indices, so you will need to pick between dynamic or static reshape, accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use BitcastConvertOp instead of ConvertOp. The implementation now:
- Handles element size differences by adding/removing dimensions as needed
- Supports both dynamic and static reshape operations based on whether dimensions are dynamic
- Follows the pattern from
AffineToStableHLORaising.cpp
Fixed in commit 18cb37a.
src/enzyme_ad/jax/Utils.cpp
Outdated
|
|
||
| // Get element size in bytes | ||
| int64_t elementBytes = 0; | ||
| if (auto complexType = dyn_cast<ComplexType>(elementType)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot you can use AutodiffTypeInterface to simplify this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified to use AutoDiffTypeInterface::getApproxSize() instead of manual bit width calculation. This handles all element types including complex types uniformly.
Fixed in commit 30715f5.
- Use BitcastConvertOp instead of ConvertOp as requested - Handle cases where element sizes differ (add/remove dimensions) - Support both dynamic and static reshape operations - Add element size calculation for complex types - Match pattern from AffineToStableHLORaising.cpp Co-authored-by: wsmoses <[email protected]>
- Add validation that element sizes divide evenly - Use consistent loop variable types (size_t) - Extract size ratio to avoid repeated division - Fix potential integer truncation issues Co-authored-by: wsmoses <[email protected]>
- Add include for Interfaces/AutoDiffTypeInterface.h - Replace manual bit width calculation with getApproxSize() - Simplifies byte offset calculation for all element types - Simplifies size calculation for BitcastConvertOp logic Co-authored-by: wsmoses <[email protected]>
- Rename elementBytes to elementSizeBytes - Rename currentSize/targetSize to currentSizeBytes/targetSizeBytes - Makes it clear that sizes are in bytes, not bits Co-authored-by: wsmoses <[email protected]>
src/enzyme_ad/jax/Utils.h
Outdated
| /// \param inputTensorTypes The tensor types for the wrapper function arguments | ||
| /// \param byteOffsets Byte offsets for each argument (0 means no offset) | ||
| /// \return A new function that adapts the calling convention | ||
| mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
In file included from src/enzyme_ad/jax/Dialect/Ops.cpp:14:
./src/enzyme_ad/jax/Utils.h:850:7: error: no member named 'func' in namespace 'mlir'
850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
| ~~~~~~^
./src/enzyme_ad/jax/Utils.h:850:51: error: no member named 'func' in namespace 'mlir'
850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
| ~~~~~~^
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: warning: use of bitwise '|' with boolean operands [-Wbitwise-instead-of-logical]
1776 | if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| ||
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: note: cast one or both operands to int to silence this warning
1 warning and 2 errors generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed the compilation error by adding the missing include mlir/Dialect/Func/IR/FuncOps.h to Utils.h. This resolves the "no member named 'func' in namespace 'mlir'" error.
Fixed in commit 7fce353.
- Add mlir/Dialect/Func/IR/FuncOps.h include to Utils.h - Fixes compilation error: no member named 'func' in namespace 'mlir' Co-authored-by: wsmoses <[email protected]>
|
- Change ConstantOp to use RankedTensorType::get({}, i32Type) for scalar tensors
- Use cast<ElementsAttr>(makeAttr(...)) for creating constant values
- Fixes compilation error with stablehlo::ConstantOp::build
- Also fix GetDimensionSizeOp to return tensor type instead of scalar
Co-authored-by: wsmoses <[email protected]>
|
... Fixed the compilation error in Fixed in commit a4c219a. |
src/enzyme_ad/jax/Utils.cpp
Outdated
| // Get dynamic dimension from original tensor | ||
| auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>( | ||
| loc, builder.getI32Type(), adaptedArg, i); | ||
| loc, scalarI32Type, adaptedArg, i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot can you use the version that doesn't explicitly pass the return type here [it will be deduced automatically].
moreover, no need to have a separate path for dynamic vs constant, you can always use getdimensionsize as it will get optimized away later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified the reshape logic:
- Removed conditional branches for static vs dynamic shapes
- Always use
GetDimensionSizeOpfor all dimensions (will be optimized away for static shapes) - Removed explicit type parameter from
ConstantOp::create(type is deduced automatically)
This reduces code complexity from ~73 lines to ~27 lines while maintaining the same functionality.
Fixed in commit 33cf3f6.
- Define scalarI32Type once at the beginning of type conversion block - Remove duplicate definitions within nested scopes - Improves code maintainability and reduces duplication Co-authored-by: wsmoses <[email protected]>
- Remove conditional branches for static vs dynamic reshape - Always use GetDimensionSizeOp for all dimensions (optimized away for static shapes) - Remove explicit type parameter from ConstantOp::create (type deduced automatically) - Simplifies code and relies on compiler optimizations Co-authored-by: wsmoses <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EnzymeJAX Benchmarks
| Benchmark suite | Current: 33cf3f6 | Previous: ead4414 | Ratio |
|---|---|---|---|
scatter_sum / JaX / cpu / Primal |
0.000006275461000041105 s |
0.000004329831799987005 s |
1.45 |
scatter_sum / JaXPipe / cpu / Primal |
0.000003951574000166147 s |
0.000003912584799763863 s |
1.01 |
scatter_sum / JaX / tpu / Primal |
0.0001316881838996 s |
0.0001478510524 s |
0.89 |
scatter_sum / JaXPipe / tpu / Primal |
0.0001495134879001 s |
0.0001495474493 s |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
|
@copilot the format is wrong. see ci |
Run clang-format on Utils.cpp and Utils.h to fix formatting issues Co-authored-by: wsmoses <[email protected]>
Implementation Complete ✅
Changes Made:
adaptToCallingConventionfunction declaration toUtils.hadaptToCallingConventioninUtils.cppFunction Capabilities:
The
adaptToCallingConventionutility function:stablehlo::BitcastConvertOpfor element type conversionAutoDiffTypeInterface::getApproxSize()for clean element size calculationstablehlo::SliceOpImplementation Details for Type Conversion:
The implementation follows the pattern from
AffineToStableHLORaising.cppwith improved validation, type safety, clean maintainable code usingAutoDiffTypeInterface, proper MLIR tensor types, and simplified reshape logic that relies on compiler optimizations.Build Status:
✅ All compilation errors fixed
✅ Simplified implementation using uniform GetDimensionSizeOp approach
✅ Code formatting fixed
Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.