-
Notifications
You must be signed in to change notification settings - Fork 346
Basic FP8 Training Support #992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0ba5ab6
to
d0e4916
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
update_cfg: OverrideInplaceUpdateTransformation.Config = ( | ||
OverrideInplaceUpdateTransformation.default_config() | ||
) | ||
update_cfg.rules = [ | ||
f".*/{x}" | ||
for x in [ | ||
"input_scale", | ||
"kernel_scale", | ||
"output_grad_scale", | ||
"input_amax_history", | ||
"kernel_amax_history", | ||
"output_grad_amax_history", | ||
] | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use update_rules? See
axlearn/axlearn/common/learner_test.py
Line 343 in daec8c5
("state_updates_v", [(".*v", UpdateType.STATE_UPDATES)]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. There is a tracer problem when plumbing the FP8 stats through as state updates, because they are produced in the backward pass of a custom vjp. To circumvent this problem, we produce them as gradient and treat them as state updates.
See the internal PR 806 and 816 for context and discussions.
axlearn/common/utils.py
Outdated
secondary: Nested[Any], | ||
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, | ||
) -> Nested[Any]: | ||
"""Merge `secondary` into `primary`. The result contains shallow copies of subtrees from both. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making them deep copies to be consistent with the jax tree utils. Deep copies are also less error prone by decoupling side effects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
if k in primary: | ||
out_tree[k] = tree_merge(primary[k], secondary=secondary[k], override_fn=override_fn) | ||
else: | ||
out_tree[k] = secondary[k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_tree[k] = secondary[k] | |
out_tree[k] = copy.deepcopy(secondary[k]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
if isinstance(primary, dict) ^ isinstance(secondary, dict): | ||
raise ValueError(f"Trying to merge incompatible subtrees: {primary=}, {secondary=}") | ||
if not (isinstance(primary, dict) or isinstance(secondary, dict)): | ||
return override_fn(primary, secondary) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return override_fn(primary, secondary) | |
return copy.deepcopy(override_fn(primary, secondary)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
primary: Nested[Any], | ||
*, | ||
secondary: Nested[Any], | ||
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, | |
leaf_merge_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
axlearn/common/utils.py
Outdated
primary: Nested[Any], | ||
*, | ||
secondary: Nested[Any], | ||
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using a simpler function (always choose primary) as the default merge function?
override_fn: Callable[[Any, Any], Any] = tree_merge_default_override_fn, | |
override_fn: Callable[[Any, Any], Any] = lambda x, y: x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current default makes more sense. Always choose primary will discard subtree of secondary even when the primary tree has an empty leaf or None at the corresponding position.
@ruomingp Could you please take a look again? |
@ruomingp Could you please take a look again? |
Sorry about the delay. Will take a look soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies about the delay.
"output_grad_amax_history", | ||
] | ||
] | ||
transformation = maybe_instantiate(cfg.learner.optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually we avoid instantiation when constructing configs to preserve the readability of configs. Can we defer the maybe_instantiate
call to OverrideInplaceUpdateTransformation.__init__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"input_scale", | ||
"kernel_scale", | ||
"output_grad_scale", | ||
"input_amax_history", | ||
"kernel_amax_history", | ||
"output_grad_amax_history", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder how we can keep this list in sync with quantized_dot_general/layers.py. Consider
- Defining them as an enum in
quantized_dot_general/common.py
and enumerate the enum members here - Move this class to
quantized_dot_general/update_transformation.py
- Change quantized_dot_general/layers.py to use the enums instead of the string literals?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
2c44d6b
to
fad667c
Compare
@ruomingp Could you please take a look again? |
commit 95c97d3 Author: Rahul H <[email protected]> Date: Mon Mar 31 23:58:24 2025 +0000 Remove cache step arg commit 95a3c2d Author: Rahul H <[email protected]> Date: Mon Mar 31 23:24:21 2025 +0000 Fix envy config commit 5255eb0 Merge: 048145f 862d3a6 Author: Rahul H <[email protected]> Date: Mon Mar 31 23:15:21 2025 +0000 Merge branch 'main' of https://github.com/apple/axlearn into rebasemain commit 862d3a6 Author: Chang Liu <[email protected]> Date: Sun Mar 30 21:20:15 2025 -0700 Consolidate QuotaInfo and UserQuotaInfo into one class (apple#1080) commit 1e2b306 Author: muyangyuapple <[email protected]> Date: Fri Mar 28 16:30:04 2025 -0700 Don't set location_hint if it is None (apple#1075) commit 0d7e65b Author: Dongseong Hwang <[email protected]> Date: Fri Mar 28 15:51:34 2025 -0700 Flash attention fallbacks to standard attention, not ReferenceMHA. (apple#1077) `ReferenceMHA` is essentially a duplication of standard attention’s `_compute_attention`. Instead of duplication, we now call `_compute_attention` directly. Another drawback of `ReferenceMHA` is that it **doesn’t apply gradient checkpointing**, which can lead to **significantly higher memory usage**. For example, training a **600M Conformer model** downstream can cause memory usage to jump from **24.89 GB to 157.58 GB**. commit 3285ee0 Author: Dongseong Hwang <[email protected]> Date: Fri Mar 28 14:13:32 2025 -0700 Make `attention_logit_biases` in init_states/extend_step Optional (apple#1076) This is legacy. Make it optional. Use cfg.mask to avoid unnecessary O(T^2) biases. Currently, it's used by only T5 model. commit 5db2378 Author: Apoorv Gupta <[email protected]> Date: Fri Mar 28 11:57:56 2025 -0700 Improvements and fixes to gradient accumulation (apple#993) * Improvements and fixes to gradient accumulation - Fix to with_minibatch_steps decorator to generate correct primal outputs shapes. - Improved with_minibatch_steps to take a minibatch_partitioner that contraints the input batch to the same PartitionSpec as Input Partitioner. * remove sharding constraints from gradient accumulation * introduce reshape+transpose to avoid input batch reshard in accumulation * Address comments and add minibatch sharding test * TestMinibatchSteps now asserts carry shape matches * add doc string on summary accumulation * reorder pytest skip decorator * update golden configs commit 2ce0aee Author: Mark Lee <[email protected]> Date: Thu Mar 27 21:00:14 2025 -0700 Disables hf test. (apple#1073) commit 02783b5 Author: Mark Lee <[email protected]> Date: Thu Mar 27 20:58:50 2025 -0700 Temporarily retain instance_type behavior and fix aliased defaults. (apple#1072) commit 896617c Author: Hanzhi Zhou <[email protected]> Date: Thu Mar 27 20:56:48 2025 -0400 Reduce amount of cpu simulation tests (apple#1074) commit 0ca5f7c Author: muyangyuapple <[email protected]> Date: Thu Mar 27 10:30:28 2025 -0700 Changed job submitter flags (apple#1070) * Changed job submitter flags * Update axlearn/cloud/gcp/job_flink.py Co-authored-by: Mark Lee <[email protected]> * updated unit test --------- Co-authored-by: Mark Lee <[email protected]> commit da55f3f Author: Dongseong Hwang <[email protected]> Date: Wed Mar 26 21:55:45 2025 -0700 KV cache update respects cached_dtype (apple#1071) The newly added dynamic_update_slice_in_dim path was ignoring cached_dtype. This fixes that issue. commit 8ce86ac Author: Firenze11 <[email protected]> Date: Tue Mar 25 16:25:41 2025 -0700 enable arbitrary head_dim when using gpu flash attention (apple#1048) * enable arbitrary head_dim when using gpu flash attention * Update gpu_attention_test.py * Update gpu_attention.py * Update gpu_attention.py * Update gpu_attention.py * Update gpu_attention.py * Update gpu_attention_test.py commit 5ac077c Author: Dongseong Hwang <[email protected]> Date: Tue Mar 25 16:04:16 2025 -0700 Eliminating Unnecessary Quadratic Memory Usage in KV Cache Prefill (apple#1068) When updating the KV cache, the current approach temporarily creates a [B, step, kv_len] one-hot vector. * This is fine when step is small, but problematic when step is large. * For example, if video is used as a prefill, the KV cache update could cause OOM errors. Solution * When step size is large, we update the KV cache using dynamic_update_slice_in_dim instead. Benchmark Results * On TPU, dynamic_update_slice_in_dim becomes faster when step size ≈ 1024. * On GPU, dynamic_update_slice_in_dim becomes faster when step size ≈ 32. Thus, we set different thresholds per platform. Note: * On GPU, dynamic_update_slice_in_dim significantly reduces memory usage. * On TPU, however, both methods use about the same amount of memory as GPU's dynamic_update_slice_in_dim. It seems that TPU's XLA compilation performs a "magical" memory optimization for one-hot matmul somehow. --- TPU v5p * one-hot matmul QkvLinearExtendStepBenchmark/2048/16/1024/1 0.380 ms 0.086 ms 8462 QkvLinearExtendStepBenchmark/2048/16/4096/1 0.599 ms 0.092 ms 8487 QkvLinearExtendStepBenchmark/2048/16/4096/64 0.562 ms 0.080 ms 8242 QkvLinearExtendStepBenchmark/2048/16/4096/512 0.654 ms 0.078 ms 8862 QkvLinearExtendStepBenchmark/2048/16/32768/1 2.18 ms 0.082 ms 8627 3.05 GB QkvLinearExtendStepBenchmark/2048/16/32768/1024 2.77 ms 0.080 ms 1000 3.08 GB QkvLinearExtendStepBenchmark/2048/16/32768/8192 15.6 ms 0.215 ms 1000 3.30 GB QkvLinearExtendStepBenchmark/2048/16/32768/32768 53.5 ms 0.224 ms 100 4.05 GB MHADecodeBenchmark/2048/16/1024/1 0.358 ms 0.072 ms 9105 MHADecodeBenchmark/2048/16/4096/1 0.542 ms 0.066 ms 9443 MHADecodeBenchmark/2048/16/4096/64 0.556 ms 0.073 ms 9444 MHADecodeBenchmark/2048/16/4096/512 0.710 ms 0.071 ms 9273 MHADecodeBenchmark/2048/16/32768/1 1.95 ms 0.073 ms 9139 2.07 GB MHADecodeBenchmark/2048/16/32768/1024 7.36 ms 0.070 ms 1000 6.11 GB MHADecodeBenchmark/2048/16/32768/8192 55.7 ms 0.210 ms 100 34.38 GB * dynamic_update_slice_in_dim QkvLinearExtendStepBenchmark/2048/16/1024/1 0.374 ms 0.078 ms 8875 QkvLinearExtendStepBenchmark/2048/16/4096/1 0.572 ms 0.079 ms 8471 QkvLinearExtendStepBenchmark/2048/16/4096/64 0.575 ms 0.077 ms 9208 QkvLinearExtendStepBenchmark/2048/16/4096/512 0.666 ms 0.080 ms 8590 QkvLinearExtendStepBenchmark/2048/16/32768/1 2.15 ms 0.082 ms 8296 3.05 GB QkvLinearExtendStepBenchmark/2048/16/32768/1024 2.34 ms 0.083 ms 8352 3.08 GB QkvLinearExtendStepBenchmark/2048/16/32768/8192 4.17 ms 0.085 ms 1000 3.30 GB QkvLinearExtendStepBenchmark/2048/16/32768/32768 10.8 ms 0.131 ms 1000 4.05 GB MHADecodeBenchmark/2048/16/1024/1 0.370 ms 0.071 ms 10258 MHADecodeBenchmark/2048/16/4096/1 0.510 ms 0.070 ms 10296 MHADecodeBenchmark/2048/16/4096/64 2.06 ms 0.070 ms 10345 MHADecodeBenchmark/2048/16/4096/512 13.1 ms 0.207 ms 1000 MHADecodeBenchmark/2048/16/32768/1 1.78 ms 0.074 ms 10033 2.07 GB MHADecodeBenchmark/2048/16/32768/1024 7.59 ms 0.068 ms 1000 6.11 GB MHADecodeBenchmark/2048/16/32768/8192 52.8 ms 0.206 ms 100 34.38 GB A100 * one-hot matmul QkvLinearExtendStepBenchmark/2048/16/1024/1 0.322 ms 0.180 ms 3523 QkvLinearExtendStepBenchmark/2048/16/4096/1 0.455 ms 0.201 ms 4016 QkvLinearExtendStepBenchmark/2048/16/4096/64 0.620 ms 0.188 ms 3524 QkvLinearExtendStepBenchmark/2048/16/4096/512 0.826 ms 0.218 ms 3551 QkvLinearExtendStepBenchmark/2048/16/32768/1 1.48 ms 0.203 ms 3608 QkvLinearExtendStepBenchmark/2048/16/32768/1024 4.36 ms 0.279 ms 1000 3.11 GB QkvLinearExtendStepBenchmark/2048/16/32768/8192 18.0 ms 0.328 ms 1000 5.30 GB QkvLinearExtendStepBenchmark/2048/16/32768/32768 108 ms 0.352 ms 100 12.05 GB MHADecodeBenchmark/2048/16/1024/1 0.307 ms 0.155 ms 4497 MHADecodeBenchmark/2048/16/4096/1 0.389 ms 0.145 ms 4729 MHADecodeBenchmark/2048/16/4096/64 0.586 ms 0.143 ms 4593 MHADecodeBenchmark/2048/16/4096/512 1.16 ms 0.174 ms 4493 MHADecodeBenchmark/2048/16/32768/1 1.16 ms 0.153 ms 4506 MHADecodeBenchmark/2048/16/32768/1024 12.9 ms 0.257 ms 1000 MHADecodeBenchmark/2048/16/32768/8192 88.7 ms 0.375 ms 100 66.31 GB * dynamic_update_slice_in_dim QkvLinearExtendStepBenchmark/2048/16/1024/1 0.345 ms 0.171 ms 3952 QkvLinearExtendStepBenchmark/2048/16/4096/1 0.502 ms 0.178 ms 3879 QkvLinearExtendStepBenchmark/2048/16/4096/64 0.544 ms 0.199 ms 3823 QkvLinearExtendStepBenchmark/2048/16/4096/512 0.583 ms 0.177 ms 3912 QkvLinearExtendStepBenchmark/2048/16/32768/1 2.04 ms 0.199 ms 3357 QkvLinearExtendStepBenchmark/2048/16/32768/1024 2.20 ms 0.218 ms 3322 3.08 GB QkvLinearExtendStepBenchmark/2048/16/32768/8192 3.76 ms 0.264 ms 1000 3.30 GB QkvLinearExtendStepBenchmark/2048/16/32768/32768 8.84 ms 0.282 ms 1000 4.05 GB MHADecodeBenchmark/2048/16/1024/1 0.299 ms 0.140 ms 5119 MHADecodeBenchmark/2048/16/4096/1 0.473 ms 0.143 ms 4915 MHADecodeBenchmark/2048/16/4096/64 0.591 ms 0.143 ms 4937 MHADecodeBenchmark/2048/16/4096/512 1.01 ms 0.149 ms 4696 MHADecodeBenchmark/2048/16/32768/1 1.50 ms 0.146 ms 4711 MHADecodeBenchmark/2048/16/32768/1024 11.3 ms 0.250 ms 1000 MHADecodeBenchmark/2048/16/32768/8192 74.1 ms 0.275 ms 100 66.31 GB commit 363f2aa Author: muyangyuapple <[email protected]> Date: Tue Mar 25 14:06:52 2025 -0700 Added v6e to system_characteristics.py (apple#1069) commit 65994b0 Author: mmllee <[email protected]> Date: Tue Mar 25 18:48:38 2025 +0100 W&B: log flat config along with the nested config (apple#1065) commit 7a851c0 Author: mmllee <[email protected]> Date: Tue Mar 25 18:48:12 2025 +0100 reflect git repo state at job submission time in docker labels and the code bundle (apple#1064) commit 96d8c58 Author: Hanzhi Zhou <[email protected]> Date: Mon Mar 24 21:02:50 2025 -0700 Fix typo (apple#1067) commit fbbef0f Author: Mark Lee <[email protected]> Date: Fri Mar 21 21:31:49 2025 -0700 Refactors config and flag handling for composition. (apple#1063) commit a6df285 Author: Hanzhi Zhou <[email protected]> Date: Fri Mar 21 21:11:01 2025 -0700 Basic FP8 Training Support (apple#992) * Basic FP8 Training * Address comments * Address comments * Rename * Use enum * Use enum * Address comments commit bf15ab3 Author: Hanzhi Zhou <[email protected]> Date: Fri Mar 21 13:57:43 2025 -0700 Refactor and optimize FlashAttention dispatch (apple#1058) * Refactor and Optimize FlashAttention Dispatch * Additional checks * add jit and fix some tests * fix test * refactor cudnn * update benchmarks * Fix TPU tests * Move * Fix TPU benchmark * Fix grad * Fix typing * Fix comment and reorder * Fix type * Address comments * Skip broken custom mask * Add back gpu fwd test * Simplify returns * Address comments * Add default config * Fix * Fix wrong msg * Add comment commit 976ccab Author: Mark Lee <[email protected]> Date: Thu Mar 20 11:45:57 2025 -0700 Removes QRM codepaths. (apple#1062) commit c361523 Author: muyangyuapple <[email protected]> Date: Wed Mar 19 21:10:52 2025 -0700 Revert "Revert "Support Beam pipelines on TPU (apple#1055)" (apple#1057)" (apple#1059) This reverts commit d54aa10. commit 03cfa39 Author: Dongseong Hwang <[email protected]> Date: Wed Mar 19 14:00:15 2025 -0700 Refactor KV Cache Out of `QKVLinear` (apple#1041) * Refactor KV Cache Out of `QKVLinear` Currently, **`QKVLinear` is overly complex** because it handles **both QKV computation and KV cache management**. Although **QKVLinear’s role and KV cache strategy should be independent**, the current implementation **forces QKVLinear to manage KV cache**, making it necessary to **carefully maintain every QKVLinear subclass** (`FusedQKV`, `GroupedQKV`, `RoPEQKV`, etc.) to ensure they correctly handle KV cache. **Key Changes in This PR** This PR **removes KV cache logic from `QKVLinear`**, turning it into a **pure `forward`-only class** (such as QKV proj and RoPE) that **no longer needs to handle decoding**. Instead, **Attention now owns the KV cache directly**, making it **more flexible for future KV cache strategies**. Currently, **`QKVLinear` supports only one KV cache behavior**, which maintains a **fixed max length**. However, in the near future, we will introduce more **KV cache strategies**, such as: - **Sliding Window Attention** → Requires a **sliding window KV cache**. - **Sparse Attention** → Needs a KV cache that **dynamically selects sparse KV** (similar to DeepSeek). https://arxiv.org/abs/2502.11089 **Implementation Details** A key aspect of this refactor is **how query positions and key positions are generated**. Previously, the related logic was **scattered across multiple places**, but now, **positions are computed in a single place**: - **Query positions** → Must be determined **before RoPE** since RoPE requires them. The **same query positions** are then **reused throughout the code**. - **Key positions** → Only the **KV cache layer** can determine them **accurately** since **KV cache strategies** directly affect key positions. So, **KV cache is now responsible for generating key positions**. In addition, **`KVState` now carries both KV values and key positions**. * Introduce live_step_len in Prefill. In KVCache, the valid parts of the input KV are marked using **live_step_len**. This information is necessary for the future (and currently downstream) sliding window KV cache. Otherwise, the cache would store invalid KV entries and simply roll the window by a fixed size, potentially losing past KV entries that should remain within the window. In the long run, **multi-step `extend_step` should also use live_step_len**, since the valid sequence length may vary across different batches. commit a39f8ca Author: Dongseong Hwang <[email protected]> Date: Wed Mar 19 09:22:38 2025 -0700 TPU flash attention: Allow custom mask. 2/2 (apple#1061) This is follow-up of apple#1050 `mask.bool_value()` internally uses `mask.target_positions` tensor, so it cannot be used inside `with jax.ensure_compile_time_eval():`, even if its shape is statically determined. Before the changes in the previous PR, apple#1028 `mask.bool_value()` worked because mask did not store any tensors. To resolve this, this PR reimplemented the same logic without using `jax.ensure_compile_time_eval()`. As mentioned earlier, since the shape is statically determined, `target_positions` can be converted into a NumPy array at runtime. commit d54aa10 Author: muyangyuapple <[email protected]> Date: Tue Mar 18 10:38:21 2025 -0700 Revert "Support Beam pipelines on TPU (apple#1055)" (apple#1057) This reverts commit 4401985. commit 9999da2 Author: Haoshuo Huang <[email protected]> Date: Mon Mar 17 23:26:53 2025 -0700 Fix a wrong assert in input_grain_lm (apple#1056) commit 4401985 Author: muyangyuapple <[email protected]> Date: Mon Mar 17 22:33:45 2025 -0700 Support Beam pipelines on TPU (apple#1055) * Support Beam pipelines on TPU * Update _Matcher signature to meet pytype requirement. * Update axlearn/cloud/gcp/job_flink.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/cloud/gcp/job_flink.py Co-authored-by: Mark Lee <[email protected]> * fix pytype * fix a minor bug during pre-merge verification --------- Co-authored-by: Mark Lee <[email protected]> commit 5a17918 Author: Hanzhi Zhou <[email protected]> Date: Mon Mar 17 15:10:25 2025 -0400 Lazy import orbax (apple#1054) * Lazy import orbax * fix test commit 53eea9d Author: Dongseong Hwang <[email protected]> Date: Mon Mar 17 11:47:14 2025 -0700 The dtype of paddings becomes jnp.bool. (apple#1053) Previously, it was assumed to have the same dtype as the input. However, since paddings only store 0/1 values, this restriction is unnecessary. There is a need to save memory by using a much smaller dtype instead of float32/bfloat16. We prefer to use jnp.bool because Backend | Boolean Storage Size (jnp.bool) -- | -- CPU (XLA:CPU) | 8 bits (1 byte per bool) GPU (XLA:GPU, CUDA, ROCm) | 8 bits (1 byte per bool, no native bit-packing) TPU (XLA:TPU) | 1 bit per bool (native bit-packing) commit 182605b Author: fnan <[email protected]> Date: Mon Mar 17 13:01:12 2025 -0400 fix tool call intent when tool_calls is empty list (apple#1052) commit 8a52817 Author: Dongseong Hwang <[email protected]> Date: Sat Mar 15 09:24:56 2025 -0700 Fix Broken CPU Code Path in `run_aot_compilation.py` (apple#1051) `run_aot_compilation.py` was designed to run on **CPU** when `--topology=None`, but this **code path was broken**. This PR fixes it to ensure proper execution on **CPU**. Now, --topology is a required flag. If you want to perform AOT compilation using the JAX CPU library, you must specify --topology=cpu-256 (or a similar cpu-<digit> format). This is necessary because, even for CPU execution, we need to determine the logical device number to estimate activation memory usage. commit fde078d Author: Dongseong Hwang <[email protected]> Date: Fri Mar 14 12:45:08 2025 -0700 TPU flash attention: Allow custom mask. (apple#1050) TPU Flash Attention previously supported custom masks through splash_attention_mask.NumpyMask, but the following PR slightly disabled this functionality. https://github.pie.apple.com/foundation-models/ajax/pull/19016 In production, only causal and sliding window masks are used, so TPU Flash Attention has been working without issues. However, since there’s no need to remove support for something that was originally available, this PR re-enables custom mask support as before. TEST=axlearn/common/flash_attention/tpu_attention_test.py on v5p commit 1a8b586 Author: Haoshuo Huang <[email protected]> Date: Thu Mar 13 18:25:47 2025 -0700 Add streaming packing support for grain (apple#1046) * Add streaming packing * fix typo * minor cleanup * minor cleanup * fix format * Fix comments * fix comments * fix comments * minor cleanup * fix comments * rephrase commit 216fa1e Author: Hanzhi Zhou <[email protected]> Date: Wed Mar 12 22:08:54 2025 -0700 Fix TPU decoding when KV and Q have different dtype (apple#1049) * Fix TPU decoding dtype issue * Fix TPU decoding dtype issue commit 1d4fc81 Author: Mark Lee <[email protected]> Date: Wed Mar 12 22:02:30 2025 -0700 Generalizes checkpointer to support arbitrary PythonSavables. (apple#1047) * Generalizes checkpointer to support arbitrary PythonSavables. * Switches to msgpack. commit 2795c6e Author: muyangyuapple <[email protected]> Date: Wed Mar 12 16:10:56 2025 -0700 Runners and utils to support Beam on Flink (apple#1044) * Runners and utils to support Beam on Flink * Update job_flink_test.py * Update gke_runner.py * Update axlearn/cloud/gcp/jobs/tpu_utils.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/cloud/gcp/jobs/gke_runner.py Co-authored-by: Mark Lee <[email protected]> * Update job_flink.py * Fix a typo * typo fixes and proofread --------- Co-authored-by: Mark Lee <[email protected]> commit 2249dd5 Author: Hanzhi Zhou <[email protected]> Date: Tue Mar 11 17:20:12 2025 -0700 Fix gpu flash test (apple#1045) * Fix GPU Flash tests * Fix GPU Flash tests * relax commit 93a9ad6 Author: Meng (Ethan) Li <[email protected]> Date: Tue Mar 11 11:56:58 2025 -0700 Set megascale_grpc_enable_xor_tracer to false by default (apple#1043) commit 281508d Author: Dongseong Hwang <[email protected]> Date: Mon Mar 10 19:49:18 2025 -0700 Fix KeyError of aot_model_analysis in GPU (apple#1042) On H100, XLA AOT analysis works well, but not all GPUs generate proper logs. Therefore, I wrapped the entire AOT log section with a try-exception block. commit 2673462 Author: Mark Lee <[email protected]> Date: Sat Mar 8 12:47:36 2025 -0800 Simplifies logical dispatch by direct global array construction. (apple#1040) * Simplifies logical dispatch by direct global array construction. * Address comments. * Fixes non-compliant testing inputs. commit 771926e Author: Hanzhi Zhou <[email protected]> Date: Fri Mar 7 18:48:02 2025 -0800 Implements TPU decoding in Pallas (apple#1039) * Implements TPU decoding in Pallas * Add support for shorter kv len commit b1e7b37 Author: Jialing Tong <[email protected]> Date: Thu Mar 6 20:23:47 2025 -0800 Enable to assign different parameters dtype during training (apple#1037) * Enable per parameter train_dtype * nit commit 3d62075 Author: Ke Ye <[email protected]> Date: Thu Mar 6 22:58:45 2025 -0500 Allow more flexible attribute name for getting named configs. (apple#1038) commit 0acb9a5 Author: Dongseong Hwang <[email protected]> Date: Thu Mar 6 17:39:12 2025 -0800 Trainer: the model analysis on the AOT compiled JAX program. (apple#1036) This will help researchers estimate HBM usage and computation costs before launching a job, allowing them to determine whether a model is compute-bound or memory-bound. Introduced aot_model_analysis(), which returns analysis results as a string, making it reusable (e.g., in Jupyter notebooks). `run_aot_compilation` tool prints fuji-1B-v3 model analysis as follows. ``` ======= Memory Analysis ================================== Input memory: 4.4GB Output memory: 4.4GB Temp memory: 170.9GB Code memory: 0.0MB Total HBM memory: 179.6GB ======= Cost Analysis ==================================== FLOPS: 70052.0G The number of exp/log/sin/cos ops: 20.9G The total memory traffic: 1750.7GB HBM access: 733.9GB L2 cache access: 321.0GB Register usage: 59.8GB Output data transferred: 661.4GB Hardware utilization scores Tensor Cores / MatMul units: 630.0 ALU (Arithmetic Logic Unit): 416.0 Memory Load/Store Units: 143.0 L1 Cache Operations: 92.0 L2 Cache Operations: 60.0 Special Function Units (exp/log/sin/cos): 41.0 Integer Units (for indexing, loop counters): 16.0 Branch Divergence (Control Flow Processing): 12.0 Load Balancing / Dispatch): 10.0 Texture Units (or Rarely Used Compute Units): 8.0 ``` commit 35a189c Author: Chang Liu <[email protected]> Date: Wed Mar 5 10:22:35 2025 -0800 Add gke_reservation_project config to support shared reservation. (apple#1033) commit dcfdbc1 Author: Mark Lee <[email protected]> Date: Wed Mar 5 09:39:20 2025 -0800 Disables some flaky metrics tests. (apple#1034) commit 88c2ff6 Author: Mark Lee <[email protected]> Date: Wed Mar 5 09:39:07 2025 -0800 Enables input multi-device tests in CI. (apple#1035) commit 6e34bcb Author: Hanzhi Zhou <[email protected]> Date: Tue Mar 4 22:33:12 2025 -0800 XLA flag autotuning for v6e (apple#1032) commit 94d7fa3 Author: Dongseong Hwang <[email protected]> Date: Tue Mar 4 21:14:51 2025 -0800 Little Optimization for RoPE Computation (apple#1031) In the existing `_rotary_sinusoidal_positional_embeddings()`, the same `position_enc[:, :, 0::2]` and `position_enc[:, :, 1::2]` computations were duplicated, followed by an interleaving split operation. This PR removes that redundant computation. Additionally, I refactored the code using `einops` to improve readability. The benchmark results confirm that `einops` does not slow down execution on TPU/GPU. **Benchmark Results** **Note:** `8192/0` is the benchmark without JIT, while `8192/1` is the benchmark with JIT enabled. The results show that even without JIT, `einops` does not cause a slowdown in the code. - **TPU (v5p)**: Comparison between **AS-IS** and **this PR** ``` AS-IS ------------------------------------------------------------------------------------------- Benchmark Time CPU Iterations HBM ------------------------------------------------------------------------------------------- QkvLinearBenchmark/1024/4/8192/0 13.4 ms 12.9 ms 56 546.21 MB QkvLinearBenchmark/2048/4/8192/0 12.4 ms 11.1 ms 62 1143.38 MB QkvLinearBenchmark/1024/4/8192/1 1.69 ms 0.068 ms 10071 546.21 MB QkvLinearBenchmark/2048/4/8192/1 3.90 ms 0.080 ms 1000 1143.38 MB This PR QkvLinearBenchmark/1024/4/8192/0 10.5 ms 10.2 ms 60 545.99 MB QkvLinearBenchmark/2048/4/8192/0 11.0 ms 9.82 ms 69 1142.65 MB QkvLinearBenchmark/1024/4/8192/1 1.68 ms 0.067 ms 10237 545.99 MB QkvLinearBenchmark/2048/4/8192/1 3.83 ms 0.065 ms 1000 1142.65 MB ``` - **GPU (A100)**: Comparison between **AS-IS** and **this PR** ``` AS-IS QkvLinearBenchmark/1024/4/8192/0 12.8 ms 12.8 ms 54 428.03 MB QkvLinearBenchmark/2048/4/8192/0 13.0 ms 12.8 ms 55 848.05 MB QkvLinearBenchmark/1024/4/8192/1 0.665 ms 0.129 ms 5545 428.03 MB QkvLinearBenchmark/2048/4/8192/1 1.90 ms 0.160 ms 4661 848.05 MB This PR QkvLinearBenchmark/1024/4/8192/0 11.4 ms 11.3 ms 61 428.03 MB QkvLinearBenchmark/2048/4/8192/0 11.6 ms 11.4 ms 62 848.04 MB QkvLinearBenchmark/1024/4/8192/1 0.631 ms 0.137 ms 5595 428.03 MB QkvLinearBenchmark/2048/4/8192/1 1.85 ms 0.152 ms 4652 848.04 MB ``` commit 2a6b242 Author: Dongseong Hwang <[email protected]> Date: Tue Mar 4 18:24:55 2025 -0800 Remove "einops==0.8.0" dependency in audio as core already has it. (apple#1030)
* Basic FP8 Training * Address comments * Address comments * Rename * Use enum * Use enum * Address comments
This PR adds FP8 training support using either in-batch scaling or delayed scaling. There's a one scaling factor per tensor.