-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcustom_dm_matmul.py
More file actions
135 lines (122 loc) · 3.99 KB
/
custom_dm_matmul.py
File metadata and controls
135 lines (122 loc) · 3.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from ttlang.ttl_api import *
from ttlang.utils.correctness import assert_allclose
@pykernel_gen(
block_factors=[
(1, 1),
(1, 1),
(1, 1),
],
grid=(2, 2),
memory_space="L1",
tiled=True,
# kernel_source_mode="store",
)
def matmul(lhs, rhs, out, block_factors=None, grid=None):
assert block_factors is not None
assert grid is not None
# assert M
assert block_factors[0][0] == block_factors[2][0]
# assert K
assert block_factors[0][1] == block_factors[1][0]
# assert N
assert block_factors[1][1] == block_factors[2][1]
GY = grid[0]
GX = grid[1]
GK = 2
M = block_factors[0][0]
N = block_factors[1][1]
K = block_factors[0][1] * GK
lhs_accessor = TensorAccessor(lhs)
rhs_accessor = TensorAccessor(rhs)
@compute()
def mm(
lhs_cb: CircularBuffer,
rhs_cb: CircularBuffer,
out_cb: CircularBuffer,
lhs_receiver_ready: Semaphore,
lhs_sender_sent: Semaphore,
rhs_receiver_ready: Semaphore,
rhs_sender_sent: Semaphore,
):
for k in range(K):
for m in range(M):
lhs_shard = lhs_cb.pop()
for n in range(N):
rhs_shard = rhs_cb.pop()
out_shard = out_cb.reserve()
out = lhs_shard @ rhs_shard
out_shard.store(out)
out_cb.pop() # compute needs to clear the output
@datamovement()
def dm0(
lhs_cb: CircularBuffer,
rhs_cb: CircularBuffer,
out_cb: CircularBuffer,
lhs_receiver_ready: Semaphore,
lhs_sender_sent: Semaphore,
rhs_receiver_ready: Semaphore,
rhs_sender_sent: Semaphore,
):
cy = core_index(0)
cx = core_index(1)
for k in range(K):
for m in range(M):
lhs_shard = lhs_cb.reserve()
if cx == 0:
tx = dma(lhs_accessor[cy * M + m, k], lhs_shard)
tx.wait()
lhs_receiver_ready.wait(GK - 1, reset=0)
tx = dma(
lhs_shard,
lhs_shard,
core=(cy, 1),
mcast=(1, GX - 1),
)
tx.wait()
lhs_sender_sent.set(1, core=(cy, 1), mcast=(1, GX - 1))
else:
lhs_receiver_ready.inc(1, core=(cy, 0))
lhs_sender_sent.wait(1, reset=0)
@datamovement()
def dm1(
lhs_cb: CircularBuffer,
rhs_cb: CircularBuffer,
out_cb: CircularBuffer,
lhs_receiver_ready: Semaphore,
lhs_sender_sent: Semaphore,
rhs_receiver_ready: Semaphore,
rhs_sender_sent: Semaphore,
):
cy = core_index(0)
cx = core_index(1)
for k in range(K):
for m in range(M):
for n in range(N):
rhs_shard = rhs_cb.reserve()
if cy == 0:
tx = dma(rhs_accessor[k, cx * N + n], rhs_shard)
tx.wait()
rhs_receiver_ready.wait(GK - 1, reset=0)
tx = dma(
rhs_shard,
rhs_shard,
core=(1, cx),
mcast=(GY - 1, 1),
)
tx.wait()
rhs_sender_sent.set(1, core=(1, cx), mcast=(GY - 1, 1))
else:
rhs_receiver_ready.inc(1, core=(0, cx))
rhs_sender_sent.wait(1, reset=0)
return Program(mm, dm0, dm1)(lhs, rhs, out)
lhs = torch.randn(128, 128)
rhs = torch.randn(128, 128)
out = torch.zeros(128, 128)
matmul(lhs, rhs, out)
golden = lhs @ rhs
assert_allclose(out, golden, rtol=1e-2, atol=1e-4)
print("Passed")