diff --git a/pykokkos/interface/atomic/atomic_fetch_op.py b/pykokkos/interface/atomic/atomic_fetch_op.py index e42e2752..14846351 100644 --- a/pykokkos/interface/atomic/atomic_fetch_op.py +++ b/pykokkos/interface/atomic/atomic_fetch_op.py @@ -45,17 +45,19 @@ def atomic_fetch_max( view: View, indices: List[int], value: Union[int, float]) -> Union[int, float]: + old_result: int = view[indices[-1]] inner = reduce(operator.getitem, indices[:-1], view) inner[indices[-1]] = max(inner[indices[-1]], value) - return inner[indices[-1]] + return old_result def atomic_fetch_min( view: View, indices: List[int], value: Union[int, float]) -> Union[int, float]: + old_result: int = view[indices[-1]] inner = reduce(operator.getitem, indices[:-1], view) inner[indices[-1]] = min(inner[indices[-1]], value) - return inner[indices[-1]] + return old_result def atomic_fetch_mod( view: View, @@ -110,4 +112,8 @@ def atomic_compare_exchange( indices: List[int], comparison_value: int, new_value: int) -> int: - pass + inner = reduce(operator.getitem, indices[:-1], view) + old_result = inner[indices[-1]] + if old_result == comparison_value: + inner[indices[-1]] = new_value + return old_result