The test does not pass with a recent torch_tpu pin. (Note that this does not repro at HEAD yet; I'm in the process of updating the torch_tpu pin and want to reference this bug in doing so).
$ HELION_PRINT_OUTPUT_CODE=1 HELION_BACKEND=pallas pytest -v test/test_atomic_ops.py -k atomic_add_w_tile_attr -s
============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -- /mnt/disks/workspace/src/miniconda3/envs/pt312/bin/python3.12
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /mnt/disks/workspace/src/helion
configfile: pyproject.toml
plugins: jaxtyping-0.3.2, hypothesis-6.151.9
collecting ... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1778066591.864828 401020 pjrt_api.cc:96] PJRT_Api is set for device type cpu
I0000 00:00:1778066591.893021 401020 pjrt_api.cc:118] GetPjrtApi was found for TPU at /mnt/disks/workspace/src/miniconda3/envs/pt312/lib/python3.12/site-packages/libtpu/libtpu.so
I0000 00:00:1778066591.893054 401020 pjrt_api.cc:96] PJRT_Api is set for device type tpu
I0000 00:00:1778066593.091694 401020 device_rt.cc:86] PjRt runtime initialization deferred for tpu
Successfully renamed PrivateUse1 backend to 'tpu'. Device: device(type='tpu')
Registered Python module for 'tpu'.
Device type: tpu, Device index: default
Initializing TPU distributed runtime
collected 28 items / 27 deselected / 1 selected
test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr I0000 00:00:1778066593.768334 401147 pjrt_api.cc:167] The PJRT plugin has PJRT API version 0.103. The framework PJRT API version is 0.104.
I0000 00:00:1778066597.037861 401147 pjrt_c_api_client.cc:197] PjRtCApiClient created.
W0000 00:00:1778066597.037896 401147 pjrt_state.cc:225] Only using 1 device out of all the 8 addressable devices.
I0000 00:00:1778066597.053938 401147 tier2_compilation_cache.cc:165] Tier-2 compilation cache is disabled as requested by the TORCH_TPU_INTERNAL_TIER2_COMPILATION_CACHE environment variable.
I0000 00:00:1778066597.053959 401147 tier3_compilation_cache.cc:65] Backup compilation for tier-3 cache read is disabled as tier-3 cache is disabled.
# Output code written to: /tmp/tmpbtup_5f4/n2/cn2gkud6p2ix7nz4tb4n32opsdz5vr4ynaq5jyunuda256d6rt3q.py
from __future__ import annotations
import torch
import jax.numpy as jnp
from jax.experimental import pallas as pl
import jax.lax as lax
from helion.runtime import default_pallas_launcher as _default_pallas_launcher
_BLOCK_SIZE_0 = int(2)
def _helion_atomic_add_w_tile_attr(y):
# src[test_atomic_ops.py:89]: for tile in hl.tile(x.size(0)):
pid_0 = pl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
# src[test_atomic_ops.py:90]: hl.atomic_add(y, [tile.begin], 1)
_prev = y[offset_0]
y[offset_0] = lax.convert_element_type(_prev + 1, jnp.int32)
def atomic_add_w_tile_attr(x: torch.Tensor, *, _launcher=_default_pallas_launcher):
"""Test atomic_add where the index is a symbolic int"""
# src[test_atomic_ops.py:88]: y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
# src[test_atomic_ops.py:89]: for tile in hl.tile(x.size(0)):
_BLOCK_SIZE_0 = 2
# src[test_atomic_ops.py:89]: for tile in hl.tile(x.size(0)):
# src[test_atomic_ops.py:90]: hl.atomic_add(y, [tile.begin], 1)
_launcher(_helion_atomic_add_w_tile_attr, ((20 + _BLOCK_SIZE_0 - 1) // _BLOCK_SIZE_0,), y, _output_indices=[0], _inplace_indices=[0], _block_spec_info=[((None,), (None,))], _smem_arg_indices=[0])
# src[test_atomic_ops.py:91]: return y
return y
def call():
from torch._dynamo.testing import rand_strided
# src[test_atomic_ops.py:86]: def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor:
# src[test_atomic_ops.py:87]: """Test atomic_add where the index is a symbolic int"""
# src[test_atomic_ops.py:88]: y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
# src[test_atomic_ops.py:86-91]: ...
x = rand_strided(size=(20,), stride=(1,), dtype=torch.float32, device='tpu:0')
atomic_add_w_tile_attr(x)
if __name__ == '__main__':
call()
FAILED
=================================== FAILURES ===================================
_______________ TestAtomicOperations.test_atomic_add_w_tile_attr _______________
self = <test.test_atomic_ops.TestAtomicOperations testMethod=test_atomic_add_w_tile_attr>
@skipIfRefEager(
"Test is block size dependent which is not supported in ref eager mode"
)
def test_atomic_add_w_tile_attr(self):
"""Test atomic_add where the index is a symbolic int"""
x = torch.randn(20, device=DEVICE)
code, result = code_and_output(
atomic_add_w_tile_attr,
(x,),
block_sizes=[2],
)
expected = torch.tensor([1, 0], device=DEVICE, dtype=torch.int32).repeat(10)
> torch.testing.assert_close(result, expected)
E AssertionError: Tensor-likes are not equal!
E
E Mismatched elements: 1 / 20 (5.0%)
E Greatest absolute difference: 1376286904 at index (0,)
E Greatest relative difference: 1376286848.0 at index (0,)
test/test_atomic_ops.py:519: AssertionError
=============================== warnings summary ===============================
../pytorch/torch/jit/_script.py:365: 14 warnings
/mnt/disks/workspace/src/pytorch/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
warnings.warn(
test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr
/mnt/disks/workspace/src/helion/helion/_compiler/compile_environment.py:144: UserWarning: The 'pallas' backend is experimental and may have limited functionality.
warn_once(
test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr
/mnt/disks/workspace/src/helion/helion/runtime/__init__.py:657: DeprecationWarning: input_output_aliases is deprecated and will be removed soon. Please use donate_argnums instead.
jax_callable = JaxCallable(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr
================ 1 failed, 27 deselected, 16 warnings in 6.91s =================
I0000 00:00:1778066599.255886 401020 compilation_cache.cc:328] Compilation cache evicted.
I0000 00:00:1778066599.255917 401020 compilation_cache.cc:193] CompilationCache final stats: num_cache_reqs=26
num_cache_hits=2 {7.6%}
AFAICT the generated kernel is reading garbage values and updating them.
The test does not pass with a recent torch_tpu pin. (Note that this does not repro at HEAD yet; I'm in the process of updating the torch_tpu pin and want to reference this bug in doing so).
Repro:
version, e.g. after
b6053a4d Implement synchronizeDevice in torch_tpu backend. I bisected the problem to #2051. After that PR, the generated Pallas program seems to be reading garbage values that then increments.Test output: