Skip to content

"RuntimeError: !at::functionalization::impl::isFunctionalTensor(t)" when running a DTensor test with functionalization on #9472

@jeffhataws

Description

@jeffhataws

🐛 Bug

When running the new DTensor placement test test/spmd/test_dtensor_integration3.py with functionalization on (default), I get the following error:

======================================================================
ERROR: test_xla_placement (__main__.DTensorIntegrationTest3)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/test/spmd/test_dtensor_integration3.py", line 74, in test_xla_placement
    outputs_sharded = forward_pure(inputs, in_proj_weight, out_proj_weight)
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/test/spmd/test_dtensor_integration3.py", line 45, in forward_pure
    hidden = torch.matmul(hidden, in_proj_weight.T)
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/torch_xla/distributed/spmd/xla_sharded_tensor.py", line 195, in __torch_function__
    return super().__torch_function__(func, types, args, kwargs)
  File "/home/ubuntu/pytorch/torch/_tensor.py", line 1682, in __torch_function__
    ret = func(*args, **kwargs)
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/torch_xla/distributed/spmd/xla_sharded_tensor.py", line 190, in __torch_dispatch__
    func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
  File "/home/ubuntu/pytorch/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
RuntimeError: !at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED at "/home/ubuntu/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":838, please report a bug to PyTorch. The composite op functionalization fallback expects its inputs all not to be functional tensors

To Reproduce

Steps to reproduce the behavior:

python test/spmd/test_dtensor_integration3.py

Passes with:

XLA_DISABLE_FUNCTIONALIZATION=1 python test/spmd/test_dtensor_integration3.py

Expected behavior

No crash with functionalization on

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: all
  • torch_xla version: v2.8, v2.9

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdistributedSPMD and other distributed things.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions