Skip to content

Comments

Split wgrad&dgrad from backward() to support a2a overlap#1653

Merged
ksivaman merged 30 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/split_wgrad_new
Apr 18, 2025
Merged

Split wgrad&dgrad from backward() to support a2a overlap#1653
ksivaman merged 30 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/split_wgrad_new

Conversation

@lhb8125
Copy link
Contributor

@lhb8125 lhb8125 commented Apr 8, 2025

Description

Add a flag split_bw to control if we should separate wgrad from backward() and schedule it in another function to better hide the a2a communication when training moe models.
This MR is to support the 1f1b with a2a overlap in MCore, similar with the idea in DualPipe.
This feature has an assertion:

ub_bulk_wgrad == False

because the knob will bind the output of wgrad with dgrad , which complicates the computing context of wgrad.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add class WeightGradStore to put and pop the context of wgrad computation;
  • Wrap&store the wgrad computation of class Linear/LayernormLinear/GroupedLinear and pop it in wgrad_comp();
  • Add some unit tests in test_numerics.py;

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

lhb8125 and others added 8 commits April 7, 2025 18:55
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@ksivaman
Copy link
Member

/te-ci pytorch L0 L1

@ksivaman ksivaman requested a review from denera April 10, 2025 15:13
@lhb8125 lhb8125 force-pushed the hongbinl/split_wgrad_new branch from a718320 to 76eea17 Compare April 11, 2025 14:43
lhb8125 and others added 4 commits April 11, 2025 07:44
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@denera
Copy link
Collaborator

denera commented Apr 11, 2025

/te-ci pytorch L0 L1

@lhb8125 lhb8125 force-pushed the hongbinl/split_wgrad_new branch from b80a842 to 7ec4182 Compare April 14, 2025 12:00
@lhb8125
Copy link
Contributor Author

lhb8125 commented Apr 14, 2025

/te-ci pytorch L0 L1

@lhb8125 lhb8125 requested a review from denera April 14, 2025 12:08
denera
denera previously approved these changes Apr 14, 2025
Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@lhb8125
Copy link
Contributor Author

lhb8125 commented Apr 17, 2025

/te-ci pytorch L0 L1

@ptrendx
Copy link
Member

ptrendx commented Apr 17, 2025

What about LayerNormMLP? Since the functionality is advertised as general in all TE modules LNMLP should also be changed.

@lhb8125
Copy link
Contributor Author

lhb8125 commented Apr 17, 2025

What about LayerNormMLP? Since the functionality is advertised as general in all TE modules LNMLP should also be changed.

@ptrendx LNMLP has already been changed to delay the wgrad computation.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member

/te-ci pytorch L0 L1

ksivaman
ksivaman previously approved these changes Apr 18, 2025
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@ksivaman ksivaman merged commit 9f8aadd into NVIDIA:main Apr 18, 2025
11 checks passed
@timmoon10 timmoon10 mentioned this pull request Apr 25, 2025
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants