-
Notifications
You must be signed in to change notification settings - Fork 564
Description
🐛 Bug
Calling torch_xla.device() or setting random seed before use_spmd() produces a SIGSEGV for unmarked tensors.
To Reproduce
import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import Mesh
xm.set_rng_state(42)
torch_xla.device()
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
t = torch.randn(8192, 4096).to(torch_xla.device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)
x = torch.randn(1024, 8192).to(torch_xla.device())
mult = torch.matmul(x, t)
torch_xla.sync()
print("Result shape:", mult.shape)
print("total sum of result:", mult.sum().item())Steps to reproduce the behavior:
- Comment out
set_rng_state(...)ordevice()calls - Run and observe it works
- Uncomment either one of them
*** SIGSEGV (@0x1f0), see go/stacktraces#s15 received by PID 55852 (TID 57319) on cpu 126; stack trace: ***
PC: @ 0x71a5473605d9 (unknown) std::_Function_handler<>::_M_invoke()
@ 0x71a4c9e9abc5 1904 (unknown)
@ 0x71a709042520 3184 (unknown)
@ 0x71a550c3c4de 32 std::_Function_handler<>::_M_invoke()
@ 0x71a547fb7072 320 Eigen::ThreadPoolDevice::parallelFor()
@ 0x71a550c402c5 608 tsl::thread::ThreadPool::ParallelFor()
@ 0x71a547e96b4d 1376 torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
@ 0x71a547c3ce75 816 torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
@ 0x71a64003f4b8 192 torch::lazy::MultiWait::Complete()
@ 0x71a550c3c488 64 absl::lts_20250512::internal_any_invocable::RemoteInvoker<>()
@ 0x71a550c322c2 96 tsl::(anonymous namespace)::PThread::ThreadFn()
@ 0x71a709094ac3 (unknown) (unknown)
https://symbolize.stripped_domain/r/?trace=71a5473605d9,71a4c9e9abc4,71a70904251f,71a550c3c4dd,71a547fb7071,71a550c402c4,71a547e96b4c,71a547c3ce74,71a64003f4b7,71a550c3c487,71a550c322c1,71a709094ac2&map=
E0120 20:49:43.195168 57319 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0120 20:49:43.195187 57319 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E0120 20:49:43.195191 57319 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0120 20:49:43.195195 57319 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0120 20:49:43.195228 57319 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0120 20:49:43.195234 57319 coredump_hook.cc:457] RAW: Dumping core locally.
E0120 20:49:43.452738 57319 process_state.cc:808] RAW: Raising signal 11 with default behavior
Segmentation fault (core dumped)
Expected behavior
Either use_spmd() should raise an error to prevent weird errors moving forward or should handle the calls to device() or other methods more gracefully.
Environment
- Reproducible on XLA backend [CPU/TPU]: Libtpu version: 0.0.21, Accelerator type: v6e, 8 chips 1 node.
- torch_xla version: 2.9.0
Additional Details
It took me multiple days to understand this was caused by setting seed before calling use_spmd().
from accelerate.utils import set_seed
set_seed(42)This calls xm.set_rng_state(seed) for XLA devices. I suspect the underlying torch_xla._XLAC._xla_get_default_device() call is causing it. Somehow some tensors end up in the virtual device, and some tensors end up in actual device.
When I tried to debug this issue, I ended up marking every tensor created as sharded later on. For example doing
x = torch.randn(1024, 8192).to(torch_xla.device())
xs.mark_sharding(x, mesh, partition_spec)Fixes this issue. However, whenever the underlying library creates a new tensor which is not marked, it raises the issue. I inspected torch_xla._XLAC._get_xla_sharding_spec values on non-buggy and buggy versions. But they show up similarly.
Before sharding:
After sharding: {devices=[8,1]0,1,2,3,4,5,6,7}