-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathassert_zig_zag.py
More file actions
205 lines (155 loc) · 5.32 KB
/
assert_zig_zag.py
File metadata and controls
205 lines (155 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import click
from math import ceil
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from ring_attention_pytorch import RingAttention
from ring_attention_pytorch.distributed import all_gather_variable_dim
from einops import rearrange
from ring_attention_pytorch.ring_attention import apply_rotary_pos_emb
from ring_attention_pytorch.zig_zag_attention import (
zig_zag_pad_seq,
zig_zag_attn,
zig_zag_shard
)
def abs_diff(x, y):
return (x - y).abs().amax()
def setup(
rank,
world_size,
use_cuda
):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
backend = "gloo" if not use_cuda else "nccl"
dist.init_process_group(backend, rank = rank, world_size = world_size)
if use_cuda:
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def start(
rank,
world_size,
batch_size,
batch_size_var_len,
seq_len,
num_sharded_batches,
dim,
heads,
num_grouped_query_heads,
dim_head,
use_cuda,
rotary
):
setup(rank, world_size, use_cuda)
attention = RingAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
num_grouped_query_heads = num_grouped_query_heads,
causal = True,
rotary_embed = rotary,
ring_attn = False,
use_cuda_kernel = use_cuda
)
if batch_size_var_len:
batch_size = batch_size + rank
seq = torch.randn(batch_size, seq_len, dim)
# move to cuda if needed
if use_cuda:
seq = seq.cuda(rank)
attention.cuda(rank)
# separate inputs for ring vs flash
regular_input = seq.clone().requires_grad_()
zig_zag_input = seq.clone().requires_grad_()
# wrap
ddp_attention = DDP(attention)
# regular
out = ddp_attention(regular_input)
out.mean().backward()
# zig zag
padded_inp, remove_pad = zig_zag_pad_seq(zig_zag_input)
(padded_inp, q_indices, kv_indices), gather_seq = zig_zag_shard(padded_inp, all_gather_batch = True)
qkv = attention.to_qkv(padded_inp)
q, k, v = rearrange(qkv, 'b n (h d) -> b h n d', d = dim_head).split(attention.qkv_head_breakdown, dim = -3)
if rotary:
pos_emb = attention.rotary_embed(q_indices)
q = apply_rotary_pos_emb(pos_emb, q, head_dim_first = True)
k = apply_rotary_pos_emb(pos_emb, k, head_dim_first = True)
# causal mask
causal_mask = q_indices[:, None] >= kv_indices[None, :]
# attention
o = zig_zag_attn(
q, k, v,
attn_mask = causal_mask
)
o = rearrange(o, 'b h n d -> b n (h d)')
padded_out = attention.to_out(o)
padded_out = gather_seq(padded_out)
zig_zag_out = remove_pad(padded_out)
zig_zag_out.mean().backward()
# validate output is the same for sequence split across machines vs without
if rank == 0:
out = out.cpu()
zig_zag_out = zig_zag_out.cpu()
output_atol = 1e-2 if use_cuda else 1e-6
assert torch.allclose(out, zig_zag_out, atol = output_atol), 'output is not the same'
# validate gradients is the same
regular_input_grad = regular_input.grad
zig_zag_input_grad = zig_zag_input.grad
assert torch.allclose(
regular_input_grad,
zig_zag_input_grad,
atol = 1e-2
), 'grad is not the same'
print('✅ outputs and gradients are same between zig zag attention and regular attention')
cleanup()
@click.command()
@click.option('--world-size', default = 8, help = 'number of machines / processes')
@click.option('--batch-size', default = 2, help = 'test batch size')
@click.option('--num-sharded-batches', default = 1, help = 'number of sharded batches')
@click.option('--batch-size-var-len', is_flag = True, help = 'test variable lengthed batch sizes')
@click.option('--use-cuda', is_flag = True, help = 'whether to test with CUDA and NCCL')
@click.option('--rotary', is_flag = True, help = 'whether to test with rotary embeddings')
@click.option('--seq-len', default = 31, help = 'sequence length to test')
@click.option('--model-dim', default = 8, help = 'model dimensions for testing')
@click.option('--heads', default = 8, help = 'number of query attention heads')
@click.option('--num-grouped-query-heads', default = 2, help = 'number of query attention head groups')
@click.option('--dim-head', default = 16, help = 'model dimensions for testing')
def test(
world_size: int,
batch_size: int,
num_sharded_batches: int,
batch_size_var_len: bool,
use_cuda: bool,
rotary: bool,
seq_len: int,
model_dim: int,
heads: int,
num_grouped_query_heads: int,
dim_head: int,
):
assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}'
mp.spawn(
start,
args = (
world_size,
batch_size,
batch_size_var_len,
seq_len,
num_sharded_batches,
model_dim,
heads,
num_grouped_query_heads,
dim_head,
use_cuda,
rotary
),
nprocs = world_size,
join = True
)
if __name__ == '__main__':
test()