diff --git a/test/dtypes/ddp/check_ddp_nf4.py b/test/dtypes/ddp/check_ddp_nf4.py new file mode 100644 index 0000000000..608bcb9c02 --- /dev/null +++ b/test/dtypes/ddp/check_ddp_nf4.py @@ -0,0 +1,40 @@ +import argparse +from pathlib import Path + +import torch + +from torchao.dtypes.nf4tensor import NF4Tensor + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ref_checkpoint_dir", type=str, required=True) + parser.add_argument("--test_checkpoints_dir", type=str, required=True) + + args = parser.parse_args() + + ref_checkpoints = list(Path(args.ref_checkpoint_dir).glob("*.pt")) + assert len(ref_checkpoints) == 1, "Expected exactly one reference checkpoint" + ref_checkpoint = ref_checkpoints[0] + ref_state_dict = torch.load(ref_checkpoint, weights_only=True, map_location="cpu") + print(f"Ref checkpoint: {ref_checkpoint}") + + for path in Path(args.test_checkpoints_dir).glob("*.pt"): + print(f"Checking {path}") + state_dict = torch.load(path, weights_only=True, map_location="cpu") + assert ref_state_dict.keys() == state_dict.keys() + for name in ref_state_dict.keys(): + ref_param = ref_state_dict[name] + test_param = state_dict[name] + print(f"Checking {name} {type(ref_param)} {type(test_param)}") + + if isinstance(ref_param, NF4Tensor): + ref_param = ref_param.get_original_weight() + assert isinstance(test_param, NF4Tensor) + test_param = test_param.get_original_weight() + + if not torch.allclose(ref_param, test_param, atol=1e-4, rtol=1e-4): + diff = (ref_param - test_param).abs().max() + print(f" \u2718 Param {name} differs by {diff}") + else: + print(f" \u2713 Param {name} is consistent") + print("Passed!") diff --git a/test/dtypes/ddp/ddp_nf4.py b/test/dtypes/ddp/ddp_nf4.py new file mode 100644 index 0000000000..e38d0015b1 --- /dev/null +++ b/test/dtypes/ddp/ddp_nf4.py @@ -0,0 +1,155 @@ +import argparse +import math +import os +import time +from contextlib import contextmanager + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._dynamo import config as dynamo_config +from torch.nn.parallel import DistributedDataParallel as DDP + +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + +class LoRALinear(nn.Module): + def __init__( + self, + hidden_dim: int, + lora_rank: int = None, + lora_alpha: float = 16, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.hidden_dim = hidden_dim + if lora_rank is None: + lora_rank = hidden_dim // 2 + + weight = torch.randn(hidden_dim, hidden_dim, dtype=dtype) + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.register_parameter( + "weight", nn.Parameter(to_nf4(weight), requires_grad=False) + ) + self.lora_a = nn.Linear( + in_features=hidden_dim, out_features=self.lora_rank, bias=False + ) + self.lora_b = nn.Linear( + in_features=self.lora_rank, out_features=hidden_dim, bias=False + ) + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = linear_nf4(input=x, weight=self.weight) + lora_out = self.lora_a(x) + lora_out = (self.lora_alpha / self.lora_rank) * self.lora_b(lora_out) + return out + lora_out + + +def _init_model(dim, num_linears, device, dtype) -> nn.Module: + with torch.device(device): + modules = [] + for i in range(num_linears): + modules += [LoRALinear(hidden_dim=dim, dtype=dtype)] + seq = nn.Sequential(*modules) + + return seq + + +def dist_print(*args, delay=0.5): + rank = dist.get_rank() + time.sleep(delay * rank) + print(f"[rank{rank}]: ", *args, flush=True) + + +def make_batch(global_bs, dim, dtype, device): + batch = torch.randn((global_bs, dim), dtype=dtype, device=device) + if dist.get_world_size() > 1: + batch = batch.chunk(dist.get_world_size(), dim=0)[dist.get_rank()] + return batch + + +def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, compile): + os.makedirs(save_dir, exist_ok=True) + model = _init_model(dim, num_linears, device, dtype) + model = DDP(model, device_ids=[device]) + + if compile: + model = torch.compile(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + + losses = [] + + for i in range(num_steps): + inp = make_batch(global_bs, dim, dtype, device) + loss = model(inp).sum() + losses.append(loss) + loss.backward() + optim.step() + optim.zero_grad() + + dist.barrier() + + save_path = f"{save_dir}/ddp-{dist.get_rank()}.pt" + torch.save(model.state_dict(), save_path) + dist_print("Saved model to", save_path) + + +def init_dist(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) + dist_print("Dist initialized with world size", dist.get_world_size()) + + +def cleanup_dist(): + dist.barrier() + if dist.get_rank() == 0: + print("Cleaning up dist") + dist.destroy_process_group() + + +@contextmanager +def distributed_context(): + init_dist() + yield + cleanup_dist() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--global_bs", type=int, default=8) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--num_linears", type=int, default=1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--num_steps", type=int, default=3) + parser.add_argument("--save_dir", type=str, default="checkpoints") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--optimize_ddp", type=str, default="ddp_optimizer") + args = parser.parse_args() + + args.dtype = getattr(torch, args.dtype) + dynamo_config.optimize_ddp = args.optimize_ddp + + if args.optimize_ddp == "python_reducer": + dynamo_config.compiled_autograd = True + + with distributed_context(): + torch.manual_seed(args.seed) + run_ddp( + global_bs=args.global_bs, + dim=args.dim, + num_linears=args.num_linears, + device=args.device, + dtype=args.dtype, + num_steps=args.num_steps, + save_dir=args.save_dir, + compile=args.compile, + ) diff --git a/test/dtypes/ddp/run_ddp_nf4_test.sh b/test/dtypes/ddp/run_ddp_nf4_test.sh new file mode 100755 index 0000000000..b9a3c2929f --- /dev/null +++ b/test/dtypes/ddp/run_ddp_nf4_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -euo pipefail +WORLD_SIZE=${1:-2} + + +# Test params +GLOBAL_BS=8 +DIM=128 +NUM_LINEARS=1 +NUM_STEPS=3 + +PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS" +SAVE_DIR="checkpoints" +REF_DIR="${SAVE_DIR}/ref" +TEST_DIR="${SAVE_DIR}/test" +DDP_PROGRAM="ddp_nf4.py" +CHECK_PROGRAM="check_ddp_nf4.py" +REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR" +TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR" +CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR" +CLEANUP_CMD="rm -rf $SAVE_DIR" + +echo "Step 1: Generating reference checkpoint..." +echo $REF_CMD +$REF_CMD +echo -e "\n --- \n" +sleep 2 + +echo "Step 2: Generating test checkpoints..." +echo $TEST_CMD +$TEST_CMD +echo -e "\n --- \n" +sleep 2 + +# Check params +echo "Step 3: Checking params..." +echo $CHECK_CMD +$CHECK_CMD +echo -e "\n --- \n" +sleep 2 + +# Cleanup +echo "Step 4: Cleaning up..." +echo $CLEANUP_CMD +$CLEANUP_CMD +echo -e "\n --- \n" +echo "Done!" diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 5ae06a1fe1..457cf352fa 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -423,6 +423,35 @@ def nf4_pin_memory(aten_op, args, kwargs=None): return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) +@implements( + [ + aten.cat.default, + ] +) +def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): + tensors_to_cat = args[0] + assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat) + remaining_args = args[1:] + + ts = [] + for t in tensors_to_cat: + assert isinstance(t, torch.Tensor) + + if isinstance(t, NF4Tensor): + ts.append(t.get_original_weight()) + else: + ts.append(t) + + dtype = ts[0].dtype + assert all(t.dtype == dtype for t in ts) + + if kwargs is None: + kwargs = {} + + tensors = aten_op(ts, *remaining_args, **kwargs) + return tensors + + @dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size @@ -1058,3 +1087,4 @@ def nf4_constructor( if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([NF4Tensor]) + torch.serialization.add_safe_globals([NF4Tensor])