Skip to content

NF4Tensor and DDP #1665

Closed
Closed
@psinger

Description

@psinger

I am trying to use NF4Tensor weights in my model and wrap it with DistributedDataParallel, but get the following error:

[rank0]:     model = DistributedDataParallel(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 837, in __init__
[rank0]:     _sync_module_states(
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 313, in _sync_module_states
[rank0]:     _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 324, in _sync_params_and_buffers
[rank0]:     dist._broadcast_coalesced(
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/path/to/venv/lib/python3.12/site-packages/torchao/dtypes/nf4tensor.py", line 834, in __torch_dispatch__
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: NF4Tensor dispatch: attempting to run aten.cat.default, this is not supported

To replicate:

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torch.nn.parallel import DistributedDataParallel
from torch import nn
import os
import torch

class NF4(nn.Module):
    
    def __init__(
        self,
        device = None,
    ):
        super().__init__()

        self.linear = nn.Linear(512, 512, bias=False, device=device)
        self.linear.weight = nn.Parameter(to_nf4(self.linear.weight))


if __name__ == "__main__":
    
    _local_rank = int(os.getenv("LOCAL_RANK", "0"))
    _device = f"cuda:{_local_rank}"

    torch.distributed.init_process_group(
        backend="nccl",
        init_method="env://",
        device_id=torch.device(_local_rank),
    )

    model = NF4(_device)

    model = DistributedDataParallel(model)

torchrun --nproc_per_node=2 script.py

NotImplementedError: NF4Tensor dispatch: attempting to run c10d.broadcast_.default, this is not supported

Is there some way around this issue?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions