-
Notifications
You must be signed in to change notification settings - Fork 554
Open
Labels
Description
🐛 Bug
With torch-xla v2.8, the Neuron team is getting "Check failed: state Expected an array shape." errors when running many training tests that uses reduce-scatter. These errors were not there in v2.7. Furthermore, I have narrowed it down to commit that updated openxla pin #9045 , because using the torch-xla nightly from 4/30 works.
I have also narrowed down the testcase to the existing test/test_mp_reduce_scatter.py as seen in the next section.
To Reproduce
On a CPU instance:
python -m venv test_venv_pt2.8
source test_venv_pt2.8/bin/activate
pip3 install -U pip
pip3 install torch torchvision --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install 'torch_xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
cd ~/
git clone https://github.com/pytorch/xla
cd ~/xla/test
Add 'CPU' to the device list in test_mp_reduce_scatter.py:
if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'NEURON', 'CPU']:
then run
PJRT_DEVICE=CPU python test_mp_reduce_scatter.py
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1749493093.448876 48267 shape.cc:166] Check failed: state Expected an array shape. Got (f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32], f32[32,2,32])
This is a programmer error. Please read the Shape object's array properties (e.g. dimensions) only when it's an array shape.
*** Check failure stack trace: ***
@ 0x7ea70c886e99 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7ea7017dc8c7 xla::Shape::array_state()
@ 0x7ea701a3b1ce torch_xla::BuildReduceScatterCoalesced()
@ 0x7ea701e3b264 std::_Function_handler<>::_M_invoke()
@ 0x7ea701ddf751 torch_xla::InferOutputShape()
@ 0x7ea701e3b4a1 std::_Function_handler<>::_M_invoke()
@ 0x7ea701e726df torch_xla::XlaNode::GetOpShape()
@ 0x7ea701e72fa9 torch_xla::XlaNode::XlaNode()
@ 0x7ea701e3c2ec torch_xla::ReduceScatterCoalesced::ReduceScatterCoalesced()
@ 0x7ea701ade1b7 torch_xla::MakeNode<>()
@ 0x7ea701ade432 torch_xla::tensor_methods::reduce_scatter_coalesced()
@ 0x7ea70194c02f torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#55}::operator()()
@ 0x7ea70196ca19 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
@ 0x7ea701948e38 pybind11::cpp_function::dispatcher()
@ 0x61fe15737e12 (unknown)
Traceback (most recent call last):
File "/home/ubuntu/xla/test/test_mp_reduce_scatter.py", line 180, in <module>
torch_xla.launch(_mp_fn, args=())
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 266, in launch
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 43, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
replica_results = list(
File "/home/ubuntu/test_venv/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
concurrent.futures.process.BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
Expected behavior
No crash when running the code with v2.7.
python -m venv test_venv_pt2.7
source test_venv_pt2.7/bin/activate
pip3 install -U pip
pip install torch torch-xla
cd ~/xla/test
PJRT_DEVICE=CPU python test_mp_reduce_scatter.py
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU, NEURON
- torch_xla version: 2.8 (TOT)