Skip to content

Conversation

@kevssim
Copy link
Collaborator

@kevssim kevssim commented Dec 17, 2025

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Summary

Integrate Tiled MLP into Swift training framework for memory-efficient long sequence training. Tiled MLP splits sequence computation into shards, trading compute time for significant memory savings at long sequence lengths.

Features

  • FSDP2: Uses custom TiledMLP implementation
  • DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP
  • DeepSpeed/Single NPU: Uses LigerTiledSwiGLUMLP with native PyTorch _mlp_forward

Usage

Add the following arguments to enable tiled MLP:

--use_tiled_mlp true
--tiled_mlp_num_shards 4  # optional, default is 4

Experiment results

1. FSDP2 + GPU (2x NVIDIA H800)

Env: 2x NVIDIA H800, FSDP2, bf16, batch_size=1, num_shards=4, hidden_size=4096, intermediate_size=12288

Seq Length Regular (ms) Tiled (ms) Speed Regular (MB) Tiled (MB) Memory
2048 18.59 23.51 +26% 1216.00 1444.00 +19%
4096 22.43 31.57 +41% 1504.00 1576.00 +5%
8192 33.11 46.01 +39% 2080.00 1840.00 -12%
16384 54.06 76.76 +42% 3424.00 2368.00 -31%
32768 103.93 136.67 +31% 6112.00 3456.00 -43%
65536 190.89 258.67 +36% 11488.00 5696.00 -50%
131072 391.50 501.15 +28% 22240.00 10176.00 -54%

2. FSDP2 + NPU (910B)

Env: FSDP2, bf16, NPU

Seq Length Regular (ms) Tiled (ms) Speed Regular (MB) Tiled (MB) Memory
2048 54.41 46.26 -15% 1152.01 1460.02 +27%
4096 65.69 81.05 +23% 1460.01 1576.01 +8%
8192 80.46 98.85 +23% 2036.01 1808.01 -11%
16384 108.22 141.93 +31% 3360.01 2324.01 -31%
32768 177.25 228.90 +29% 6048.01 3392.01 -44%
65536 313.80 419.55 +34% 11424.01 5632.01 -51%
131072 601.27 791.40 +32% 22176.01 10112.01 -54%

3. DeepSpeed ZeRO-3 + GPU

Env: DeepSpeed ZeRO-3, bf16, GPU

Seq Length Standard (ms) Tiled (ms) Speed Standard (MB) Tiled (MB) Memory
2048 9.75 20.73 +113% 1920.96 3009.91 +57%
4096 12.63 21.60 +71% 4546.87 5451.82 +20%
8192 17.55 26.45 +51% 7524.78 8053.73 +7%
16384 29.66 40.83 +38% 11238.69 10951.64 -3%
32768 53.93 71.41 +32% 16360.60 14441.55 -12%
65536 100.70 132.27 +31% 24298.51 19115.46 -21%
131072 194.67 254.97 +31% 37868.42 26157.37 -31%

4. DeepSpeed ZeRO-3 + NPU

Env: DeepSpeed ZeRO-3, bf16, NPU

Seq Length Standard (ms) Tiled (ms) Speed Standard (MB) Tiled (MB) Memory
2048 34.77 77.10 +122% 1888.97 3029.93 +60%
4096 40.82 83.04 +103% 4502.88 5483.84 +22%
8192 56.68 103.27 +82% 7460.80 8085.76 +8%
16384 92.51 143.88 +56% 11174.71 11003.67 -2%
32768 157.39 229.23 +46% 16296.63 14665.59 -10%
65536 299.29 418.76 +40% 24234.54 19723.50 -19%
131072 592.62 804.98 +36% 37804.46 27533.42 -27%

Note: Speed ↑ means tiled is slower, Memory ↓ means tiled uses less memory

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kevssim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the training capabilities by integrating Tiled MLP, a technique crucial for optimizing memory usage, especially when working with very long sequences in large language models. By providing specific support for FSDP2 and leveraging existing solutions for DeepSpeed and single-GPU environments, it allows users to train larger models or longer contexts more efficiently, preventing out-of-memory errors and improving scalability.

Highlights

  • Tiled MLP Integration: Introduces Tiled MLP for memory-efficient training, supporting FSDP2 with a custom implementation and DeepSpeed/single-GPU with liger_kernel.
  • FSDP2 Compatibility: Provides a dedicated TiledSwiGLUMLP implementation and an autograd function (TiledMLPFunction) specifically designed for FSDP2, including a GradientAccumulator for efficient gradient handling.
  • Dynamic MLP Replacement: Implements a mechanism to dynamically replace standard MLP layers with their tiled counterparts based on the detected distributed training environment (FSDP2, DeepSpeed, or single GPU), with FSDP1 explicitly not supported.
  • New Training Arguments: Adds --use_tiled_mlp and --tiled_mlp_num_shards to the training arguments, allowing users to easily enable and configure Tiled MLP.
  • Example Scripts: Includes new example scripts (train_deepspeed.sh, train_fsdp2.sh, fsdp2.json) to demonstrate Tiled MLP usage with DeepSpeed and FSDP2.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Tiled MLP to reduce memory usage during training, with a custom implementation for FSDP2 and a fallback to liger_kernel for DeepSpeed and other setups. The changes include new training arguments, example scripts, and the core Tiled MLP implementation. The code is well-structured, particularly the environment detection logic and the dynamic MLP replacement. The custom autograd function for FSDP2 compatibility is a sophisticated piece of engineering. I have a couple of suggestions to improve robustness regarding gradient accumulation precision and exception handling.

# Split on dim=0
x_shards = list(torch.chunk(x, chunks=shards, dim=0))

grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The GradientAccumulator is initialized with dtype=x.dtype. This means that if x is a low-precision type like bfloat16 or float16, the gradients for the MLP weights will be accumulated in that same low precision. While this might be acceptable for bfloat16, it can lead to precision loss and numerical instability with float16. It is generally safer to perform gradient accumulation in float32 to maintain precision. The GradientAccumulator already defaults to torch.float32 if dtype is not provided. I suggest removing dtype=x.dtype to use the safer default.

Suggested change
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
grad_accumulator = GradientAccumulator(compute_params, shards)

Comment on lines +214 to +222
try:
from accelerate import PartialState
state = PartialState()
if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None:
# Check if fsdp_version is 2 in the plugin
if hasattr(state.fsdp_plugin, 'fsdp_version'):
return state.fsdp_plugin.fsdp_version == 2
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: clause can hide unexpected errors and make debugging more difficult. It's better to catch specific exceptions that you expect to handle, such as ImportError or RuntimeError if accelerate is not fully configured. This will make the code more robust and maintainable. The same applies to the is_fsdp1_enabled function.

Suggested change
try:
from accelerate import PartialState
state = PartialState()
if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None:
# Check if fsdp_version is 2 in the plugin
if hasattr(state.fsdp_plugin, 'fsdp_version'):
return state.fsdp_plugin.fsdp_version == 2
except Exception:
pass
try:
from accelerate import PartialState
state = PartialState()
if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None:
# Check if fsdp_version is 2 in the plugin
if hasattr(state.fsdp_plugin, 'fsdp_version'):
return state.fsdp_plugin.fsdp_version == 2
except (ImportError, RuntimeError):
# It's possible that accelerate is not fully initialized or available.
# In such cases, we can safely ignore the exception and rely on env vars.
pass

Comment on lines +233 to +240
try:
from accelerate import PartialState
state = PartialState()
if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None:
if hasattr(state.fsdp_plugin, 'fsdp_version'):
return state.fsdp_plugin.fsdp_version != 2
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the is_fsdp2_enabled function, using a broad except Exception: is not ideal as it can suppress unexpected errors. It's better to catch specific exceptions like ImportError and RuntimeError to avoid masking other potential issues and improve code robustness.

Suggested change
try:
from accelerate import PartialState
state = PartialState()
if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None:
if hasattr(state.fsdp_plugin, 'fsdp_version'):
return state.fsdp_plugin.fsdp_version != 2
except Exception:
pass
try:
from accelerate import PartialState
state = PartialState()
if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None:
if hasattr(state.fsdp_plugin, 'fsdp_version'):
return state.fsdp_plugin.fsdp_version != 2
except (ImportError, RuntimeError):
# It's possible that accelerate is not fully initialized or available.
# In such cases, we can safely ignore the exception and rely on env vars.
pass


This module provides a tiled MLP implementation that is compatible with FSDP2.
- FSDP2: Uses custom TiledMLP implementation (this file)
- DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevssim I'm not sure if LigerTiledSwiGLUMLP is available on an NPU. Is it possible to provide an NPU-compatible implementation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LigerTiledSwiGLUMLP is implemented using native PyTorch and theoretically supports NPU, but further testing and verification are needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this PR LGTM. Just one small request: could you test its gain on the NPU? Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added NPU support to LigerTiledSwiGLUMLP by replacing the LigerSiLUMulFunction with a native PyTorch implementation.

Additionally, I have supplemented the test results on NPU.

@kevssim kevssim marked this pull request as draft December 18, 2025 02:42
@kevssim kevssim marked this pull request as ready for review December 18, 2025 07:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants