forked from kokkos/pykokkos
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathatomic_fetch_op.py
More file actions
119 lines (104 loc) · 3.37 KB
/
atomic_fetch_op.py
File metadata and controls
119 lines (104 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from functools import reduce
from typing import (
List, Union
)
import operator
from ..views import View
# Atomic operations from:
# https://github.com/kokkos/kokkos/wiki/Kokkos%3A%3Aatomic_fetch_op
def atomic_fetch_add(
view: View,
indices: List[int],
value: Union[int, float]) -> Union[int, float]:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] += value
return inner[indices[-1]]
def atomic_fetch_and(
view: View,
indices: List[int],
value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] &= value
return inner[indices[-1]]
def atomic_fetch_div(
view: View,
indices: List[int],
value: Union[int, float]) -> Union[int, float]:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] /= value
return inner[indices[-1]]
def atomic_fetch_lshift(
view: View,
indices: List[int],
value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] <<= value
return inner[indices[-1]]
def atomic_fetch_max(
view: View,
indices: List[int],
value: Union[int, float]) -> Union[int, float]:
old_result: int = view[indices[-1]]
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] = max(inner[indices[-1]], value)
return old_result
def atomic_fetch_min(
view: View,
indices: List[int],
value: Union[int, float]) -> Union[int, float]:
old_result: int = view[indices[-1]]
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] = min(inner[indices[-1]], value)
return old_result
def atomic_fetch_mod(
view: View,
indices: List[int],
value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] %= value
return inner[indices[-1]]
def atomic_fetch_mul(
view: View,
indices: List[int],
value: Union[int, float]) -> Union[int, float]:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] *= value
return inner[indices[-1]]
def atomic_fetch_or(
view: View,
indices: List[int],
value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] |= value
return inner[indices[-1]]
def atomic_fetch_rshift(
view: View,
indices: List[int],
value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] >>= value
return inner[indices[-1]]
def atomic_fetch_sub(
view: View,
indices: List[int],
value: Union[int, float]) -> Union[int, float]:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] -= value
return inner[indices[-1]]
def atomic_fetch_xor(
view: View,
indices: List[int],
value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
inner[indices[-1]] ^= value
return inner[indices[-1]]
def atomic_compare_exchange(
view: View,
indices: List[int],
comparison_value: int,
new_value: int) -> int:
inner = reduce(operator.getitem, indices[:-1], view)
old_result = inner[indices[-1]]
if old_result == comparison_value:
inner[indices[-1]] = new_value
return old_result