-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[feat] support TiledMLP in Deepspeed and FSDP2 #7090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @kevssim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the training capabilities by integrating Tiled MLP, a technique crucial for optimizing memory usage, especially when working with very long sequences in large language models. By providing specific support for FSDP2 and leveraging existing solutions for DeepSpeed and single-GPU environments, it allows users to train larger models or longer contexts more efficiently, preventing out-of-memory errors and improving scalability. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Tiled MLP to reduce memory usage during training, with a custom implementation for FSDP2 and a fallback to liger_kernel for DeepSpeed and other setups. The changes include new training arguments, example scripts, and the core Tiled MLP implementation. The code is well-structured, particularly the environment detection logic and the dynamic MLP replacement. The custom autograd function for FSDP2 compatibility is a sophisticated piece of engineering. I have a couple of suggestions to improve robustness regarding gradient accumulation precision and exception handling.
| # Split on dim=0 | ||
| x_shards = list(torch.chunk(x, chunks=shards, dim=0)) | ||
|
|
||
| grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GradientAccumulator is initialized with dtype=x.dtype. This means that if x is a low-precision type like bfloat16 or float16, the gradients for the MLP weights will be accumulated in that same low precision. While this might be acceptable for bfloat16, it can lead to precision loss and numerical instability with float16. It is generally safer to perform gradient accumulation in float32 to maintain precision. The GradientAccumulator already defaults to torch.float32 if dtype is not provided. I suggest removing dtype=x.dtype to use the safer default.
| grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) | |
| grad_accumulator = GradientAccumulator(compute_params, shards) |
| try: | ||
| from accelerate import PartialState | ||
| state = PartialState() | ||
| if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: | ||
| # Check if fsdp_version is 2 in the plugin | ||
| if hasattr(state.fsdp_plugin, 'fsdp_version'): | ||
| return state.fsdp_plugin.fsdp_version == 2 | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a broad except Exception: clause can hide unexpected errors and make debugging more difficult. It's better to catch specific exceptions that you expect to handle, such as ImportError or RuntimeError if accelerate is not fully configured. This will make the code more robust and maintainable. The same applies to the is_fsdp1_enabled function.
| try: | |
| from accelerate import PartialState | |
| state = PartialState() | |
| if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: | |
| # Check if fsdp_version is 2 in the plugin | |
| if hasattr(state.fsdp_plugin, 'fsdp_version'): | |
| return state.fsdp_plugin.fsdp_version == 2 | |
| except Exception: | |
| pass | |
| try: | |
| from accelerate import PartialState | |
| state = PartialState() | |
| if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: | |
| # Check if fsdp_version is 2 in the plugin | |
| if hasattr(state.fsdp_plugin, 'fsdp_version'): | |
| return state.fsdp_plugin.fsdp_version == 2 | |
| except (ImportError, RuntimeError): | |
| # It's possible that accelerate is not fully initialized or available. | |
| # In such cases, we can safely ignore the exception and rely on env vars. | |
| pass |
| try: | ||
| from accelerate import PartialState | ||
| state = PartialState() | ||
| if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: | ||
| if hasattr(state.fsdp_plugin, 'fsdp_version'): | ||
| return state.fsdp_plugin.fsdp_version != 2 | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the is_fsdp2_enabled function, using a broad except Exception: is not ideal as it can suppress unexpected errors. It's better to catch specific exceptions like ImportError and RuntimeError to avoid masking other potential issues and improve code robustness.
| try: | |
| from accelerate import PartialState | |
| state = PartialState() | |
| if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: | |
| if hasattr(state.fsdp_plugin, 'fsdp_version'): | |
| return state.fsdp_plugin.fsdp_version != 2 | |
| except Exception: | |
| pass | |
| try: | |
| from accelerate import PartialState | |
| state = PartialState() | |
| if hasattr(state, 'fsdp_plugin') and state.fsdp_plugin is not None: | |
| if hasattr(state.fsdp_plugin, 'fsdp_version'): | |
| return state.fsdp_plugin.fsdp_version != 2 | |
| except (ImportError, RuntimeError): | |
| # It's possible that accelerate is not fully initialized or available. | |
| # In such cases, we can safely ignore the exception and rely on env vars. | |
| pass |
|
|
||
| This module provides a tiled MLP implementation that is compatible with FSDP2. | ||
| - FSDP2: Uses custom TiledMLP implementation (this file) | ||
| - DeepSpeed/Single GPU: Uses liger_kernel's LigerTiledSwiGLUMLP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kevssim I'm not sure if LigerTiledSwiGLUMLP is available on an NPU. Is it possible to provide an NPU-compatible implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LigerTiledSwiGLUMLP is implemented using native PyTorch and theoretically supports NPU, but further testing and verification are needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this PR LGTM. Just one small request: could you test its gain on the NPU? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added NPU support to LigerTiledSwiGLUMLP by replacing the LigerSiLUMulFunction with a native PyTorch implementation.
Additionally, I have supplemented the test results on NPU.
PR type
PR information
Summary
Integrate Tiled MLP into Swift training framework for memory-efficient long sequence training. Tiled MLP splits sequence computation into shards, trading compute time for significant memory savings at long sequence lengths.
Features
Usage
Add the following arguments to enable tiled MLP:
Experiment results
1. FSDP2 + GPU (2x NVIDIA H800)
Env: 2x NVIDIA H800, FSDP2, bf16, batch_size=1, num_shards=4, hidden_size=4096, intermediate_size=12288
2. FSDP2 + NPU (910B)
Env: FSDP2, bf16, NPU
3. DeepSpeed ZeRO-3 + GPU
Env: DeepSpeed ZeRO-3, bf16, GPU
4. DeepSpeed ZeRO-3 + NPU
Env: DeepSpeed ZeRO-3, bf16, NPU
Note: Speed ↑ means tiled is slower, Memory ↓ means tiled uses less memory