Closed
Description
I am trying to use NF4Tensor
weights in my model and wrap it with DistributedDataParallel
, but get the following error:
[rank0]: model = DistributedDataParallel(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 837, in __init__
[rank0]: _sync_module_states(
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 313, in _sync_module_states
[rank0]: _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/distributed/utils.py", line 324, in _sync_params_and_buffers
[rank0]: dist._broadcast_coalesced(
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/path/to/venv/lib/python3.12/site-packages/torchao/dtypes/nf4tensor.py", line 834, in __torch_dispatch__
[rank0]: raise NotImplementedError(
[rank0]: NotImplementedError: NF4Tensor dispatch: attempting to run aten.cat.default, this is not supported
To replicate:
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torch.nn.parallel import DistributedDataParallel
from torch import nn
import os
import torch
class NF4(nn.Module):
def __init__(
self,
device = None,
):
super().__init__()
self.linear = nn.Linear(512, 512, bias=False, device=device)
self.linear.weight = nn.Parameter(to_nf4(self.linear.weight))
if __name__ == "__main__":
_local_rank = int(os.getenv("LOCAL_RANK", "0"))
_device = f"cuda:{_local_rank}"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
device_id=torch.device(_local_rank),
)
model = NF4(_device)
model = DistributedDataParallel(model)
torchrun --nproc_per_node=2 script.py
NotImplementedError: NF4Tensor dispatch: attempting to run c10d.broadcast_.default, this is not supported
Is there some way around this issue?