Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DDP with nf4 #1684

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Fix DDP with nf4 #1684

wants to merge 5 commits into from

Conversation

jeromeku
Copy link
Collaborator

@jeromeku jeromeku commented Feb 9, 2025

@weifengpy @drisspg

Fix #1665

TLDR: Implement aten.cat.default so that NF4Tensor can be used when using DDP.

Overview

DDP syncs params and buffers during __init__. This dispatches to a call to aten.cat.default with (potentially) a list of tensors with mixed dtypes if nf4 tensors fall in the same bucket as regular tensors.

Implementing aten.cat.default fixes this issue by unpacking the nf4 to their original tensors. Other operations post the sync are already implemented such that the synced modules can be properly reconstructed.

Tests

Tests are located in tests/dtypes/ddp and can be run by executing the run_ddp_nf4_test.sh script.

This script does the following:

  1. Runs a LoraLinear model (ddp_nf4.py) with world size 1 to generate a reference checkpoint.
  2. Run ddp_nf4.py with world size 2 to generate test checkpoints.
  3. Checks that the params of (1) and (2) are all close.

Example output:

Step 1: Generating reference checkpoint...
torchrun --nproc_per_node 1 ddp_nf4.py --global_bs 8 --dim 128 --num_linears 1 --num_steps 3 --save_dir checkpoints/ref
[rank0]:  Dist initialized with world size 1
[rank0]:  Saved model to checkpoints/ref/ddp-0.pt
Cleaning up dist

 --- 

Step 2: Generating test checkpoints...
torchrun --nproc_per_node 2 ddp_nf4.py --global_bs 8 --dim 128 --num_linears 1 --num_steps 3 --save_dir checkpoints/test
W0209 09:55:37.887000 1372779 torch/distributed/run.py:792] 
W0209 09:55:37.887000 1372779 torch/distributed/run.py:792] *****************************************
W0209 09:55:37.887000 1372779 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0209 09:55:37.887000 1372779 torch/distributed/run.py:792] *****************************************
[rank0]:  Dist initialized with world size 2
[rank1]:  Dist initialized with world size 2
[rank0]:  Saved model to checkpoints/test/ddp-0.pt
[rank1]:  Saved model to checkpoints/test/ddp-1.pt
Cleaning up dist

 --- 

Step 3: Checking params...
python check_ddp_nf4.py --ref_checkpoint_dir checkpoints/ref --test_checkpoints_dir checkpoints/test
Ref checkpoint: checkpoints/ref/ddp-0.pt
Checking checkpoints/test/ddp-0.pt
Checking module.0.weight <class 'torchao.dtypes.nf4tensor.NF4Tensor'> <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
 ✓ Param module.0.weight is consistent
Checking module.0.lora_a.weight <class 'torch.Tensor'> <class 'torch.Tensor'>
 ✓ Param module.0.lora_a.weight is consistent
Checking module.0.lora_b.weight <class 'torch.Tensor'> <class 'torch.Tensor'>
 ✓ Param module.0.lora_b.weight is consistent
Passed!
Checking checkpoints/test/ddp-1.pt
Checking module.0.weight <class 'torchao.dtypes.nf4tensor.NF4Tensor'> <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
 ✓ Param module.0.weight is consistent
Checking module.0.lora_a.weight <class 'torch.Tensor'> <class 'torch.Tensor'>
 ✓ Param module.0.lora_a.weight is consistent
Checking module.0.lora_b.weight <class 'torch.Tensor'> <class 'torch.Tensor'>
 ✓ Param module.0.lora_b.weight is consistent
Passed!

 --- 

Step 4: Cleaning up...
rm -rf checkpoints

 --- 

Done!

ddp_nf4.py can be parametrized:

python ddp_nf4.py --help
usage: ddp_nf4.py [-h] [--global_bs GLOBAL_BS] [--dim DIM] [--num_linears NUM_LINEARS] [--seed SEED] [--device DEVICE] [--dtype DTYPE] [--num_steps NUM_STEPS] [--save_dir SAVE_DIR] [--compile] [--optimize_ddp OPTIMIZE_DDP]

options:
  -h, --help            show this help message and exit
  --global_bs GLOBAL_BS
  --dim DIM
  --num_linears NUM_LINEARS
  --seed SEED
  --device DEVICE
  --dtype DTYPE
  --num_steps NUM_STEPS
  --save_dir SAVE_DIR
  --compile
  --optimize_ddp OPTIMIZE_DDP

Copy link

pytorch-bot bot commented Feb 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1684

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 22bc211 with merge base c8eb8d3 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 9, 2025
@jeromeku jeromeku mentioned this pull request Feb 9, 2025
@jeromeku jeromeku added the topic: bug fix Use this tag for PRs that fix bugs label Feb 9, 2025
@psinger
Copy link

psinger commented Feb 12, 2025

Can confirm that this PR solves my issues reported in #1665

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NF4Tensor and DDP
3 participants