Skip to content

[Pallas] test_atomic_add_w_tile_attr broken with recent torch_tpu #2304

@cota

Description

@cota

The test does not pass with a recent torch_tpu pin. (Note that this does not repro at HEAD yet; I'm in the process of updating the torch_tpu pin and want to reference this bug in doing so).

Repro:

$ HELION_PRINT_OUTPUT_CODE=1 HELION_BACKEND=pallas pytest  -v test/test_atomic_ops.py  -k atomic_add_w_tile_attr  -s

============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -- /mnt/disks/workspace/src/miniconda3/envs/pt312/bin/python3.12
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /mnt/disks/workspace/src/helion
configfile: pyproject.toml
plugins: jaxtyping-0.3.2, hypothesis-6.151.9
collecting ... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1778066591.864828  401020 pjrt_api.cc:96] PJRT_Api is set for device type cpu
I0000 00:00:1778066591.893021  401020 pjrt_api.cc:118] GetPjrtApi was found for TPU at /mnt/disks/workspace/src/miniconda3/envs/pt312/lib/python3.12/site-packages/libtpu/libtpu.so
I0000 00:00:1778066591.893054  401020 pjrt_api.cc:96] PJRT_Api is set for device type tpu
I0000 00:00:1778066593.091694  401020 device_rt.cc:86] PjRt runtime initialization deferred for tpu
Successfully renamed PrivateUse1 backend to 'tpu'. Device: device(type='tpu')
Registered Python module for 'tpu'.
Device type: tpu, Device index: default
Initializing TPU distributed runtime
collected 28 items / 27 deselected / 1 selected

test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr I0000 00:00:1778066593.768334  401147 pjrt_api.cc:167] The PJRT plugin has PJRT API version 0.103. The framework PJRT API version is 0.104.
I0000 00:00:1778066597.037861  401147 pjrt_c_api_client.cc:197] PjRtCApiClient created.
W0000 00:00:1778066597.037896  401147 pjrt_state.cc:225] Only using 1 device out of all the 8 addressable devices.
I0000 00:00:1778066597.053938  401147 tier2_compilation_cache.cc:165] Tier-2 compilation cache is disabled as requested by the TORCH_TPU_INTERNAL_TIER2_COMPILATION_CACHE environment variable.
I0000 00:00:1778066597.053959  401147 tier3_compilation_cache.cc:65] Backup compilation for tier-3 cache read is disabled as tier-3 cache is disabled.
# Output code written to: /tmp/tmpbtup_5f4/n2/cn2gkud6p2ix7nz4tb4n32opsdz5vr4ynaq5jyunuda256d6rt3q.py
from __future__ import annotations

import torch
import jax.numpy as jnp
from jax.experimental import pallas as pl
import jax.lax as lax
from helion.runtime import default_pallas_launcher as _default_pallas_launcher

_BLOCK_SIZE_0 = int(2)

def _helion_atomic_add_w_tile_attr(y):
    # src[test_atomic_ops.py:89]: for tile in hl.tile(x.size(0)):
    pid_0 = pl.program_id(0)
    offset_0 = pid_0 * _BLOCK_SIZE_0
    # src[test_atomic_ops.py:90]: hl.atomic_add(y, [tile.begin], 1)
    _prev = y[offset_0]
    y[offset_0] = lax.convert_element_type(_prev + 1, jnp.int32)

def atomic_add_w_tile_attr(x: torch.Tensor, *, _launcher=_default_pallas_launcher):
    """Test atomic_add where the index is a symbolic int"""
    # src[test_atomic_ops.py:88]: y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
    y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
    # src[test_atomic_ops.py:89]: for tile in hl.tile(x.size(0)):
    _BLOCK_SIZE_0 = 2
    # src[test_atomic_ops.py:89]: for tile in hl.tile(x.size(0)):
    # src[test_atomic_ops.py:90]:     hl.atomic_add(y, [tile.begin], 1)
    _launcher(_helion_atomic_add_w_tile_attr, ((20 + _BLOCK_SIZE_0 - 1) // _BLOCK_SIZE_0,), y, _output_indices=[0], _inplace_indices=[0], _block_spec_info=[((None,), (None,))], _smem_arg_indices=[0])
    # src[test_atomic_ops.py:91]: return y
    return y

def call():
    from torch._dynamo.testing import rand_strided
    # src[test_atomic_ops.py:86]: def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor:
    # src[test_atomic_ops.py:87]:     """Test atomic_add where the index is a symbolic int"""
    # src[test_atomic_ops.py:88]:     y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
    # src[test_atomic_ops.py:86-91]: ...
    x = rand_strided(size=(20,), stride=(1,), dtype=torch.float32, device='tpu:0')
    atomic_add_w_tile_attr(x)
if __name__ == '__main__':
    call()
FAILED

=================================== FAILURES ===================================
_______________ TestAtomicOperations.test_atomic_add_w_tile_attr _______________

self = <test.test_atomic_ops.TestAtomicOperations testMethod=test_atomic_add_w_tile_attr>

    @skipIfRefEager(
        "Test is block size dependent which is not supported in ref eager mode"
    )
    def test_atomic_add_w_tile_attr(self):
        """Test atomic_add where the index is a symbolic int"""
        x = torch.randn(20, device=DEVICE)
        code, result = code_and_output(
            atomic_add_w_tile_attr,
            (x,),
            block_sizes=[2],
        )
    
        expected = torch.tensor([1, 0], device=DEVICE, dtype=torch.int32).repeat(10)
>       torch.testing.assert_close(result, expected)
E       AssertionError: Tensor-likes are not equal!
E       
E       Mismatched elements: 1 / 20 (5.0%)
E       Greatest absolute difference: 1376286904 at index (0,)
E       Greatest relative difference: 1376286848.0 at index (0,)

test/test_atomic_ops.py:519: AssertionError
=============================== warnings summary ===============================
../pytorch/torch/jit/_script.py:365: 14 warnings
  /mnt/disks/workspace/src/pytorch/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr
  /mnt/disks/workspace/src/helion/helion/_compiler/compile_environment.py:144: UserWarning: The 'pallas' backend is experimental and may have limited functionality.
    warn_once(

test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr
  /mnt/disks/workspace/src/helion/helion/runtime/__init__.py:657: DeprecationWarning: input_output_aliases is deprecated and will be removed soon. Please use donate_argnums instead.
    jax_callable = JaxCallable(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED test/test_atomic_ops.py::TestAtomicOperations::test_atomic_add_w_tile_attr
================ 1 failed, 27 deselected, 16 warnings in 6.91s =================
I0000 00:00:1778066599.255886  401020 compilation_cache.cc:328] Compilation cache evicted.
I0000 00:00:1778066599.255917  401020 compilation_cache.cc:193] CompilationCache final stats: num_cache_reqs=26
num_cache_hits=2 {7.6%}

version, e.g. after b6053a4d Implement synchronizeDevice in torch_tpu backend. I bisected the problem to #2051. After that PR, the generated Pallas program seems to be reading garbage values that then increments.

Test output:


AFAICT the generated kernel is reading garbage values and updating them.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions