Skip to content

Support frozen weights #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 27, 2025
Merged

Support frozen weights #185

merged 20 commits into from
Mar 27, 2025

Conversation

jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Mar 11, 2025

✨ Description

Fix #183

  • Extract the FSDP stuff from Stage so we can have a separate one for frozen weights. (811739a)
  • Separate shards so they can have different sizes (i.e. no grad and optimizer shards for frozen parameters). Shards now stored as a dict {shard_name:shard} instead of a single tensor of shape (num_shards, shard_size). (b9b017f)
  • Remove unnecessary buffers and shards for frozen weights by setting their size to zero. (b9b017f)
  • Add test for frozen weights and make it pass. (e878656)
  • Train a small model (no frozen weights) to check for regressions.
  • Try loading an older checkpoint in distributed format to verify backward compatibility.
$ fast-llm train gpt
[...]
2025-03-18 03:02:53,421 >>> Allocating 14 weight buffers (692.54 MiB)
2025-03-18 03:02:53,952 >>> Allocating 14 grad buffers (692.54 MiB)
2025-03-18 03:02:53,952 >>> Allocating 4 shards (2,770.14 MiB)
2025-03-18 03:02:53,953 Total allocated: 4,155.21 MiB
[...]
$ fast-llm train gpt model.base_model.transformer.mlp_lr_scale=[0]
[...]
2025-03-18 03:01:30,496 >>> Allocating 14 weight buffers (692.54 MiB)
2025-03-18 03:01:31,326 >>> Allocating 14 grad buffers (308.30 MiB)
2025-03-18 03:01:31,327 >>> Allocating 4 shards (1,617.44 MiB)
2025-03-18 03:01:31,327 Total allocated: 2,618.27 MiB
[...]

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier jlamypoirier marked this pull request as ready for review March 18, 2025 22:20
@jlamypoirier jlamypoirier requested a review from tscholak March 18, 2025 22:22
@jlamypoirier
Copy link
Collaborator Author

@tscholak PR is ready for review, feel free to volunteer another lucky reviewer. I guess it's too big for a full review, general structure should be fine.
I still have one-off tests to do, but they aren't strictly part of the PR so can be done in parallel with the review.

@jlamypoirier jlamypoirier mentioned this pull request Mar 20, 2025
17 tasks
@jlamypoirier
Copy link
Collaborator Author

No regression found on the tutorial, so things seem to be working properly.

[main]
2025-03-21 21:57:23,638 [Rank 0] PhaseType.training @ iteration    100/   100 | consumed samples:       48,000 | consumed tokens:       49,152,000 | batch size: 480 | step time: 191.74 ms | throughput: 292.57 tflop/s (model) | 298.24 tflop/s (hardware) | 320432.24 tokens/s/gpu | Memory allocated 1,091.40 MiB | max allocated 40,576.67 MiB | reserved 47,002.00 MiB | max reserved 47,002.00 MiB | global max reserved 47,002.00 MiB | learning rate: 6.000e-04 | loss scale:     1 | grad norm: 0.3063 | skipped iterations:   0 | nan iterations:   0 | average step time 231.56 ms | remaining 0:00:00  | completion 2025-03-21 21:57:24 (100.00 %) | language model loss: 6.17838 | run: 1
[frozen_weights]
2025-03-21 21:59:58,159 [Rank 0] PhaseType.training @ iteration    100/   100 | consumed samples:       48,000 | consumed tokens:       49,152,000 | batch size: 480 | step time: 191.96 ms | throughput: 292.23 tflop/s (model) | 297.90 tflop/s (hardware) | 320063.90 tokens/s/gpu | Memory allocated 1,091.40 MiB | max allocated 40,576.67 MiB | reserved 47,002.00 MiB | max reserved 47,002.00 MiB | global max reserved 47,002.00 MiB | learning rate: 6.000e-04 | loss scale:     1 | grad norm: 0.7672 | skipped iterations:   0 | nan iterations:   0 | average step time 231.80 ms | remaining 0:00:00  | completion 2025-03-21 21:59:58 (100.00 %) | language model loss: 6.16544 | run: 0

Might be worth trying another example closer to our experimental setup though (@RaymondLi0 ?)

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. This is mostly a refactor, and I appreciate that tests were updated and extended. I'm not going to deep-review every detail of the checkpointing and loading logic, but from what I see, the transition to named shards and frozen weight handling is clean and backward-compatible.

One requirement before merging: we need a real-world, non-trivial training run post-refactor to confirm nothing breaks under load. Ideally, something with a bit of complexity, perhaps, a multi-stage setup with MoEs and a mix of frozen and unfrozen weights.

@RaymondLi0: can you own this and re-run a chilled/frozen MoE test job from back when you experimented with this from this branch? or something else that is similarly complex and can be compared to previous results in wandb.

@RaymondLi0
Copy link
Contributor

Sure! So just to confirm, the goal would be to reproduce a previously run experiment with frozen weights, and observe: less memory usage, same loss curve. Is my understanding correct?
If so I could re-run https://wandb.ai/maxmatical/moe-cl/runs/rsodr3jlilcoyjoh/overview which involved frozen weights.

@jlamypoirier
Copy link
Collaborator Author

@RaymondLi0 Something like this. I already confirmed the lower memory usage, so the loss curve is the main thing to check (maybe throughut too). My main concern is about regressions from refactoring more than the new frozen feature, so strictly. speaking the test doesn't even need to use frozen weighs.

@tscholak
Copy link
Collaborator

You could also run two experiments, one with lr-scale set to 0 and one set to 1e-12 or something ridiculously small, making sure that the losses are consistent

@RaymondLi0
Copy link
Contributor

Sounds good, I'll re-run that experiment then.
If the main concern is about regressions from refactoring though it may be best to re-run a more recent experiment like one from SLAM. I can launch this as well.

@RaymondLi0
Copy link
Contributor

@jlamypoirier I got the following error trying to re-run a SLAM experiment:

  File "/app/fast_llm/engine/checkpoint/distributed.py", line 100, in load
    log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning())
                                                                        ^^^^^^^^^^^^^^^^
TypeError: Logger.warning() missing 1 required positional argument: 'msg'

@RaymondLi0
Copy link
Contributor

RaymondLi0 commented Mar 26, 2025

Screenshot 2025-03-26 at 2 46 51 PM The loss is a bit higher than the original run (which is from early Feb). Will run the same on main to see if the difference comes from this PR (I also tried running the experiment with frozen experts, but it fails with another issue not related to this PR, but I think this recent run is more relevant)

@RaymondLi0
Copy link
Contributor

Screenshot 2025-03-26 at 4 47 42 PM Getting the same loss curve on main. So this refactoring seems correct. Anything in the last 1.5 month that could have caused this variation?

@tscholak
Copy link
Collaborator

Dataset sampling changes most likely

@jlamypoirier
Copy link
Collaborator Author

Since it's unrelated with this PR, let's merge and open a separate issue? Doesn't look like a big difference but it's likely hiding a real bug. (@tscholak this is using the legacy dataset so should not have changed at all)

Which commit was the previous run using? Could we bisect the commit that caused that difference? (And re-run the old one to rule out other factors?)

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thanks!

@jlamypoirier jlamypoirier merged commit 14c980b into main Mar 27, 2025
2 checks passed
@jlamypoirier jlamypoirier deleted the frozen_weights branch March 27, 2025 02:45
@RaymondLi0
Copy link
Contributor

Re-running with the old commit fails because of memmap-version:

  File "/app/fast_llm/data/dataset/gpt/memmap.py", line 32, in _init
    Assert.eq(struct.unpack("<Q", stream.read(8))[0], 1)
  File "/app/fast_llm/utils.py", line 88, in eq
    assert x == arg, f"{x} != {arg}"
           ^^^^^^^^
AssertionError: 2 != 1

(indeed one of the dataset was deleted and re-tokenized a few weeks ago).
Any idea how to work-around this?

@jlamypoirier
Copy link
Collaborator Author

You aren't using the exact same prepared dataset? That could explain the difference...
The new version was introduced in #113, is the old commit far behind? I guess you could try running from that commit first, or add memmap versioning part from there to the old commit.

@RaymondLi0
Copy link
Contributor

The data was re-tokenized, but it should be the same, except the version is different ...
The old commit is a93891ec73f27a67a917aaec02290b3445b6d7cb, which is based on #117

@RaymondLi0
Copy link
Contributor

The original commit was before this fix: 3921dae
This probably explains the difference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support frozen weights
3 participants