Skip to content

New version of Enzyme-JAX breaks lots of sharding tests #1180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
giordano opened this issue Apr 14, 2025 · 2 comments
Closed

New version of Enzyme-JAX breaks lots of sharding tests #1180

giordano opened this issue Apr 14, 2025 · 2 comments

Comments

@giordano
Copy link
Member

PJRT

Test Summary:                               | Pass  Error  Broken  Total      Time
Reactant.jl Tests                           | 1430     19       1   1450  22m57.4s
  Layout                                    |    5                     5      1.6s
  Tracing                                   |  146                   146      3.1s
  Basic                                     |  450                   450   7m31.3s
  Autodiff                                  |   30                    30     58.9s
  Complex                                   |   50                    50     50.1s
  Broadcast                                 |    9                     9     26.7s
  Struct                                    |    9              1     10      8.6s
  Closure                                   |    1                     1      1.2s
  Compile                                   |   25                    25     25.0s
  Buffer Donation                           |   12                    12     13.9s
  Shortcuts to MLIR ops                     |  291                   291   4m37.4s
  Wrapped Arrays                            |   32                    32   1m22.1s
  Control Flow                              |   84                    84   2m00.7s
  Sorting                                   |   74                    74   1m42.3s
  Indexing                                  |  153                   153   1m24.3s
  Custom Number Types                       |    5                     5     13.5s
  Sharding                                  |   31     15             46     40.8s
    Sharding Across 2 Devices               |   12      1             13      7.5s
    Sharding Across 8 Devices               |           2              2      3.1s
    Sharding with non-iota mesh             |                       None      0.0s
    Multiple Axis Partition Spec            |    2      2              4      5.8s
    Open Axis Partition Spec                |    2      2              4      0.9s
    Multiple Mesh Sharding                  |                       None      0.0s
    Sharding Constraint                     |    1      1              2      2.9s
    Sharding with non-divisible axes sizes  |    3      4              7      3.7s
      Handle Sub-Axis Info                  |    1                     1      0.5s
    Device List from Iota Tile              |    3                     3      0.2s
    Sharding with Mutation                  |    3      1              4      5.8s
    Bad Codegen for Resharded Inputs: #1027 |                       None      0.0s
    Multiple Mesh Sharding                  |                       None      0.0s
    Initialize Sharded Data                 |    5      1              6      4.3s
    ShardyPropagationOptions                |           1              1      1.9s
  Comm Optimization                         |           4              4      6.9s
    Rotate                                  |           1              1      2.3s
    Pad                                     |           1              1      1.4s
    DUS                                     |           1              1      1.6s
    DUS2                                    |           1              1      1.4s
  Cluster Detection                         |   11                    11      0.7s
  Config                                    |   12                    12      8.1s

IFRT:

Test Summary:                               | Pass  Error  Broken  Total      Time
Reactant.jl Tests                           | 1426     26       1   1453  23m57.5s
  Layout                                    |    5                     5      1.9s
  Tracing                                   |  146                   146      3.2s
  Basic                                     |  446                   446   7m42.2s
  Autodiff                                  |   30                    30     59.0s
  Complex                                   |   50                    50     50.2s
  Broadcast                                 |    9                     9     26.5s
  Struct                                    |    9              1     10      8.9s
  Closure                                   |    1                     1      1.3s
  Compile                                   |   25                    25     27.2s
  Buffer Donation                           |   12                    12     13.7s
  Shortcuts to MLIR ops                     |  291                   291   4m48.4s
  Wrapped Arrays                            |   32                    32   1m27.7s
  Control Flow                              |   84                    84   2m06.1s
  Sorting                                   |   74                    74   1m46.9s
  Indexing                                  |  153                   153   1m25.8s
  Custom Number Types                       |    5                     5     13.9s
  Sharding                                  |   31     22             53     58.1s
    Sharding Across 2 Devices               |   12      1             13      7.9s
    Sharding Across 8 Devices               |           2              2      3.2s
    Sharding with non-iota mesh             |           4              4      4.3s
    Multiple Axis Partition Spec            |    2      2              4      3.2s
    Open Axis Partition Spec                |    2      2              4      0.9s
    Multiple Mesh Sharding                  |           1              1      3.6s
    Sharding Constraint                     |    1      1              2      3.0s
    Sharding with non-divisible axes sizes  |    3      4              7      5.7s
      Handle Sub-Axis Info                  |    1                     1      0.4s
    Device List from Iota Tile              |    3                     3      0.2s
    Sharding with Mutation                  |    3      1              4      7.1s
    Bad Codegen for Resharded Inputs: #1027 |           1              1      2.1s
    Multiple Mesh Sharding                  |           1              1      5.7s
    Initialize Sharded Data                 |    5      1              6      4.0s
    ShardyPropagationOptions                |           1              1      1.7s
  Comm Optimization                         |           4              4      7.7s
    Rotate                                  |           1              1      2.5s
    Pad                                     |           1              1      1.4s
    DUS                                     |           1              1      1.7s
    DUS2                                    |           1              1      2.0s
  Cluster Detection                         |   11                    11      0.7s
  Config                                    |   12                    12      8.1s

Error seems to be

error: couldn't find a pattern operation corresponding to concatreshape_to_onedim_dus
@giordano
Copy link
Member Author

Is the issue that #1172 introduced this pass but this requires a future JLL?

@giordano
Copy link
Member Author

This was indeed fixed by upgrading to newer JLL build in 44d3d63

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant