Skip to content

[torch-xla v2.8] Error "Check failed: state Expected an array shape." when running test/test_mp_reduce_scatter.py #9314

@jeffhataws

Description

@jeffhataws

🐛 Bug

With torch-xla v2.8, the Neuron team is getting "Check failed: state Expected an array shape." errors when running many training tests that uses reduce-scatter. These errors were not there in v2.7. Furthermore, I have narrowed it down to commit that updated openxla pin #9045 , because using the torch-xla nightly from 4/30 works.

I have also narrowed down the testcase to the existing test/test_mp_reduce_scatter.py as seen in the next section.

To Reproduce

On a CPU instance:

python -m venv test_venv_pt2.8
source test_venv_pt2.8/bin/activate
pip3 install -U pip
pip3 install torch torchvision --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install 'torch_xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
cd ~/
git clone https://github.com/pytorch/xla
cd ~/xla/test

Add 'CPU' to the device list in test_mp_reduce_scatter.py:

   if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'NEURON', 'CPU']:

then run

PJRT_DEVICE=CPU python test_mp_reduce_scatter.py
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR                                                      
F0000 00:00:1749493093.448876   48267 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
*** Check failure stack trace: ***                                                                                                                                           
    @     0x7ea70c886e99  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7ea7017dc8c7  xla::Shape::array_state()
    @     0x7ea701a3b1ce  torch_xla::BuildReduceScatterCoalesced()
    @     0x7ea701e3b264  std::_Function_handler<>::_M_invoke()
    @     0x7ea701ddf751  torch_xla::InferOutputShape()
    @     0x7ea701e3b4a1  std::_Function_handler<>::_M_invoke()
    @     0x7ea701e726df  torch_xla::XlaNode::GetOpShape()
    @     0x7ea701e72fa9  torch_xla::XlaNode::XlaNode()
    @     0x7ea701e3c2ec  torch_xla::ReduceScatterCoalesced::ReduceScatterCoalesced()
    @     0x7ea701ade1b7  torch_xla::MakeNode<>()
    @     0x7ea701ade432  torch_xla::tensor_methods::reduce_scatter_coalesced()
    @     0x7ea70194c02f  torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#55}::operator()()
    @     0x7ea70196ca19  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() 
    @     0x7ea701948e38  pybind11::cpp_function::dispatcher()
    @     0x61fe15737e12  (unknown)
Traceback (most recent call last):
  File "/home/ubuntu/xla/test/test_mp_reduce_scatter.py", line 180, in <module>
    torch_xla.launch(_mp_fn, args=())
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 266, in launch
    xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator 
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

Expected behavior

No crash when running the code with v2.7.

python -m venv test_venv_pt2.7
source test_venv_pt2.7/bin/activate
pip3 install -U pip
pip install torch torch-xla
cd ~/xla/test
PJRT_DEVICE=CPU python test_mp_reduce_scatter.py

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU, NEURON
  • torch_xla version: 2.8 (TOT)

Additional context

Metadata

Metadata

Labels

bugSomething isn't workingxla:cpu

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions