Skip to content

Basic FP8 Training Support #992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 22, 2025
Merged

Conversation

hanzhi713
Copy link
Member

This PR adds FP8 training support using either in-batch scaling or delayed scaling. There's a one scaling factor per tensor.

@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners February 13, 2025 23:42
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Comment on lines 317 to 330
update_cfg: OverrideInplaceUpdateTransformation.Config = (
OverrideInplaceUpdateTransformation.default_config()
)
update_cfg.rules = [
f".*/{x}"
for x in [
"input_scale",
"kernel_scale",
"output_grad_scale",
"input_amax_history",
"kernel_amax_history",
"output_grad_amax_history",
]
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use update_rules? See

("state_updates_v", [(".*v", UpdateType.STATE_UPDATES)]),
.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. There is a tracer problem when plumbing the FP8 stats through as state updates, because they are produced in the backward pass of a custom vjp. To circumvent this problem, we produce them as gradient and treat them as state updates.

See the internal PR 806 and 816 for context and discussions.

secondary: Nested[Any],
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
) -> Nested[Any]:
"""Merge `secondary` into `primary`. The result contains shallow copies of subtrees from both.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making them deep copies to be consistent with the jax tree utils. Deep copies are also less error prone by decoupling side effects.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if k in primary:
out_tree[k] = tree_merge(primary[k], secondary=secondary[k], override_fn=override_fn)
else:
out_tree[k] = secondary[k]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out_tree[k] = secondary[k]
out_tree[k] = copy.deepcopy(secondary[k])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if isinstance(primary, dict) ^ isinstance(secondary, dict):
raise ValueError(f"Trying to merge incompatible subtrees: {primary=}, {secondary=}")
if not (isinstance(primary, dict) or isinstance(secondary, dict)):
return override_fn(primary, secondary)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return override_fn(primary, secondary)
return copy.deepcopy(override_fn(primary, secondary))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

primary: Nested[Any],
*,
secondary: Nested[Any],
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
leaf_merge_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

primary: Nested[Any],
*,
secondary: Nested[Any],
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using a simpler function (always choose primary) as the default merge function?

Suggested change
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn,
override_fn: Callable[[Any, Any], Any] = lambda x, y: x,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current default makes more sense. Always choose primary will discard subtree of secondary even when the primary tree has an empty leaf or None at the corresponding position.

@hanzhi713 hanzhi713 requested a review from ruomingp February 27, 2025 00:17
@hanzhi713
Copy link
Member Author

@ruomingp Could you please take a look again?

@hanzhi713 hanzhi713 enabled auto-merge March 13, 2025 23:25
@hanzhi713
Copy link
Member Author

@ruomingp Could you please take a look again?

@ruomingp
Copy link
Contributor

@ruomingp Could you please take a look again?

Sorry about the delay. Will take a look soon.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies about the delay.

"output_grad_amax_history",
]
]
transformation = maybe_instantiate(cfg.learner.optimizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we avoid instantiation when constructing configs to preserve the readability of configs. Can we defer the maybe_instantiate call to OverrideInplaceUpdateTransformation.__init__?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 323 to 328
"input_scale",
"kernel_scale",
"output_grad_scale",
"input_amax_history",
"kernel_amax_history",
"output_grad_amax_history",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how we can keep this list in sync with quantized_dot_general/layers.py. Consider

  • Defining them as an enum in quantized_dot_general/common.py and enumerate the enum members here
  • Move this class to quantized_dot_general/update_transformation.py
  • Change quantized_dot_general/layers.py to use the enums instead of the string literals?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@hanzhi713 hanzhi713 requested a review from ruomingp March 17, 2025 21:32
@hanzhi713 hanzhi713 requested a review from ruomingp March 19, 2025 17:11
@hanzhi713
Copy link
Member Author

@ruomingp Could you please take a look again?

@hanzhi713 hanzhi713 added this pull request to the merge queue Mar 22, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 22, 2025
@hanzhi713 hanzhi713 added this pull request to the merge queue Mar 22, 2025
Merged via the queue into apple:main with commit a6df285 Mar 22, 2025
6 checks passed
@hanzhi713 hanzhi713 deleted the hanzhi/fp8-update branch March 22, 2025 04:51
rahul003 added a commit to rahul003/axlearn that referenced this pull request Mar 31, 2025
commit 95c97d3
Author: Rahul H <[email protected]>
Date:   Mon Mar 31 23:58:24 2025 +0000

    Remove cache step arg

commit 95a3c2d
Author: Rahul H <[email protected]>
Date:   Mon Mar 31 23:24:21 2025 +0000

    Fix envy config

commit 5255eb0
Merge: 048145f 862d3a6
Author: Rahul H <[email protected]>
Date:   Mon Mar 31 23:15:21 2025 +0000

    Merge branch 'main' of https://github.com/apple/axlearn into rebasemain

commit 862d3a6
Author: Chang Liu <[email protected]>
Date:   Sun Mar 30 21:20:15 2025 -0700

    Consolidate QuotaInfo and UserQuotaInfo into one class (apple#1080)

commit 1e2b306
Author: muyangyuapple <[email protected]>
Date:   Fri Mar 28 16:30:04 2025 -0700

    Don't set location_hint if it is None (apple#1075)

commit 0d7e65b
Author: Dongseong Hwang <[email protected]>
Date:   Fri Mar 28 15:51:34 2025 -0700

    Flash attention fallbacks to standard attention, not ReferenceMHA. (apple#1077)

    `ReferenceMHA` is essentially a duplication of standard attention’s
    `_compute_attention`.
    Instead of duplication, we now call `_compute_attention` directly.

    Another drawback of `ReferenceMHA` is that it **doesn’t apply gradient
    checkpointing**, which can lead to **significantly higher memory usage**.
    For example, training a **600M Conformer model** downstream can cause memory
    usage to jump from **24.89 GB to 157.58 GB**.

commit 3285ee0
Author: Dongseong Hwang <[email protected]>
Date:   Fri Mar 28 14:13:32 2025 -0700

    Make `attention_logit_biases` in init_states/extend_step Optional (apple#1076)

    This is legacy. Make it optional. Use cfg.mask to avoid unnecessary O(T^2) biases.
    Currently, it's used by only T5 model.

commit 5db2378
Author: Apoorv Gupta <[email protected]>
Date:   Fri Mar 28 11:57:56 2025 -0700

    Improvements and fixes to gradient accumulation (apple#993)

    * Improvements and fixes to gradient accumulation

    - Fix to with_minibatch_steps decorator to generate correct primal outputs shapes.
    - Improved with_minibatch_steps to take a minibatch_partitioner that contraints the input batch to the same PartitionSpec as Input Partitioner.

    * remove sharding constraints from gradient accumulation

    * introduce reshape+transpose to avoid input batch reshard in accumulation

    * Address comments and add minibatch sharding test

    * TestMinibatchSteps now asserts carry shape matches

    * add doc string on summary accumulation

    * reorder pytest skip decorator

    * update golden configs

commit 2ce0aee
Author: Mark Lee <[email protected]>
Date:   Thu Mar 27 21:00:14 2025 -0700

    Disables hf test. (apple#1073)

commit 02783b5
Author: Mark Lee <[email protected]>
Date:   Thu Mar 27 20:58:50 2025 -0700

    Temporarily retain instance_type behavior and fix aliased defaults. (apple#1072)

commit 896617c
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Mar 27 20:56:48 2025 -0400

    Reduce amount of cpu simulation tests (apple#1074)

commit 0ca5f7c
Author: muyangyuapple <[email protected]>
Date:   Thu Mar 27 10:30:28 2025 -0700

    Changed job submitter flags (apple#1070)

    * Changed job submitter flags

    * Update axlearn/cloud/gcp/job_flink.py

    Co-authored-by: Mark Lee <[email protected]>

    * updated unit test

    ---------

    Co-authored-by: Mark Lee <[email protected]>

commit da55f3f
Author: Dongseong Hwang <[email protected]>
Date:   Wed Mar 26 21:55:45 2025 -0700

    KV cache update respects cached_dtype (apple#1071)

    The newly added dynamic_update_slice_in_dim path was ignoring cached_dtype.
    This fixes that issue.

commit 8ce86ac
Author: Firenze11 <[email protected]>
Date:   Tue Mar 25 16:25:41 2025 -0700

    enable arbitrary head_dim when using gpu flash attention (apple#1048)

    * enable arbitrary head_dim when using gpu flash attention

    * Update gpu_attention_test.py

    * Update gpu_attention.py

    * Update gpu_attention.py

    * Update gpu_attention.py

    * Update gpu_attention.py

    * Update gpu_attention_test.py

commit 5ac077c
Author: Dongseong Hwang <[email protected]>
Date:   Tue Mar 25 16:04:16 2025 -0700

    Eliminating Unnecessary Quadratic Memory Usage in KV Cache Prefill (apple#1068)

    When updating the KV cache, the current approach temporarily creates a [B,
    step, kv_len] one-hot vector.
    * This is fine when step is small, but problematic when step is large.
    * For example, if video is used as a prefill, the KV cache update could cause
    OOM errors.

    Solution
    * When step size is large, we update the KV cache using
      dynamic_update_slice_in_dim instead.

    Benchmark Results
    * On TPU, dynamic_update_slice_in_dim becomes faster when step size ≈ 1024.
    * On GPU, dynamic_update_slice_in_dim becomes faster when step size ≈ 32.

    Thus, we set different thresholds per platform.

    Note:
    * On GPU, dynamic_update_slice_in_dim significantly reduces memory usage.
    * On TPU, however, both methods use about the same amount of memory as GPU's
      dynamic_update_slice_in_dim.

    It seems that TPU's XLA compilation performs a "magical" memory optimization
    for one-hot matmul somehow.

    ---

    TPU v5p
    * one-hot matmul
    QkvLinearExtendStepBenchmark/2048/16/1024/1           0.380 ms        0.086 ms         8462
    QkvLinearExtendStepBenchmark/2048/16/4096/1           0.599 ms        0.092 ms         8487
    QkvLinearExtendStepBenchmark/2048/16/4096/64          0.562 ms        0.080 ms         8242
    QkvLinearExtendStepBenchmark/2048/16/4096/512         0.654 ms        0.078 ms         8862
    QkvLinearExtendStepBenchmark/2048/16/32768/1           2.18 ms        0.082 ms         8627 3.05 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/1024        2.77 ms        0.080 ms         1000 3.08 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/8192        15.6 ms        0.215 ms         1000 3.30 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/32768       53.5 ms        0.224 ms          100 4.05 GB
    MHADecodeBenchmark/2048/16/1024/1                     0.358 ms        0.072 ms         9105
    MHADecodeBenchmark/2048/16/4096/1                     0.542 ms        0.066 ms         9443
    MHADecodeBenchmark/2048/16/4096/64                    0.556 ms        0.073 ms         9444
    MHADecodeBenchmark/2048/16/4096/512                   0.710 ms        0.071 ms         9273
    MHADecodeBenchmark/2048/16/32768/1                     1.95 ms        0.073 ms         9139 2.07 GB
    MHADecodeBenchmark/2048/16/32768/1024                  7.36 ms        0.070 ms         1000 6.11 GB
    MHADecodeBenchmark/2048/16/32768/8192                  55.7 ms        0.210 ms          100 34.38 GB

    * dynamic_update_slice_in_dim
    QkvLinearExtendStepBenchmark/2048/16/1024/1           0.374 ms        0.078 ms         8875
    QkvLinearExtendStepBenchmark/2048/16/4096/1           0.572 ms        0.079 ms         8471
    QkvLinearExtendStepBenchmark/2048/16/4096/64          0.575 ms        0.077 ms         9208
    QkvLinearExtendStepBenchmark/2048/16/4096/512         0.666 ms        0.080 ms         8590
    QkvLinearExtendStepBenchmark/2048/16/32768/1           2.15 ms        0.082 ms         8296 3.05 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/1024        2.34 ms        0.083 ms         8352 3.08 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/8192        4.17 ms        0.085 ms         1000 3.30 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/32768       10.8 ms        0.131 ms         1000 4.05 GB
    MHADecodeBenchmark/2048/16/1024/1                     0.370 ms        0.071 ms        10258
    MHADecodeBenchmark/2048/16/4096/1                     0.510 ms        0.070 ms        10296
    MHADecodeBenchmark/2048/16/4096/64                     2.06 ms        0.070 ms        10345
    MHADecodeBenchmark/2048/16/4096/512                    13.1 ms        0.207 ms         1000
    MHADecodeBenchmark/2048/16/32768/1                     1.78 ms        0.074 ms        10033 2.07 GB
    MHADecodeBenchmark/2048/16/32768/1024                  7.59 ms        0.068 ms         1000 6.11 GB
    MHADecodeBenchmark/2048/16/32768/8192                  52.8 ms        0.206 ms          100 34.38 GB

    A100
    * one-hot matmul
    QkvLinearExtendStepBenchmark/2048/16/1024/1           0.322 ms        0.180 ms         3523
    QkvLinearExtendStepBenchmark/2048/16/4096/1           0.455 ms        0.201 ms         4016
    QkvLinearExtendStepBenchmark/2048/16/4096/64          0.620 ms        0.188 ms         3524
    QkvLinearExtendStepBenchmark/2048/16/4096/512         0.826 ms        0.218 ms         3551
    QkvLinearExtendStepBenchmark/2048/16/32768/1           1.48 ms        0.203 ms         3608
    QkvLinearExtendStepBenchmark/2048/16/32768/1024        4.36 ms        0.279 ms         1000 3.11 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/8192        18.0 ms        0.328 ms         1000 5.30 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/32768        108 ms        0.352 ms          100 12.05 GB
    MHADecodeBenchmark/2048/16/1024/1                     0.307 ms        0.155 ms         4497
    MHADecodeBenchmark/2048/16/4096/1                     0.389 ms        0.145 ms         4729
    MHADecodeBenchmark/2048/16/4096/64                    0.586 ms        0.143 ms         4593
    MHADecodeBenchmark/2048/16/4096/512                    1.16 ms        0.174 ms         4493
    MHADecodeBenchmark/2048/16/32768/1                     1.16 ms        0.153 ms         4506
    MHADecodeBenchmark/2048/16/32768/1024                  12.9 ms        0.257 ms         1000
    MHADecodeBenchmark/2048/16/32768/8192                  88.7 ms        0.375 ms          100 66.31 GB

    * dynamic_update_slice_in_dim
    QkvLinearExtendStepBenchmark/2048/16/1024/1           0.345 ms        0.171 ms         3952
    QkvLinearExtendStepBenchmark/2048/16/4096/1           0.502 ms        0.178 ms         3879
    QkvLinearExtendStepBenchmark/2048/16/4096/64          0.544 ms        0.199 ms         3823
    QkvLinearExtendStepBenchmark/2048/16/4096/512         0.583 ms        0.177 ms         3912
    QkvLinearExtendStepBenchmark/2048/16/32768/1           2.04 ms        0.199 ms         3357
    QkvLinearExtendStepBenchmark/2048/16/32768/1024        2.20 ms        0.218 ms         3322 3.08 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/8192        3.76 ms        0.264 ms         1000 3.30 GB
    QkvLinearExtendStepBenchmark/2048/16/32768/32768       8.84 ms        0.282 ms         1000 4.05 GB
    MHADecodeBenchmark/2048/16/1024/1                     0.299 ms        0.140 ms         5119
    MHADecodeBenchmark/2048/16/4096/1                     0.473 ms        0.143 ms         4915
    MHADecodeBenchmark/2048/16/4096/64                    0.591 ms        0.143 ms         4937
    MHADecodeBenchmark/2048/16/4096/512                    1.01 ms        0.149 ms         4696
    MHADecodeBenchmark/2048/16/32768/1                     1.50 ms        0.146 ms         4711
    MHADecodeBenchmark/2048/16/32768/1024                  11.3 ms        0.250 ms         1000
    MHADecodeBenchmark/2048/16/32768/8192                  74.1 ms        0.275 ms          100 66.31 GB

commit 363f2aa
Author: muyangyuapple <[email protected]>
Date:   Tue Mar 25 14:06:52 2025 -0700

    Added v6e to system_characteristics.py (apple#1069)

commit 65994b0
Author: mmllee <[email protected]>
Date:   Tue Mar 25 18:48:38 2025 +0100

    W&B: log flat config along with the nested config (apple#1065)

commit 7a851c0
Author: mmllee <[email protected]>
Date:   Tue Mar 25 18:48:12 2025 +0100

    reflect git repo state at job submission time in docker labels and the code bundle (apple#1064)

commit 96d8c58
Author: Hanzhi Zhou <[email protected]>
Date:   Mon Mar 24 21:02:50 2025 -0700

    Fix typo (apple#1067)

commit fbbef0f
Author: Mark Lee <[email protected]>
Date:   Fri Mar 21 21:31:49 2025 -0700

    Refactors config and flag handling for composition. (apple#1063)

commit a6df285
Author: Hanzhi Zhou <[email protected]>
Date:   Fri Mar 21 21:11:01 2025 -0700

    Basic FP8 Training Support (apple#992)

    * Basic FP8 Training

    * Address comments

    * Address comments

    * Rename

    * Use enum

    * Use enum

    * Address comments

commit bf15ab3
Author: Hanzhi Zhou <[email protected]>
Date:   Fri Mar 21 13:57:43 2025 -0700

    Refactor and optimize FlashAttention dispatch (apple#1058)

    * Refactor and Optimize FlashAttention Dispatch

    * Additional checks

    * add jit and fix some tests

    * fix test

    * refactor cudnn

    * update benchmarks

    * Fix TPU tests

    * Move

    * Fix TPU benchmark

    * Fix grad

    * Fix typing

    * Fix comment and reorder

    * Fix type

    * Address comments

    * Skip broken custom mask

    * Add back gpu fwd test

    * Simplify returns

    * Address comments

    * Add default config

    * Fix

    * Fix wrong msg

    * Add comment

commit 976ccab
Author: Mark Lee <[email protected]>
Date:   Thu Mar 20 11:45:57 2025 -0700

    Removes QRM codepaths. (apple#1062)

commit c361523
Author: muyangyuapple <[email protected]>
Date:   Wed Mar 19 21:10:52 2025 -0700

    Revert "Revert "Support Beam pipelines on TPU (apple#1055)" (apple#1057)" (apple#1059)

    This reverts commit d54aa10.

commit 03cfa39
Author: Dongseong Hwang <[email protected]>
Date:   Wed Mar 19 14:00:15 2025 -0700

    Refactor KV Cache Out of `QKVLinear` (apple#1041)

    * Refactor KV Cache Out of `QKVLinear`

    Currently, **`QKVLinear` is overly complex** because it handles **both QKV
    computation and KV cache management**.

    Although **QKVLinear’s role and KV cache strategy should be independent**, the
    current implementation **forces QKVLinear to manage KV cache**, making it
    necessary to **carefully maintain every QKVLinear subclass** (`FusedQKV`,
    `GroupedQKV`, `RoPEQKV`, etc.) to ensure they correctly handle KV cache.

    **Key Changes in This PR**
    This PR **removes KV cache logic from `QKVLinear`**, turning it into a **pure
    `forward`-only class** (such as QKV proj and RoPE) that **no longer needs to
    handle decoding**.

    Instead, **Attention now owns the KV cache directly**, making it **more
    flexible for future KV cache strategies**.

    Currently, **`QKVLinear` supports only one KV cache behavior**, which maintains
    a **fixed max length**. However, in the near future, we will introduce more
    **KV cache strategies**, such as:

    - **Sliding Window Attention** → Requires a **sliding window KV cache**.
    - **Sparse Attention** → Needs a KV cache that **dynamically selects sparse KV**
      (similar to DeepSeek). https://arxiv.org/abs/2502.11089

    **Implementation Details**
    A key aspect of this refactor is **how query positions and key positions are
    generated**.
    Previously, the related logic was **scattered across multiple places**, but
    now, **positions are computed in a single place**:

    - **Query positions** → Must be determined **before RoPE** since RoPE requires
      them. The **same query positions** are then **reused throughout the code**.
    - **Key positions** → Only the **KV cache layer** can determine them
      **accurately** since **KV cache strategies** directly affect key positions.
      So, **KV cache is now responsible for generating key positions**.

    In addition, **`KVState` now carries both KV values and key positions**.

    * Introduce live_step_len in Prefill.

    In KVCache, the valid parts of the input KV are marked using **live_step_len**.
    This information is necessary for the future (and currently downstream) sliding
    window KV cache. Otherwise, the cache would store invalid KV entries and
    simply roll the window by a fixed size, potentially losing past KV entries that
    should remain within the window.

    In the long run, **multi-step `extend_step` should also use live_step_len**,
    since the valid sequence length may vary across different batches.

commit a39f8ca
Author: Dongseong Hwang <[email protected]>
Date:   Wed Mar 19 09:22:38 2025 -0700

    TPU flash attention: Allow custom mask. 2/2 (apple#1061)

    This is follow-up of apple#1050

    `mask.bool_value()` internally uses `mask.target_positions` tensor, so it cannot be
    used inside `with jax.ensure_compile_time_eval():`, even if its shape is
    statically determined.

    Before the changes in the previous PR, apple#1028
    `mask.bool_value()` worked because mask did not store any tensors.

    To resolve this, this PR reimplemented the same logic without using
    `jax.ensure_compile_time_eval()`.

    As mentioned earlier, since the shape is statically determined,
    `target_positions` can be converted into a NumPy array at runtime.

commit d54aa10
Author: muyangyuapple <[email protected]>
Date:   Tue Mar 18 10:38:21 2025 -0700

    Revert "Support Beam pipelines on TPU (apple#1055)" (apple#1057)

    This reverts commit 4401985.

commit 9999da2
Author: Haoshuo Huang <[email protected]>
Date:   Mon Mar 17 23:26:53 2025 -0700

    Fix a wrong assert in input_grain_lm (apple#1056)

commit 4401985
Author: muyangyuapple <[email protected]>
Date:   Mon Mar 17 22:33:45 2025 -0700

    Support Beam pipelines on TPU (apple#1055)

    * Support Beam pipelines on TPU

    * Update _Matcher signature to meet pytype requirement.

    * Update axlearn/cloud/gcp/job_flink.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/cloud/gcp/job_flink.py

    Co-authored-by: Mark Lee <[email protected]>

    * fix pytype

    * fix a minor bug during pre-merge verification

    ---------

    Co-authored-by: Mark Lee <[email protected]>

commit 5a17918
Author: Hanzhi Zhou <[email protected]>
Date:   Mon Mar 17 15:10:25 2025 -0400

    Lazy import orbax (apple#1054)

    * Lazy import orbax

    * fix test

commit 53eea9d
Author: Dongseong Hwang <[email protected]>
Date:   Mon Mar 17 11:47:14 2025 -0700

    The dtype of paddings becomes jnp.bool. (apple#1053)

    Previously, it was assumed to have the same dtype as the input. However, since
    paddings only store 0/1 values, this restriction is unnecessary.

    There is a need to save memory by using a much smaller dtype instead of
    float32/bfloat16.
    We prefer to use jnp.bool because

    Backend | Boolean Storage Size (jnp.bool)
    -- | --
    CPU (XLA:CPU) | 8 bits (1 byte per bool)
    GPU (XLA:GPU, CUDA, ROCm) | 8 bits (1 byte per bool, no native bit-packing)
    TPU (XLA:TPU) | 1 bit per bool (native bit-packing)

commit 182605b
Author: fnan <[email protected]>
Date:   Mon Mar 17 13:01:12 2025 -0400

    fix tool call intent when tool_calls is empty list (apple#1052)

commit 8a52817
Author: Dongseong Hwang <[email protected]>
Date:   Sat Mar 15 09:24:56 2025 -0700

    Fix Broken CPU Code Path in `run_aot_compilation.py` (apple#1051)

    `run_aot_compilation.py` was designed to run on **CPU** when `--topology=None`,
    but this **code path was broken**. This PR fixes it to ensure proper execution
    on **CPU**.

    Now, --topology is a required flag.
    If you want to perform AOT compilation using the JAX CPU library, you must
    specify --topology=cpu-256 (or a similar cpu-<digit> format).
    This is necessary because, even for CPU execution, we need to determine the
    logical device number to estimate activation memory usage.

commit fde078d
Author: Dongseong Hwang <[email protected]>
Date:   Fri Mar 14 12:45:08 2025 -0700

    TPU flash attention: Allow custom mask. (apple#1050)

    TPU Flash Attention previously supported custom masks through
    splash_attention_mask.NumpyMask, but the following PR slightly disabled this
    functionality.
    https://github.pie.apple.com/foundation-models/ajax/pull/19016

    In production, only causal and sliding window masks are used, so TPU Flash
    Attention has been working without issues.

    However, since there’s no need to remove support for something that was
    originally available, this PR re-enables custom mask support as before.

    TEST=axlearn/common/flash_attention/tpu_attention_test.py on v5p

commit 1a8b586
Author: Haoshuo Huang <[email protected]>
Date:   Thu Mar 13 18:25:47 2025 -0700

    Add streaming packing support for grain (apple#1046)

    * Add streaming packing

    * fix typo

    * minor cleanup

    * minor cleanup

    * fix format

    * Fix comments

    * fix comments

    * fix comments

    * minor cleanup

    * fix comments

    * rephrase

commit 216fa1e
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Mar 12 22:08:54 2025 -0700

    Fix TPU decoding when KV and Q have different dtype (apple#1049)

    * Fix TPU decoding dtype issue

    * Fix TPU decoding dtype issue

commit 1d4fc81
Author: Mark Lee <[email protected]>
Date:   Wed Mar 12 22:02:30 2025 -0700

    Generalizes checkpointer to support arbitrary PythonSavables. (apple#1047)

    * Generalizes checkpointer to support arbitrary PythonSavables.

    * Switches to msgpack.

commit 2795c6e
Author: muyangyuapple <[email protected]>
Date:   Wed Mar 12 16:10:56 2025 -0700

    Runners and utils to support Beam on Flink (apple#1044)

    * Runners and utils to support Beam on Flink

    * Update job_flink_test.py

    * Update gke_runner.py

    * Update axlearn/cloud/gcp/jobs/tpu_utils.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/cloud/gcp/jobs/gke_runner.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update job_flink.py

    * Fix a typo

    * typo fixes and proofread

    ---------

    Co-authored-by: Mark Lee <[email protected]>

commit 2249dd5
Author: Hanzhi Zhou <[email protected]>
Date:   Tue Mar 11 17:20:12 2025 -0700

    Fix gpu flash test (apple#1045)

    * Fix GPU Flash tests

    * Fix GPU Flash tests

    * relax

commit 93a9ad6
Author: Meng (Ethan) Li <[email protected]>
Date:   Tue Mar 11 11:56:58 2025 -0700

    Set megascale_grpc_enable_xor_tracer to false by default (apple#1043)

commit 281508d
Author: Dongseong Hwang <[email protected]>
Date:   Mon Mar 10 19:49:18 2025 -0700

    Fix KeyError of aot_model_analysis in GPU (apple#1042)

    On H100, XLA AOT analysis works well, but not all GPUs generate proper logs.
    Therefore, I wrapped the entire AOT log section with a try-exception block.

commit 2673462
Author: Mark Lee <[email protected]>
Date:   Sat Mar 8 12:47:36 2025 -0800

    Simplifies logical dispatch by direct global array construction. (apple#1040)

    * Simplifies logical dispatch by direct global array construction.

    * Address comments.

    * Fixes non-compliant testing inputs.

commit 771926e
Author: Hanzhi Zhou <[email protected]>
Date:   Fri Mar 7 18:48:02 2025 -0800

    Implements TPU decoding in Pallas (apple#1039)

    * Implements TPU decoding in Pallas

    * Add support for shorter kv len

commit b1e7b37
Author: Jialing Tong <[email protected]>
Date:   Thu Mar 6 20:23:47 2025 -0800

    Enable to assign different parameters dtype during training (apple#1037)

    * Enable per parameter train_dtype

    * nit

commit 3d62075
Author: Ke Ye <[email protected]>
Date:   Thu Mar 6 22:58:45 2025 -0500

    Allow more flexible attribute name for getting named configs. (apple#1038)

commit 0acb9a5
Author: Dongseong Hwang <[email protected]>
Date:   Thu Mar 6 17:39:12 2025 -0800

    Trainer: the model analysis on the AOT compiled JAX program. (apple#1036)

    This will help researchers estimate HBM usage and computation costs before
    launching a job, allowing them to determine whether a model is compute-bound or
    memory-bound.

    Introduced aot_model_analysis(), which returns analysis results as a string,
    making it reusable (e.g., in Jupyter notebooks).

    `run_aot_compilation` tool prints fuji-1B-v3 model analysis as follows.
    ```
    ======= Memory Analysis ==================================
    Input memory: 4.4GB
    Output memory: 4.4GB
    Temp memory: 170.9GB
    Code memory: 0.0MB
    Total HBM memory: 179.6GB
    ======= Cost Analysis ====================================
    FLOPS: 70052.0G
    The number of exp/log/sin/cos ops: 20.9G
    The total memory traffic: 1750.7GB
      HBM access: 733.9GB
      L2 cache access: 321.0GB
      Register usage: 59.8GB
      Output data transferred: 661.4GB
    Hardware utilization scores
      Tensor Cores / MatMul units: 630.0
      ALU (Arithmetic Logic Unit): 416.0
      Memory Load/Store Units: 143.0
      L1 Cache Operations: 92.0
      L2 Cache Operations: 60.0
      Special Function Units (exp/log/sin/cos): 41.0
      Integer Units (for indexing, loop counters): 16.0
      Branch Divergence (Control Flow Processing): 12.0
      Load Balancing / Dispatch): 10.0
      Texture Units (or Rarely Used Compute Units): 8.0
    ```

commit 35a189c
Author: Chang Liu <[email protected]>
Date:   Wed Mar 5 10:22:35 2025 -0800

    Add gke_reservation_project config to support shared reservation. (apple#1033)

commit dcfdbc1
Author: Mark Lee <[email protected]>
Date:   Wed Mar 5 09:39:20 2025 -0800

    Disables some flaky metrics tests. (apple#1034)

commit 88c2ff6
Author: Mark Lee <[email protected]>
Date:   Wed Mar 5 09:39:07 2025 -0800

    Enables input multi-device tests in CI. (apple#1035)

commit 6e34bcb
Author: Hanzhi Zhou <[email protected]>
Date:   Tue Mar 4 22:33:12 2025 -0800

    XLA flag autotuning for v6e (apple#1032)

commit 94d7fa3
Author: Dongseong Hwang <[email protected]>
Date:   Tue Mar 4 21:14:51 2025 -0800

    Little Optimization for RoPE Computation (apple#1031)

    In the existing `_rotary_sinusoidal_positional_embeddings()`, the same
    `position_enc[:, :, 0::2]` and `position_enc[:, :, 1::2]` computations were
    duplicated, followed by an interleaving split operation. This PR removes that
    redundant computation.

    Additionally, I refactored the code using `einops` to improve readability. The
    benchmark results confirm that `einops` does not slow down execution on
    TPU/GPU.

    **Benchmark Results**

    **Note:** `8192/0` is the benchmark without JIT, while `8192/1` is the
    benchmark with JIT enabled. The results show that even without JIT, `einops`
    does not cause a slowdown in the code.

    - **TPU (v5p)**: Comparison between **AS-IS** and **this PR**
    ```
    AS-IS
    -------------------------------------------------------------------------------------------
    Benchmark                                 Time             CPU   Iterations      HBM
    -------------------------------------------------------------------------------------------
    QkvLinearBenchmark/1024/4/8192/0       13.4 ms         12.9 ms           56 546.21 MB
    QkvLinearBenchmark/2048/4/8192/0       12.4 ms         11.1 ms           62 1143.38 MB
    QkvLinearBenchmark/1024/4/8192/1       1.69 ms        0.068 ms        10071 546.21 MB
    QkvLinearBenchmark/2048/4/8192/1       3.90 ms        0.080 ms         1000 1143.38 MB

    This PR
    QkvLinearBenchmark/1024/4/8192/0       10.5 ms         10.2 ms           60 545.99 MB
    QkvLinearBenchmark/2048/4/8192/0       11.0 ms         9.82 ms           69 1142.65 MB
    QkvLinearBenchmark/1024/4/8192/1       1.68 ms        0.067 ms        10237 545.99 MB
    QkvLinearBenchmark/2048/4/8192/1       3.83 ms        0.065 ms         1000 1142.65 MB
    ```

    - **GPU (A100)**: Comparison between **AS-IS** and **this PR**
    ```
    AS-IS
    QkvLinearBenchmark/1024/4/8192/0       12.8 ms         12.8 ms           54 428.03 MB
    QkvLinearBenchmark/2048/4/8192/0       13.0 ms         12.8 ms           55 848.05 MB
    QkvLinearBenchmark/1024/4/8192/1      0.665 ms        0.129 ms         5545 428.03 MB
    QkvLinearBenchmark/2048/4/8192/1       1.90 ms        0.160 ms         4661 848.05 MB

    This PR
    QkvLinearBenchmark/1024/4/8192/0       11.4 ms         11.3 ms           61 428.03 MB
    QkvLinearBenchmark/2048/4/8192/0       11.6 ms         11.4 ms           62 848.04 MB
    QkvLinearBenchmark/1024/4/8192/1      0.631 ms        0.137 ms         5595 428.03 MB
    QkvLinearBenchmark/2048/4/8192/1       1.85 ms        0.152 ms         4652 848.04 MB
    ```

commit 2a6b242
Author: Dongseong Hwang <[email protected]>
Date:   Tue Mar 4 18:24:55 2025 -0800

    Remove "einops==0.8.0" dependency in audio as core already has it. (apple#1030)
loofahcus pushed a commit to loofahcus/axlearn that referenced this pull request Apr 10, 2025
* Basic FP8 Training

* Address comments

* Address comments

* Rename

* Use enum

* Use enum

* Address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants