Skip to content

Commit 7db6e92

Browse files
jahatefJacob HatefQuentin-Anthony
authored
* rwkv-init * annotations * Re-added docs * make dir if not exist * Add RWKV paper and update doc index * add train loop * experiment --------- Co-authored-by: Jacob Hatef <[email protected]> Co-authored-by: Quentin Anthony <[email protected]>
1 parent 285cb37 commit 7db6e92

File tree

6 files changed

+543
-1
lines changed

6 files changed

+543
-1
lines changed

docs/index.html

+4-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ <h4>✨ <a href="resnet/index.html">ResNet</a></h4>
124124
<h4><a href="conv_mixer/index.html">ConvMixer</a></h4>
125125
<h4><a href="capsule_networks/index.html">Capsule Networks</a></h4>
126126
<h4><a href="unet/index.html">U-Net</a></h4>
127-
<h4><a href="sketch_rnn/index.html">Sketch RNN</a></h4>
127+
<h4><a href="sketch_rnn/index.html">RNNs</a></h4>
128+
<ul><li><a href="rwkv/index.html">RWKV</a> </li>
129+
<li><a href="sketch_rnn/index.html">Sketch RNN</a></li></ul>
128130
<h4>✨ Graph Neural Networks</h4>
129131
<ul><li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a> </li>
130132
<li><a href="graphs/gatv2/index.html">Graph Attention Networks v2 (GATv2)</a></li></ul>
@@ -168,6 +170,7 @@ <h2>Highlighted Research Paper PDFs</h2>
168170
<ul><li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf">Autoregressive Search Engines: Generating Substrings as Document Identifiers</a> </li>
169171
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.15556.pdf">Training Compute-Optimal Large Language Models</a> </li>
170172
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/1910.02054.pdf">ZeRO: Memory Optimizations Toward Training Trillion Parameter Models</a> </li>
173+
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/RWKV.pdf">RWKV: Reinventing RNNs for the Transformer Era</a> </li>
171174
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.02311.pdf">PaLM: Scaling Language Modeling with Pathways</a> </li>
172175
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/dall-e-2.pdf">Hierarchical Text-Conditional Image Generation with CLIP Latents</a> </li>
173176
<li><a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2203.14465.pdf">STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning</a> </li>

labml_nn/RWKV/__init__.py

+328
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
"""
2+
3+
---
4+
title: Receptance Weighted Key Value (RWKV)
5+
summary: >
6+
This implements the RWKV model
7+
using PyTorch with explanations.
8+
---
9+
10+
# Receptance Weighted Key Value (RWKV)
11+
12+
##TODO: make colab ?
13+
14+
This is a tutorial/implementation of RWKV
15+
from paper [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/pdf/2305.13048.pdf)
16+
in [PyTorch](https://pytorch.org/).
17+
18+
Full definition of a RWKV Language Model, all of it in this single file.
19+
References:
20+
1) the official RWKV PyTorch implementation released by Bo Peng:
21+
https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py
22+
2) huggingface/transformers PyTorch implementation:
23+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
24+
"""
25+
26+
27+
import math,time
28+
import os
29+
import inspect
30+
from dataclasses import dataclass
31+
32+
import torch
33+
import torch.nn as nn
34+
from torch.nn import functional as F
35+
36+
from labml_helpers.module import Module
37+
38+
39+
PREV_X_TIME = 0
40+
NUM_STATE = 1
41+
DEN_STATE = 2
42+
MAX_STATE = 3
43+
PREV_X_CHANNEL = 4
44+
45+
"""
46+
## Layernorm with bias
47+
"""
48+
class LayerNorm(Module):
49+
def __init__(self, ndim, bias):
50+
super().__init__()
51+
self.weight = nn.Parameter(torch.ones(ndim))
52+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
53+
54+
def forward(self, input):
55+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
56+
57+
"""
58+
# L2 loss wrapper
59+
https://github.com/BlinkDL/RWKV-LM/blob/cca1b5e8e597cf40675882bb10b46287c844e35c/RWKV-v4/src/model.py#L21
60+
"""
61+
class L2Wrap(torch.autograd.Function):
62+
@staticmethod
63+
def forward(ctx, loss, y):
64+
ctx.save_for_backward(y)
65+
return loss
66+
@staticmethod
67+
def backward(ctx, grad_output):
68+
y = ctx.saved_tensors[0]
69+
# to encourage the logits to be close to 0
70+
factor = 1e-4 / (y.shape[0] * y.shape[1])
71+
maxx, ids = torch.max(y, -1, keepdim=True)
72+
gy = torch.zeros_like(y)
73+
gy.scatter_(-1, ids, maxx * factor)
74+
return (grad_output, gy)
75+
76+
class ChannelMixing(Module):
77+
"""
78+
## Channel Mixing
79+
"""
80+
def __init__(self,config,layer_id):
81+
super().__init__()
82+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
83+
# token shifting
84+
self.layer_id = layer_id
85+
86+
n_embd = config.n_embd
87+
intermediate_size = (
88+
config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
89+
)
90+
91+
## Learnable Matrix
92+
self.key_proj = nn.Linear(n_embd,intermediate_size,bias=False)
93+
self.value_proj = nn.Linear(intermediate_size,n_embd,bias=False)
94+
self.receptance_proj = nn.Linear(n_embd,n_embd,bias=False)
95+
96+
## Learnable Vector
97+
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
98+
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
99+
100+
def forward(self,x,state=None):
101+
# x = (Batch,Time,Channel)
102+
if state is not None:
103+
prev_x = state[self.layer_id,:,[PREV_X_CHANNEL],:]
104+
state[self.layer_id,:,[PREV_X_CHANNEL],:] = x
105+
else:
106+
prev_x = self.time_shift(x)
107+
108+
"""
109+
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
110+
"""
111+
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
112+
receptance = self.receptance_proj(receptance)
113+
114+
"""
115+
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
116+
"""
117+
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
118+
key = self.key_proj(key)
119+
120+
"""
121+
### $V_t=W_v \cdot max(k_t,0)^2$
122+
"""
123+
value = self.value_proj(torch.square(torch.relu(key)))
124+
125+
"""
126+
### $o_t=\sigma(r_t) \odot v_t$
127+
"""
128+
out = F.sigmoid(receptance) * value
129+
return out, state
130+
131+
"""
132+
## Time Mixing
133+
"""
134+
class TimeMixing(Module):
135+
def __init__(self,config,layer_id):
136+
super().__init__()
137+
self.config = config
138+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
139+
self.layer_id = layer_id
140+
141+
n_embd = config.n_embd
142+
attn_sz = n_embd
143+
144+
## learnable matrix
145+
self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
146+
self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
147+
self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
148+
self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
149+
150+
## learnable vector
151+
self.time_decay = nn.Parameter(torch.empty(attn_sz))
152+
self.time_first = nn.Parameter(torch.empty(attn_sz))
153+
self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
154+
self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
155+
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
156+
157+
def forward(self,x,state=None):
158+
# x = (Batch,Time,Channel)
159+
if state is not None:
160+
prev_x = state[self.layer_id,:,[PREV_X_TIME],:]
161+
state[self.layer_id,:,[PREV_X_TIME],:] = x
162+
else:
163+
prev_x = self.time_shift(x)
164+
165+
"""
166+
### $r_t=W_r \cdot (\mu_r x_t + (1-\mu_r)x_{t-1})$
167+
"""
168+
receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
169+
receptance = self.receptance_proj(receptance)
170+
171+
"""
172+
### $k_t=W_k \cdot (\mu_k x_t + (1-\mu_k)x_{t-1})$
173+
"""
174+
key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
175+
key = self.key_proj(key)
176+
177+
"""
178+
### $v_t=W_v \cdot (\mu_v x_t + (1-\mu_v)x_{t-1})$
179+
"""
180+
value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
181+
value = self.value_proj(value)
182+
183+
"""
184+
## WKV calculation
185+
"""
186+
_, seq_length, _ = key.size()
187+
output = torch.zeros_like(key)
188+
189+
if state is None:
190+
num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
191+
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
192+
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
193+
else:
194+
num_state = state[self.layer_id,:,NUM_STATE,:]
195+
den_state = state[self.layer_id,:,DEN_STATE,:]
196+
max_state = state[self.layer_id,:,MAX_STATE,:]
197+
198+
time_decay = -torch.exp(self.time_decay)
199+
200+
for current_index in range(seq_length):
201+
current_key = key[:, current_index].float()
202+
current_value = value[:, current_index]
203+
204+
"""
205+
### $wkv_t=\frac{\sum^{t-1}_{i=1}d^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum^{t-1}_{i=1}e^{-(t-1-i)w+k_i}+e^{u+k_t}}$
206+
"""
207+
max_for_output = torch.maximum(max_state, current_key + self.time_first)
208+
e1 = torch.exp(max_state - max_for_output)
209+
e2 = torch.exp(current_key + self.time_first - max_for_output)
210+
numerator = e1 * num_state + e2 * current_value
211+
denominator = e1 * den_state + e2
212+
output[:, current_index] = (numerator / denominator).to(output.dtype)
213+
214+
# Update state for next iteration
215+
max_for_state = torch.maximum(max_state + time_decay, current_key)
216+
e1 = torch.exp(max_state + time_decay - max_for_state)
217+
e2 = torch.exp(current_key - max_for_state)
218+
num_state = e1 * num_state + e2 * current_value
219+
den_state = e1 * den_state + e2
220+
max_state = max_for_state
221+
222+
"""
223+
### update states
224+
"""
225+
state[self.layer_id,:,NUM_STATE,:] = num_state
226+
state[self.layer_id,:,DEN_STATE,:] = den_state
227+
state[self.layer_id,:,MAX_STATE,:] = max_state
228+
wkv, state = self.wkv_function(key,value,use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,state=state)
229+
230+
"""
231+
### $o_t=W_o \cdot (\sigma(r_t) \odot wkv_t)$
232+
"""
233+
rwkv = F.sigmoid(receptance) * wkv
234+
rwkv = self.output_proj(rwkv)
235+
236+
return rwkv, state
237+
238+
"""
239+
## RWKV block element
240+
"""
241+
class Block(Module):
242+
243+
def __init__(self, config,layer_id):
244+
super().__init__()
245+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
246+
self.attn = TimeMixing(config,layer_id)
247+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
248+
self.ffn = ChannelMixing(config,layer_id)
249+
250+
def forward(self, x, state = None):
251+
# state: [batch_size, 5 , n_embd]
252+
"""
253+
## time mixing
254+
"""
255+
residual = x
256+
x,state = self.attn(self.ln_1(x),state=state)
257+
x = x + residual
258+
"""
259+
## channel mixing
260+
"""
261+
residual = x
262+
x, state = self.ffn(self.ln_2(x),state=state)
263+
x = x + residual
264+
return x, state
265+
266+
class RWKV(Module):
267+
def __init__(self, config,lr_init=0.0008):
268+
super().__init__()
269+
assert config.vocab_size is not None
270+
assert config.block_size is not None
271+
self.config = config
272+
self.lr_init = lr_init ## used to initialize embedding parameters
273+
self.n_layer = config.n_layer
274+
self.n_embd = config.n_embd
275+
"""
276+
## Initiate model layers
277+
"""
278+
self.rwkv = nn.ModuleDict(dict(
279+
wte = nn.Embedding(config.vocab_size, config.n_embd),
280+
ln_p = LayerNorm(config.n_embd, bias=config.bias),
281+
h = nn.ModuleList([Block(config,layer_id) for layer_id in range(config.n_layer)]),
282+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
283+
))
284+
"""
285+
## Output linear layer
286+
"""
287+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
288+
289+
290+
def forward(self, idx, targets=None, state=None, return_state=False):
291+
b, t = idx.size()
292+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
293+
294+
"""
295+
## Embedding Layer
296+
"""
297+
x = self.rwkv.wte(idx)
298+
"""
299+
## Layer Norm
300+
"""
301+
x = self.rwkv.ln_p(x)
302+
"""
303+
## RWKV Blocks
304+
"""
305+
for block_idx,block in enumerate(self.rwkv.h):
306+
x, state = block(x,state)
307+
x = self.rwkv.ln_f(x)
308+
309+
"""
310+
## Logit Layer and loss Function (for training)
311+
"""
312+
if targets is not None:
313+
# if we are given some desired targets also calculate the loss
314+
logits = self.lm_head(x)
315+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
316+
if self.training:
317+
loss = L2Wrap.apply(loss,logits)
318+
else:
319+
# inference-time mini-optimization: only forward the lm_head on the very last position
320+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
321+
loss = None
322+
"""
323+
## Return Logits and loss
324+
"""
325+
if return_state:
326+
return logits, loss, state
327+
else:
328+
return logits, loss

labml_nn/RWKV/configs.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import copy
2+
3+
import torch.nn as nn
4+
5+
from labml.configs import BaseConfigs, option, calculate, aggregate
6+
from labml_helpers.module import Module
7+
8+
9+
class RWKVConfigs(BaseConfigs):
10+
"""
11+
<a id="TransformerConfigs"></a>
12+
13+
## Transformer Configurations
14+
15+
This defines configurations for a transformer.
16+
The configurations are calculate using option functions.
17+
These are lazy loaded and therefore only the necessary modules
18+
are calculated.
19+
"""
20+
# Number of attention heads
21+
n_heads: int = 8
22+
# Transformer embedding size
23+
d_model: int = 512
24+
# Number of layers
25+
n_layers: int = 6
26+
# Dropout probability
27+
dropout: float = 0.1
28+
# Number of tokens in the source vocabulary (for token embeddings)
29+
n_src_vocab: int
30+
# Number of tokens in the target vocabulary (to generate logits for prediction)
31+
n_tgt_vocab: int

0 commit comments

Comments
 (0)