Skip to content

Fix DDP with nf4 #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/dtypes/ddp/check_ddp_nf4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
from pathlib import Path

import torch

from torchao.dtypes.nf4tensor import NF4Tensor

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ref_checkpoint_dir", type=str, required=True)
parser.add_argument("--test_checkpoints_dir", type=str, required=True)

args = parser.parse_args()

ref_checkpoints = list(Path(args.ref_checkpoint_dir).glob("*.pt"))
assert len(ref_checkpoints) == 1, "Expected exactly one reference checkpoint"
ref_checkpoint = ref_checkpoints[0]
ref_state_dict = torch.load(ref_checkpoint, weights_only=True, map_location="cpu")
print(f"Ref checkpoint: {ref_checkpoint}")

for path in Path(args.test_checkpoints_dir).glob("*.pt"):
print(f"Checking {path}")
state_dict = torch.load(path, weights_only=True, map_location="cpu")
assert ref_state_dict.keys() == state_dict.keys()
for name in ref_state_dict.keys():
ref_param = ref_state_dict[name]
test_param = state_dict[name]
print(f"Checking {name} {type(ref_param)} {type(test_param)}")

if isinstance(ref_param, NF4Tensor):
ref_param = ref_param.get_original_weight()
assert isinstance(test_param, NF4Tensor)
test_param = test_param.get_original_weight()

if not torch.allclose(ref_param, test_param, atol=1e-4, rtol=1e-4):
diff = (ref_param - test_param).abs().max()
print(f" \u2718 Param {name} differs by {diff}")
else:
print(f" \u2713 Param {name} is consistent")
print("Passed!")
155 changes: 155 additions & 0 deletions test/dtypes/ddp/ddp_nf4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import argparse
import math
import os
import time
from contextlib import contextmanager

import torch
import torch.distributed as dist
import torch.nn as nn
from torch._dynamo import config as dynamo_config
from torch.nn.parallel import DistributedDataParallel as DDP

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4


class LoRALinear(nn.Module):
def __init__(
self,
hidden_dim: int,
lora_rank: int = None,
lora_alpha: float = 16,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.hidden_dim = hidden_dim
if lora_rank is None:
lora_rank = hidden_dim // 2

weight = torch.randn(hidden_dim, hidden_dim, dtype=dtype)
self.lora_rank = lora_rank
self.lora_alpha = lora_alpha
self.register_parameter(
"weight", nn.Parameter(to_nf4(weight), requires_grad=False)
)
self.lora_a = nn.Linear(
in_features=hidden_dim, out_features=self.lora_rank, bias=False
)
self.lora_b = nn.Linear(
in_features=self.lora_rank, out_features=hidden_dim, bias=False
)
self.initialize_parameters()

def initialize_parameters(self):
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5))

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = linear_nf4(input=x, weight=self.weight)
lora_out = self.lora_a(x)
lora_out = (self.lora_alpha / self.lora_rank) * self.lora_b(lora_out)
return out + lora_out


def _init_model(dim, num_linears, device, dtype) -> nn.Module:
with torch.device(device):
modules = []
for i in range(num_linears):
modules += [LoRALinear(hidden_dim=dim, dtype=dtype)]
seq = nn.Sequential(*modules)

return seq


def dist_print(*args, delay=0.5):
rank = dist.get_rank()
time.sleep(delay * rank)
print(f"[rank{rank}]: ", *args, flush=True)


def make_batch(global_bs, dim, dtype, device):
batch = torch.randn((global_bs, dim), dtype=dtype, device=device)
if dist.get_world_size() > 1:
batch = batch.chunk(dist.get_world_size(), dim=0)[dist.get_rank()]
return batch


def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, compile):
os.makedirs(save_dir, exist_ok=True)
model = _init_model(dim, num_linears, device, dtype)
model = DDP(model, device_ids=[device])

if compile:
model = torch.compile(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

losses = []

for i in range(num_steps):
inp = make_batch(global_bs, dim, dtype, device)
loss = model(inp).sum()
losses.append(loss)
loss.backward()
optim.step()
optim.zero_grad()

dist.barrier()

save_path = f"{save_dir}/ddp-{dist.get_rank()}.pt"
torch.save(model.state_dict(), save_path)
dist_print("Saved model to", save_path)


def init_dist():
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
dist_print("Dist initialized with world size", dist.get_world_size())


def cleanup_dist():
dist.barrier()
if dist.get_rank() == 0:
print("Cleaning up dist")
dist.destroy_process_group()


@contextmanager
def distributed_context():
init_dist()
yield
cleanup_dist()


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--global_bs", type=int, default=8)
parser.add_argument("--dim", type=int, default=128)
parser.add_argument("--num_linears", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="float32")
parser.add_argument("--num_steps", type=int, default=3)
parser.add_argument("--save_dir", type=str, default="checkpoints")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--optimize_ddp", type=str, default="ddp_optimizer")
args = parser.parse_args()

args.dtype = getattr(torch, args.dtype)
dynamo_config.optimize_ddp = args.optimize_ddp

if args.optimize_ddp == "python_reducer":
dynamo_config.compiled_autograd = True

with distributed_context():
torch.manual_seed(args.seed)
run_ddp(
global_bs=args.global_bs,
dim=args.dim,
num_linears=args.num_linears,
device=args.device,
dtype=args.dtype,
num_steps=args.num_steps,
save_dir=args.save_dir,
compile=args.compile,
)
48 changes: 48 additions & 0 deletions test/dtypes/ddp/run_ddp_nf4_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash

set -euo pipefail
WORLD_SIZE=${1:-2}


# Test params
GLOBAL_BS=8
DIM=128
NUM_LINEARS=1
NUM_STEPS=3

PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS"
SAVE_DIR="checkpoints"
REF_DIR="${SAVE_DIR}/ref"
TEST_DIR="${SAVE_DIR}/test"
DDP_PROGRAM="ddp_nf4.py"
CHECK_PROGRAM="check_ddp_nf4.py"
REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR"
TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR"
CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR"
CLEANUP_CMD="rm -rf $SAVE_DIR"

echo "Step 1: Generating reference checkpoint..."
echo $REF_CMD
$REF_CMD
echo -e "\n --- \n"
sleep 2

echo "Step 2: Generating test checkpoints..."
echo $TEST_CMD
$TEST_CMD
echo -e "\n --- \n"
sleep 2

# Check params
echo "Step 3: Checking params..."
echo $CHECK_CMD
$CHECK_CMD
echo -e "\n --- \n"
sleep 2

# Cleanup
echo "Step 4: Cleaning up..."
echo $CLEANUP_CMD
$CLEANUP_CMD
echo -e "\n --- \n"
echo "Done!"
30 changes: 30 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,35 @@ def nf4_pin_memory(aten_op, args, kwargs=None):
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))


@implements(
[
aten.cat.default,
]
)
def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None):
tensors_to_cat = args[0]
assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat)
remaining_args = args[1:]

ts = []
for t in tensors_to_cat:
assert isinstance(t, torch.Tensor)

if isinstance(t, NF4Tensor):
ts.append(t.get_original_weight())
else:
ts.append(t)

dtype = ts[0].dtype
assert all(t.dtype == dtype for t in ts)

if kwargs is None:
kwargs = {}

tensors = aten_op(ts, *remaining_args, **kwargs)
return tensors


@dataclass(frozen=True)
class SubclassTensorArgs:
original_shape: torch.Size
Expand Down Expand Up @@ -1058,3 +1087,4 @@ def nf4_constructor(

if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([NF4Tensor])
torch.serialization.add_safe_globals([NF4Tensor])
Loading