Skip to content

Conversation

Copy link

Copilot AI commented Dec 8, 2025

Implementation Complete ✅

Changes Made:

  1. ✅ Added adaptToCallingConvention function declaration to Utils.h
  2. ✅ Implemented adaptToCallingConvention in Utils.cpp
  3. ✅ Addressed all code review feedback
  4. Updated to use BitcastConvertOp instead of ConvertOp
  5. Fixed validation and type consistency issues
  6. Simplified element size calculation using AutoDiffTypeInterface
  7. Improved variable naming for clarity
  8. Fixed missing include for func::FuncOp
  9. Fixed ConstantOp creation to use proper tensor types
  10. Refactored to reduce code duplication
  11. Simplified reshape logic to always use GetDimensionSizeOp
  12. Fixed code formatting with clang-format
  13. ✅ Security checks passed

Function Capabilities:

The adaptToCallingConvention utility function:

  • Creates a wrapper function that adapts calling conventions
  • Uses stablehlo::BitcastConvertOp for element type conversion
  • Uses AutoDiffTypeInterface::getApproxSize() for clean element size calculation
  • Validates that element sizes divide evenly
  • Handles element size differences by adding/removing tensor dimensions
  • Always uses dynamic reshape with GetDimensionSizeOp (optimized away for static shapes)
  • Handles byte offset slicing using stablehlo::SliceOp
  • Properly supports complex types
  • Validates byte alignment to element boundaries
  • Performs bounds checking to prevent out-of-bounds access
  • Uses row-major (C-style) layout for multi-dimensional offset calculations

Implementation Details for Type Conversion:

  1. Same element size: Direct bitcast conversion
  2. Target element smaller: Validates divisibility, bitcast with added dimension, then dynamic reshape
  3. Target element larger: Validates divisibility, dynamic reshape first, then bitcast

The implementation follows the pattern from AffineToStableHLORaising.cpp with improved validation, type safety, clean maintainable code using AutoDiffTypeInterface, proper MLIR tensor types, and simplified reshape logic that relies on compiler optimizations.

Build Status:

✅ All compilation errors fixed
✅ Simplified implementation using uniform GetDimensionSizeOp approach
✅ Code formatting fixed

Original prompt

This section details on the original issue you should resolve

<issue_title>Element type Conversion and offset</issue_title>
<issue_description>We need a utility function adaptToCallingConvention that takes an mlir function f, a list of input tensor types, and a list of byte offsets, passed as an ArrayRef<size_t>. adaptToCallingConvention should output a new mlir function.

This new function should call the original f. If the element types differ, it should perform casts to ensure the argument of type of the outer specification can be passed to f. Also, if the corresponding byte offset is non-zero, it should slice the argument such that the first byte passed to f is the corresponding offset number of bytes into the outer argument.

Code should be added to Utils.cpp and Utils.h

</issue_description>

Comments on the Issue (you are @copilot in this section)


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI and others added 3 commits December 8, 2025 00:25
- Added function declaration to Utils.h
- Implemented adaptToCallingConvention in Utils.cpp
- Function creates wrapper that handles element type conversion and byte offsets
- Uses stablehlo::ConvertOp for type conversion
- Uses stablehlo::SliceOp for byte offset handling

Co-authored-by: wsmoses <[email protected]>
- Convert linear byte offset to proper multi-dimensional indices
- Use row-major (C-style) layout for offset calculation
- Handle arbitrary tensor shapes correctly

Co-authored-by: wsmoses <[email protected]>
- Remove trailing whitespace
- Add proper handling for complex types in byte offset calculation
- Add validation that byte offset aligns with element boundaries
- Add bounds checking to prevent out-of-bounds slice access
- Add assertions for invalid element types

Co-authored-by: wsmoses <[email protected]>
Copilot AI changed the title [WIP] Add utility function to adapt function calling convention Add adaptToCallingConvention utility for element type conversion and byte offset handling Dec 8, 2025
Copilot AI requested a review from wsmoses December 8, 2025 00:30
if (outerType.getElementType() != innerType.getElementType()) {
auto convertedType = RankedTensorType::get(
cast<RankedTensorType>(adaptedArg.getType()).getShape(),
innerType.getElementType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot you need to use BitCastConvertOp, look at other uses of it in the repo to learn how it might force an additional tensor argument to be added or removed. At this point we do not know whether the target type has any dynamic indices, so you will need to pick between dynamic or static reshape, accordingly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use BitcastConvertOp instead of ConvertOp. The implementation now:

  • Handles element size differences by adding/removing dimensions as needed
  • Supports both dynamic and static reshape operations based on whether dimensions are dynamic
  • Follows the pattern from AffineToStableHLORaising.cpp

Fixed in commit 18cb37a.


// Get element size in bytes
int64_t elementBytes = 0;
if (auto complexType = dyn_cast<ComplexType>(elementType)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot you can use AutodiffTypeInterface to simplify this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified to use AutoDiffTypeInterface::getApproxSize() instead of manual bit width calculation. This handles all element types including complex types uniformly.

Fixed in commit 30715f5.

Copilot AI and others added 2 commits December 8, 2025 00:39
- Use BitcastConvertOp instead of ConvertOp as requested
- Handle cases where element sizes differ (add/remove dimensions)
- Support both dynamic and static reshape operations
- Add element size calculation for complex types
- Match pattern from AffineToStableHLORaising.cpp

Co-authored-by: wsmoses <[email protected]>
- Add validation that element sizes divide evenly
- Use consistent loop variable types (size_t)
- Extract size ratio to avoid repeated division
- Fix potential integer truncation issues

Co-authored-by: wsmoses <[email protected]>
Copilot AI and others added 2 commits December 8, 2025 00:45
- Add include for Interfaces/AutoDiffTypeInterface.h
- Replace manual bit width calculation with getApproxSize()
- Simplifies byte offset calculation for all element types
- Simplifies size calculation for BitcastConvertOp logic

Co-authored-by: wsmoses <[email protected]>
- Rename elementBytes to elementSizeBytes
- Rename currentSize/targetSize to currentSizeBytes/targetSizeBytes
- Makes it clear that sizes are in bytes, not bits

Co-authored-by: wsmoses <[email protected]>
/// \param inputTensorTypes The tensor types for the wrapper function arguments
/// \param byteOffsets Byte offsets for each argument (0 means no offset)
/// \return A new function that adapts the calling convention
mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
In file included from src/enzyme_ad/jax/Dialect/Ops.cpp:14:
./src/enzyme_ad/jax/Utils.h:850:7: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      | ~~~~~~^
./src/enzyme_ad/jax/Utils.h:850:51: error: no member named 'func' in namespace 'mlir'
  850 | mlir::func::FuncOp adaptToCallingConvention(mlir::func::FuncOp f,
      |                                             ~~~~~~^
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: warning: use of bitwise '|' with boolean operands [-Wbitwise-instead-of-logical]
 1776 |     if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape())
      |         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      |                                         ||
src/enzyme_ad/jax/Dialect/Ops.cpp:1776:9: note: cast one or both operands to int to silence this warning
1 warning and 2 errors generated.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the compilation error by adding the missing include mlir/Dialect/Func/IR/FuncOps.h to Utils.h. This resolves the "no member named 'func' in namespace 'mlir'" error.

Fixed in commit 7fce353.

- Add mlir/Dialect/Func/IR/FuncOps.h include to Utils.h
- Fixes compilation error: no member named 'func' in namespace 'mlir'

Co-authored-by: wsmoses <[email protected]>
Copilot AI requested a review from wsmoses December 8, 2025 01:49
@wsmoses
Copy link
Member

wsmoses commented Dec 8, 2025

@copilot


ERROR: /__w/Enzyme-JAX/Enzyme-JAX/src/enzyme_ad/jax/BUILD:840:11: Compiling src/enzyme_ad/jax/Utils.cpp failed: (Exit 1): clang failed: error executing CppCompile command (from target //src/enzyme_ad/jax:XLADerivatives) /usr/lib/llvm-18/bin/clang -U_FORTIFY_SOURCE -fstack-protector -Wall -Wthread-safety -Wself-assign -Wunused-but-set-parameter -Wno-free-nonheap-object -fcolor-diagnostics -fno-omit-frame-pointer -g0 ... (remaining 585 arguments skipped)

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
src/enzyme_ad/jax/Utils.cpp:264:33: warning: lambda capture 'op' is not used [-Wunused-lambda-capture]
  264 |     if (!llvm::all_of(effects, [op](const MemoryEffects::EffectInstance &it) {
      |                                 ^~
src/enzyme_ad/jax/Utils.cpp:294:32: warning: lambda capture 'op' is not used [-Wunused-lambda-capture]
  294 |     if (llvm::any_of(effects, [op](const MemoryEffects::EffectInstance &it) {
      |                                ^~
In file included from src/enzyme_ad/jax/Utils.cpp:9:
In file included from src/enzyme_ad/jax/Utils.h:14:
In file included from external/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h:18:
In file included from external/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h:19:
In file included from external/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h:18:
external/llvm-project/mlir/include/mlir/IR/Builders.h:509:11: error: no matching member function for call to 'build'
  509 |     OpTy::build(*this, state, std::forward<Args>(args)...);
      |     ~~~~~~^~~~~
src/enzyme_ad/jax/Utils.cpp:1410:41: note: in instantiation of function template specialization 'mlir::OpBuilder::create<mlir::stablehlo::ConstantOp, mlir::IntegerType, mlir::IntegerAttr>' requested here
 1410 |               auto constValue = builder.create<stablehlo::ConstantOp>(
      |                                         ^
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8869:15: note: candidate function not viable: no known conversion from 'mlir::IntegerAttr' to '::mlir::ElementsAttr' for 4th argument
 8869 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::ElementsAttr value);
      |               ^                                                                                           ~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8875:15: note: candidate function not viable: no known conversion from 'mlir::IntegerAttr' to '::mlir::ElementsAttr' for 4th argument
 8875 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ElementsAttr value);
      |               ^                                                                                                     ~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8878:15: note: candidate function not viable: no known conversion from 'mlir::IntegerAttr' to '::mlir::ValueRange' for 4th argument
 8878 |   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
      |               ^                                                                                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8881:15: note: candidate function not viable: no known conversion from 'mlir::IntegerType' to '::mlir::ValueRange' for 3rd argument
 8881 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
      |               ^                                                                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8887:15: note: candidate function not viable: no known conversion from 'mlir::IntegerType' to '::mlir::ValueRange' for 3rd argument
 8887 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {});
      |               ^                                                                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8866:15: note: candidate function not viable: requires 3 arguments, but 4 were provided
 8866 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Attribute value);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8872:15: note: candidate function not viable: requires 3 arguments, but 4 were provided
 8872 |   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ElementsAttr value);
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
bazel-out/k8-opt/bin/external/stablehlo/stablehlo/dialect/StablehloOps.h.inc:8884:15: note: candidate function not viable: requires at least 5 arguments, but 4 were provided
 8884 |   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {});
      |               ^     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2 warnings and 1 error generated.

- Change ConstantOp to use RankedTensorType::get({}, i32Type) for scalar tensors
- Use cast<ElementsAttr>(makeAttr(...)) for creating constant values
- Fixes compilation error with stablehlo::ConstantOp::build
- Also fix GetDimensionSizeOp to return tensor type instead of scalar

Co-authored-by: wsmoses <[email protected]>
Copy link
Author

Copilot AI commented Dec 8, 2025

@copilot


...

Fixed the compilation error in stablehlo::ConstantOp::create calls. The issue was using scalar types instead of tensor types. Changed to use RankedTensorType::get({}, builder.getI32Type()) for scalar i32 tensors and cast<ElementsAttr>(makeAttr(...)) for creating constant values. Also fixed GetDimensionSizeOp to return the proper tensor type.

Fixed in commit a4c219a.

// Get dynamic dimension from original tensor
auto dimValue = builder.create<stablehlo::GetDimensionSizeOp>(
loc, builder.getI32Type(), adaptedArg, i);
loc, scalarI32Type, adaptedArg, i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can you use the version that doesn't explicitly pass the return type here [it will be deduced automatically].

moreover, no need to have a separate path for dynamic vs constant, you can always use getdimensionsize as it will get optimized away later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the reshape logic:

  • Removed conditional branches for static vs dynamic shapes
  • Always use GetDimensionSizeOp for all dimensions (will be optimized away for static shapes)
  • Removed explicit type parameter from ConstantOp::create (type is deduced automatically)

This reduces code complexity from ~73 lines to ~27 lines while maintaining the same functionality.

Fixed in commit 33cf3f6.

- Define scalarI32Type once at the beginning of type conversion block
- Remove duplicate definitions within nested scopes
- Improves code maintainability and reduces duplication

Co-authored-by: wsmoses <[email protected]>
- Remove conditional branches for static vs dynamic reshape
- Always use GetDimensionSizeOp for all dimensions (optimized away for static shapes)
- Remove explicit type parameter from ConstantOp::create (type deduced automatically)
- Simplifies code and relies on compiler optimizations

Co-authored-by: wsmoses <[email protected]>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnzymeJAX Benchmarks

Benchmark suite Current: 33cf3f6 Previous: ead4414 Ratio
scatter_sum / JaX / cpu / Primal 0.000006275461000041105 s 0.000004329831799987005 s 1.45
scatter_sum / JaXPipe / cpu / Primal 0.000003951574000166147 s 0.000003912584799763863 s 1.01
scatter_sum / JaX / tpu / Primal 0.0001316881838996 s 0.0001478510524 s 0.89
scatter_sum / JaXPipe / tpu / Primal 0.0001495134879001 s 0.0001495474493 s 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@wsmoses
Copy link
Member

wsmoses commented Dec 8, 2025

@copilot the format is wrong. see ci

Run clang-format on Utils.cpp and Utils.h to fix formatting issues

Co-authored-by: wsmoses <[email protected]>
Copy link
Author

Copilot AI commented Dec 8, 2025

@copilot the format is wrong. see ci

Fixed formatting issues by running clang-format on Utils.cpp and Utils.h.

Fixed in commit f132335.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Element type Conversion and offset

2 participants