## 🐛 Bug Can't allocate random tensors with device='jax' It gives the following error message: jaxlib._jax.XlaRuntimeError: INVALID_ARGUMENT: Unable to replace a PyArray with a PyArray from a different client. ## To Reproduce as simple as: import torch import torchax torchax.enable_globally() torch.randn(3, 3, 28, 28, device='jax') ## Expected behavior random tensor should be created without crashes ## Environment - Reproducible on XLA backend [CPU/TPU/CUDA]: TPU - torch_xla version: 2.8.0.dev ## Additional context reverting this PR https://github.com/pytorch/xla/pull/9305/files helps