Skip to content

[Bug] Calling torch_xla.device() or setting random seed before use_spmd() produces SIGSEGV for unmarked tensors #9735

@Dogacel

Description

@Dogacel

🐛 Bug

Calling torch_xla.device() or setting random seed before use_spmd() produces a SIGSEGV for unmarked tensors.

To Reproduce

import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

from torch_xla.distributed.spmd import Mesh

xm.set_rng_state(42)

torch_xla.device()

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))

t = torch.randn(8192, 4096).to(torch_xla.device())

# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = ('data', 'model')
xs.mark_sharding(t, mesh, partition_spec)


x = torch.randn(1024, 8192).to(torch_xla.device())


mult = torch.matmul(x, t)
torch_xla.sync()

print("Result shape:", mult.shape)
print("total sum of result:", mult.sum().item())

Steps to reproduce the behavior:

  1. Comment out set_rng_state(...) or device() calls
  2. Run and observe it works
  3. Uncomment either one of them
*** SIGSEGV (@0x1f0), see go/stacktraces#s15 received by PID 55852 (TID 57319) on cpu 126; stack trace: ***
PC: @     0x71a5473605d9  (unknown)  std::_Function_handler<>::_M_invoke()
    @     0x71a4c9e9abc5       1904  (unknown)
    @     0x71a709042520       3184  (unknown)
    @     0x71a550c3c4de         32  std::_Function_handler<>::_M_invoke()
    @     0x71a547fb7072        320  Eigen::ThreadPoolDevice::parallelFor()
    @     0x71a550c402c5        608  tsl::thread::ThreadPool::ParallelFor()
    @     0x71a547e96b4d       1376  torch_xla::runtime::PjRtComputationClient::ExecuteReplicated()
    @     0x71a547c3ce75        816  torch_xla::XLAGraphExecutor::ScheduleSyncTensorsGraph()::{lambda()#1}::operator()()
    @     0x71a64003f4b8        192  torch::lazy::MultiWait::Complete()
    @     0x71a550c3c488         64  absl::lts_20250512::internal_any_invocable::RemoteInvoker<>()
    @     0x71a550c322c2         96  tsl::(anonymous namespace)::PThread::ThreadFn()
    @     0x71a709094ac3  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=71a5473605d9,71a4c9e9abc4,71a70904251f,71a550c3c4dd,71a547fb7071,71a550c402c4,71a547e96b4c,71a547c3ce74,71a64003f4b7,71a550c3c487,71a550c322c1,71a709094ac2&map=
E0120 20:49:43.195168   57319 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0120 20:49:43.195187   57319 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E0120 20:49:43.195191   57319 client.cc:270] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0120 20:49:43.195195   57319 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0120 20:49:43.195228   57319 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0120 20:49:43.195234   57319 coredump_hook.cc:457] RAW: Dumping core locally.
E0120 20:49:43.452738   57319 process_state.cc:808] RAW: Raising signal 11 with default behavior
Segmentation fault (core dumped)

Expected behavior

Either use_spmd() should raise an error to prevent weird errors moving forward or should handle the calls to device() or other methods more gracefully.

Environment

  • Reproducible on XLA backend [CPU/TPU]: Libtpu version: 0.0.21, Accelerator type: v6e, 8 chips 1 node.
  • torch_xla version: 2.9.0

Additional Details

It took me multiple days to understand this was caused by setting seed before calling use_spmd().

from accelerate.utils import set_seed
set_seed(42)

This calls xm.set_rng_state(seed) for XLA devices. I suspect the underlying torch_xla._XLAC._xla_get_default_device() call is causing it. Somehow some tensors end up in the virtual device, and some tensors end up in actual device.

When I tried to debug this issue, I ended up marking every tensor created as sharded later on. For example doing

x = torch.randn(1024, 8192).to(torch_xla.device())

xs.mark_sharding(x, mesh, partition_spec)

Fixes this issue. However, whenever the underlying library creates a new tensor which is not marked, it raises the issue. I inspected torch_xla._XLAC._get_xla_sharding_spec values on non-buggy and buggy versions. But they show up similarly.

Before sharding:
After sharding: {devices=[8,1]0,1,2,3,4,5,6,7}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions