Skip to content

add multi-stage guide #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

add multi-stage guide #234

wants to merge 3 commits into from

Conversation

tscholak
Copy link
Collaborator

✨ Description

add short multi-stage guide

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@tscholak tscholak requested a review from jlamypoirier April 16, 2025 17:44
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the guide, I have some minor comments

@tscholak tscholak requested a review from jlamypoirier April 17, 2025 00:46
Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! This looks, though I have some minor suggestions.

| `2` | Replicated | Sharded | Sharded | Moderate[^1] |
| `3` | Sharded | Sharded | Sharded | High[^2] |

[^1]: Communication overhead for ZeRO Stage 2 is similar to Stage 1, except during (depth-first) gradient accumulation when additional all-reduce operations occur.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically reduce-scatter


### Buffers

When gradients or weights are sharded, Fast-LLM accumulates partial results in shared *buffers* during forward and backward passes, separately for gradients and weights. These buffers reduce communication overhead by batching gradient or weight updates across GPUs or nodes. The options `num_grad_buffers` and `num_weight_buffers` control the number of buffers used for gradients and weights, respectively.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful to state explicitly how this relates to ZeRO stages:

  • num_layers buffers: Store all layers in memory, as in traditional (non-ZeRO) DP
  • 2 Keep weights/gradients one layer at the time, i.e. ZeRO stage 2/3. Second buffer is there for network overlap.


By default, Fast-LLM assigns one gradient and weight buffer per stage, where the number of stages equals the total number of logical partitions (stages) of the model. This enables overlapping communication (e.g., data transfers between GPUs or nodes) with computation (actual processing done by each GPU or node). Lower values (e.g., 1) reduce this overlap, potentially increasing communication waiting times.

Increasing `num_grad_buffers` or `num_weight_buffers` provides more room for overlapping communication with compute. This can help in some setups, especially when stages are imbalanced, but generally isn't necessary. Note that this does not reduce total communication; it just shifts when it happens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing transition from the last paragraph, this makes it look like we're going higher than num_layers. Reducing (to 1) is also an option to sacrifice network overlap for lower memory usage.

- **`stages_per_pipeline_stage`**: Intended to specify how many stages run per pipeline worker when pipeline parallelism is active.

!!! warning
This feature is currently **not implemented**. Changing this value has no effect.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically validation will fail

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants