Skip to content

Commit aa9b9c9

Browse files
jeromekumsaroufim
andauthored
Fix DDP with nf4 (#1684)
* implement aten.cat.default for nf4 * add nf4 ddp tests * run ruff * add dtype check * formatting * run ruff format on nf4tensor --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent 22d7d51 commit aa9b9c9

File tree

4 files changed

+273
-0
lines changed

4 files changed

+273
-0
lines changed

test/dtypes/ddp/check_ddp_nf4.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import torch
5+
6+
from torchao.dtypes.nf4tensor import NF4Tensor
7+
8+
if __name__ == "__main__":
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("--ref_checkpoint_dir", type=str, required=True)
11+
parser.add_argument("--test_checkpoints_dir", type=str, required=True)
12+
13+
args = parser.parse_args()
14+
15+
ref_checkpoints = list(Path(args.ref_checkpoint_dir).glob("*.pt"))
16+
assert len(ref_checkpoints) == 1, "Expected exactly one reference checkpoint"
17+
ref_checkpoint = ref_checkpoints[0]
18+
ref_state_dict = torch.load(ref_checkpoint, weights_only=True, map_location="cpu")
19+
print(f"Ref checkpoint: {ref_checkpoint}")
20+
21+
for path in Path(args.test_checkpoints_dir).glob("*.pt"):
22+
print(f"Checking {path}")
23+
state_dict = torch.load(path, weights_only=True, map_location="cpu")
24+
assert ref_state_dict.keys() == state_dict.keys()
25+
for name in ref_state_dict.keys():
26+
ref_param = ref_state_dict[name]
27+
test_param = state_dict[name]
28+
print(f"Checking {name} {type(ref_param)} {type(test_param)}")
29+
30+
if isinstance(ref_param, NF4Tensor):
31+
ref_param = ref_param.get_original_weight()
32+
assert isinstance(test_param, NF4Tensor)
33+
test_param = test_param.get_original_weight()
34+
35+
if not torch.allclose(ref_param, test_param, atol=1e-4, rtol=1e-4):
36+
diff = (ref_param - test_param).abs().max()
37+
print(f" \u2718 Param {name} differs by {diff}")
38+
else:
39+
print(f" \u2713 Param {name} is consistent")
40+
print("Passed!")

test/dtypes/ddp/ddp_nf4.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import argparse
2+
import math
3+
import os
4+
import time
5+
from contextlib import contextmanager
6+
7+
import torch
8+
import torch.distributed as dist
9+
import torch.nn as nn
10+
from torch._dynamo import config as dynamo_config
11+
from torch.nn.parallel import DistributedDataParallel as DDP
12+
13+
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
14+
15+
16+
class LoRALinear(nn.Module):
17+
def __init__(
18+
self,
19+
hidden_dim: int,
20+
lora_rank: int = None,
21+
lora_alpha: float = 16,
22+
dtype: torch.dtype = torch.float32,
23+
):
24+
super().__init__()
25+
self.hidden_dim = hidden_dim
26+
if lora_rank is None:
27+
lora_rank = hidden_dim // 2
28+
29+
weight = torch.randn(hidden_dim, hidden_dim, dtype=dtype)
30+
self.lora_rank = lora_rank
31+
self.lora_alpha = lora_alpha
32+
self.register_parameter(
33+
"weight", nn.Parameter(to_nf4(weight), requires_grad=False)
34+
)
35+
self.lora_a = nn.Linear(
36+
in_features=hidden_dim, out_features=self.lora_rank, bias=False
37+
)
38+
self.lora_b = nn.Linear(
39+
in_features=self.lora_rank, out_features=hidden_dim, bias=False
40+
)
41+
self.initialize_parameters()
42+
43+
def initialize_parameters(self):
44+
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
45+
nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5))
46+
47+
def forward(self, x: torch.Tensor) -> torch.Tensor:
48+
out = linear_nf4(input=x, weight=self.weight)
49+
lora_out = self.lora_a(x)
50+
lora_out = (self.lora_alpha / self.lora_rank) * self.lora_b(lora_out)
51+
return out + lora_out
52+
53+
54+
def _init_model(dim, num_linears, device, dtype) -> nn.Module:
55+
with torch.device(device):
56+
modules = []
57+
for i in range(num_linears):
58+
modules += [LoRALinear(hidden_dim=dim, dtype=dtype)]
59+
seq = nn.Sequential(*modules)
60+
61+
return seq
62+
63+
64+
def dist_print(*args, delay=0.5):
65+
rank = dist.get_rank()
66+
time.sleep(delay * rank)
67+
print(f"[rank{rank}]: ", *args, flush=True)
68+
69+
70+
def make_batch(global_bs, dim, dtype, device):
71+
batch = torch.randn((global_bs, dim), dtype=dtype, device=device)
72+
if dist.get_world_size() > 1:
73+
batch = batch.chunk(dist.get_world_size(), dim=0)[dist.get_rank()]
74+
return batch
75+
76+
77+
def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, compile):
78+
os.makedirs(save_dir, exist_ok=True)
79+
model = _init_model(dim, num_linears, device, dtype)
80+
model = DDP(model, device_ids=[device])
81+
82+
if compile:
83+
model = torch.compile(model)
84+
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
85+
86+
losses = []
87+
88+
for i in range(num_steps):
89+
inp = make_batch(global_bs, dim, dtype, device)
90+
loss = model(inp).sum()
91+
losses.append(loss)
92+
loss.backward()
93+
optim.step()
94+
optim.zero_grad()
95+
96+
dist.barrier()
97+
98+
save_path = f"{save_dir}/ddp-{dist.get_rank()}.pt"
99+
torch.save(model.state_dict(), save_path)
100+
dist_print("Saved model to", save_path)
101+
102+
103+
def init_dist():
104+
dist.init_process_group(backend="nccl")
105+
torch.cuda.set_device(dist.get_rank())
106+
dist_print("Dist initialized with world size", dist.get_world_size())
107+
108+
109+
def cleanup_dist():
110+
dist.barrier()
111+
if dist.get_rank() == 0:
112+
print("Cleaning up dist")
113+
dist.destroy_process_group()
114+
115+
116+
@contextmanager
117+
def distributed_context():
118+
init_dist()
119+
yield
120+
cleanup_dist()
121+
122+
123+
if __name__ == "__main__":
124+
parser = argparse.ArgumentParser()
125+
126+
parser.add_argument("--global_bs", type=int, default=8)
127+
parser.add_argument("--dim", type=int, default=128)
128+
parser.add_argument("--num_linears", type=int, default=1)
129+
parser.add_argument("--seed", type=int, default=42)
130+
parser.add_argument("--device", type=str, default="cuda")
131+
parser.add_argument("--dtype", type=str, default="float32")
132+
parser.add_argument("--num_steps", type=int, default=3)
133+
parser.add_argument("--save_dir", type=str, default="checkpoints")
134+
parser.add_argument("--compile", action="store_true")
135+
parser.add_argument("--optimize_ddp", type=str, default="ddp_optimizer")
136+
args = parser.parse_args()
137+
138+
args.dtype = getattr(torch, args.dtype)
139+
dynamo_config.optimize_ddp = args.optimize_ddp
140+
141+
if args.optimize_ddp == "python_reducer":
142+
dynamo_config.compiled_autograd = True
143+
144+
with distributed_context():
145+
torch.manual_seed(args.seed)
146+
run_ddp(
147+
global_bs=args.global_bs,
148+
dim=args.dim,
149+
num_linears=args.num_linears,
150+
device=args.device,
151+
dtype=args.dtype,
152+
num_steps=args.num_steps,
153+
save_dir=args.save_dir,
154+
compile=args.compile,
155+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/bin/bash
2+
3+
set -euo pipefail
4+
WORLD_SIZE=${1:-2}
5+
6+
7+
# Test params
8+
GLOBAL_BS=8
9+
DIM=128
10+
NUM_LINEARS=1
11+
NUM_STEPS=3
12+
13+
PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS"
14+
SAVE_DIR="checkpoints"
15+
REF_DIR="${SAVE_DIR}/ref"
16+
TEST_DIR="${SAVE_DIR}/test"
17+
DDP_PROGRAM="ddp_nf4.py"
18+
CHECK_PROGRAM="check_ddp_nf4.py"
19+
REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR"
20+
TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR"
21+
CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR"
22+
CLEANUP_CMD="rm -rf $SAVE_DIR"
23+
24+
echo "Step 1: Generating reference checkpoint..."
25+
echo $REF_CMD
26+
$REF_CMD
27+
echo -e "\n --- \n"
28+
sleep 2
29+
30+
echo "Step 2: Generating test checkpoints..."
31+
echo $TEST_CMD
32+
$TEST_CMD
33+
echo -e "\n --- \n"
34+
sleep 2
35+
36+
# Check params
37+
echo "Step 3: Checking params..."
38+
echo $CHECK_CMD
39+
$CHECK_CMD
40+
echo -e "\n --- \n"
41+
sleep 2
42+
43+
# Cleanup
44+
echo "Step 4: Cleaning up..."
45+
echo $CLEANUP_CMD
46+
$CLEANUP_CMD
47+
echo -e "\n --- \n"
48+
echo "Done!"

torchao/dtypes/nf4tensor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,35 @@ def nf4_pin_memory(aten_op, args, kwargs=None):
423423
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
424424

425425

426+
@implements(
427+
[
428+
aten.cat.default,
429+
]
430+
)
431+
def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
432+
tensors_to_cat = args[0]
433+
assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat)
434+
remaining_args = args[1:]
435+
436+
ts = []
437+
for t in tensors_to_cat:
438+
assert isinstance(t, torch.Tensor)
439+
440+
if isinstance(t, NF4Tensor):
441+
ts.append(t.get_original_weight())
442+
else:
443+
ts.append(t)
444+
445+
dtype = ts[0].dtype
446+
assert all(t.dtype == dtype for t in ts)
447+
448+
if kwargs is None:
449+
kwargs = {}
450+
451+
tensors = aten_op(ts, *remaining_args, **kwargs)
452+
return tensors
453+
454+
426455
@dataclass(frozen=True)
427456
class SubclassTensorArgs:
428457
original_shape: torch.Size
@@ -1058,3 +1087,4 @@ def nf4_constructor(
10581087

10591088
if TORCH_VERSION_AT_LEAST_2_5:
10601089
torch.serialization.add_safe_globals([NF4Tensor])
1090+
torch.serialization.add_safe_globals([NF4Tensor])

0 commit comments

Comments
 (0)