-
Notifications
You must be signed in to change notification settings - Fork 305
ADD RWKV7 #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ADD RWKV7 #2421
Changes from 10 commits
195ef79
7bc36b5
7d4a7a1
e5bb446
afcff31
ec0baf3
bd6c618
4201a7f
897a64b
ff11f94
ce13d54
0e36b4a
7218888
cc5815b
dd80464
5e8723d
f223002
b2b1573
c5ebeec
14111c8
a88ae01
7f8bda7
00200a8
e97b458
75a4415
8c3638b
468dce1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,185 @@ | ||||||||||||||||||||||||||||||||||||
| import keras | ||||||||||||||||||||||||||||||||||||
| from keras import ops | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from keras_hub.src.api_export import keras_hub_export | ||||||||||||||||||||||||||||||||||||
| from keras_hub.src.models.backbone import Backbone | ||||||||||||||||||||||||||||||||||||
| from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def rwkv7_kernel_initializer(stddev=0.02): | ||||||||||||||||||||||||||||||||||||
| return keras.initializers.TruncatedNormal(stddev=stddev) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @keras_hub_export("keras_hub.models.RWKV7Backbone") | ||||||||||||||||||||||||||||||||||||
| class RWKV7Backbone(Backbone): | ||||||||||||||||||||||||||||||||||||
| """The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| This network implements a Modern RNN architecture based on linear | ||||||||||||||||||||||||||||||||||||
| attention mechanisms with recurrent processing, as described in the | ||||||||||||||||||||||||||||||||||||
| RWKV papers. It includes the embedding lookups and RWKV-7 blocks. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| The default constructor gives a fully customizable, randomly initialized | ||||||||||||||||||||||||||||||||||||
| RWKV-7 model with any number of layers, heads, and embedding dimensions. | ||||||||||||||||||||||||||||||||||||
| To load preset architectures and weights, use the `from_preset` | ||||||||||||||||||||||||||||||||||||
| constructor. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| hidden_size: int. The size of the transformer encoding and pooling | ||||||||||||||||||||||||||||||||||||
| layers. | ||||||||||||||||||||||||||||||||||||
| head_size: int. The size of each attention head. | ||||||||||||||||||||||||||||||||||||
| num_layers: int. The number of transformer layers. | ||||||||||||||||||||||||||||||||||||
| vocabulary_size: int. The size of the token vocabulary. | ||||||||||||||||||||||||||||||||||||
| intermediate_dim: int. The output dimension of the first Dense layer in | ||||||||||||||||||||||||||||||||||||
| a two-layer feedforward network for each transformer. | ||||||||||||||||||||||||||||||||||||
| gate_lora: int. LoRA dimension for gating. | ||||||||||||||||||||||||||||||||||||
| mv_lora: int. LoRA dimension for value mixing. | ||||||||||||||||||||||||||||||||||||
| aaa_lora: int. LoRA dimension for alpha parameters. | ||||||||||||||||||||||||||||||||||||
| decay_lora: int. LoRA dimension for decay parameters. | ||||||||||||||||||||||||||||||||||||
| dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||||||||||||||||||||||||||||||||||||
| for model computations and weights. Note that some computations, | ||||||||||||||||||||||||||||||||||||
| such as softmax and layer normalization, will always be done at | ||||||||||||||||||||||||||||||||||||
| float32 precision regardless of dtype. | ||||||||||||||||||||||||||||||||||||
| dropout_rate: float. Dropout rate for the dropout layer. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Examples: | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||
| input_data = np.ones(shape=(1, 12), dtype="int32") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Randomly initialized RWKV-7 decoder with custom config. | ||||||||||||||||||||||||||||||||||||
| model = keras_hub.models.RWKV7Backbone( | ||||||||||||||||||||||||||||||||||||
| vocabulary_size=10, | ||||||||||||||||||||||||||||||||||||
| hidden_size=512, | ||||||||||||||||||||||||||||||||||||
| num_layers=2, | ||||||||||||||||||||||||||||||||||||
| head_size=64, | ||||||||||||||||||||||||||||||||||||
| intermediate_dim=1024, | ||||||||||||||||||||||||||||||||||||
| dtype="float32" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| model(input_data) | ||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| hidden_size, | ||||||||||||||||||||||||||||||||||||
| head_size, | ||||||||||||||||||||||||||||||||||||
| num_layers, | ||||||||||||||||||||||||||||||||||||
| vocabulary_size, | ||||||||||||||||||||||||||||||||||||
| intermediate_dim, | ||||||||||||||||||||||||||||||||||||
| gate_lora=128, | ||||||||||||||||||||||||||||||||||||
| mv_lora=32, | ||||||||||||||||||||||||||||||||||||
| aaa_lora=64, | ||||||||||||||||||||||||||||||||||||
| decay_lora=64, | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+75
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these defaults common accross all the checkpoint config, if not then let's not set any default value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.1B, 0.3B, 1.5B, and 3B models all use this parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean same values 128, 32, 64, 64? Then it's fine to keep it as it is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yep |
||||||||||||||||||||||||||||||||||||
| dtype=None, | ||||||||||||||||||||||||||||||||||||
| dropout_rate=0, | ||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| """Initialize RWKV7 backbone. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| hidden_size: Hidden dimension size. | ||||||||||||||||||||||||||||||||||||
| head_size: Attention head size. | ||||||||||||||||||||||||||||||||||||
| num_layers: Number of RWKV blocks. | ||||||||||||||||||||||||||||||||||||
| vocabulary_size: Size of vocabulary. | ||||||||||||||||||||||||||||||||||||
| intermediate_dim: Intermediate dimension for FFN. | ||||||||||||||||||||||||||||||||||||
| gate_lora: LoRA dimension for gating. | ||||||||||||||||||||||||||||||||||||
| mv_lora: LoRA dimension for value mixing. | ||||||||||||||||||||||||||||||||||||
| aaa_lora: LoRA dimension for alpha parameters. | ||||||||||||||||||||||||||||||||||||
| decay_lora: LoRA dimension for decay parameters. | ||||||||||||||||||||||||||||||||||||
| dtype: Data type for the layer. | ||||||||||||||||||||||||||||||||||||
| dropout_rate: Dropout rate for regularization. | ||||||||||||||||||||||||||||||||||||
| **kwargs: Additional arguments. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| # === Layers === | ||||||||||||||||||||||||||||||||||||
| self.token_embedding = keras.layers.Embedding( | ||||||||||||||||||||||||||||||||||||
| input_dim=vocabulary_size, | ||||||||||||||||||||||||||||||||||||
| output_dim=hidden_size, | ||||||||||||||||||||||||||||||||||||
| embeddings_initializer=rwkv7_kernel_initializer(), | ||||||||||||||||||||||||||||||||||||
| dtype=dtype, | ||||||||||||||||||||||||||||||||||||
| name="token_embedding", | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| self.token_embedding.build([None, None]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| self.output_layer_norm = keras.layers.LayerNormalization( | ||||||||||||||||||||||||||||||||||||
| epsilon=1e-5, name="output_norm" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| self.output_layer_norm.build([None, None, hidden_size]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| self.dropout = keras.layers.Dropout( | ||||||||||||||||||||||||||||||||||||
| dropout_rate, | ||||||||||||||||||||||||||||||||||||
| dtype=dtype, | ||||||||||||||||||||||||||||||||||||
| name="dropout", | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| self.rwkv_layers = [] | ||||||||||||||||||||||||||||||||||||
| for i in range(num_layers): | ||||||||||||||||||||||||||||||||||||
| layer = RWKV7_Block( | ||||||||||||||||||||||||||||||||||||
| hidden_size, | ||||||||||||||||||||||||||||||||||||
| head_size, | ||||||||||||||||||||||||||||||||||||
| intermediate_dim, | ||||||||||||||||||||||||||||||||||||
| gate_lora, | ||||||||||||||||||||||||||||||||||||
| mv_lora, | ||||||||||||||||||||||||||||||||||||
| aaa_lora, | ||||||||||||||||||||||||||||||||||||
| decay_lora, | ||||||||||||||||||||||||||||||||||||
| use_initial_norm=i == 0, | ||||||||||||||||||||||||||||||||||||
| kernel_initializer=rwkv7_kernel_initializer(), | ||||||||||||||||||||||||||||||||||||
| dtype=dtype, | ||||||||||||||||||||||||||||||||||||
| name=f"rwkv_layer_{i}", | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| self.rwkv_layers.append(layer) | ||||||||||||||||||||||||||||||||||||
| self.head = keras.layers.Dense( | ||||||||||||||||||||||||||||||||||||
| units=vocabulary_size, | ||||||||||||||||||||||||||||||||||||
| kernel_initializer=rwkv7_kernel_initializer(), | ||||||||||||||||||||||||||||||||||||
| use_bias=False, | ||||||||||||||||||||||||||||||||||||
| name="head", | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| # === Functional Model === | ||||||||||||||||||||||||||||||||||||
| token_id_input = keras.Input( | ||||||||||||||||||||||||||||||||||||
| shape=(None,), dtype="int32", name="token_ids" | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| padding_mask = ops.not_equal(token_id_input, 0) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| x = self.token_embedding(token_id_input) | ||||||||||||||||||||||||||||||||||||
| padding_mask = ops.cast(padding_mask, dtype=x.dtype) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
| token_id_input = keras.Input( | |
| shape=(None,), dtype="int32", name="token_ids" | |
| ) | |
| padding_mask = ops.not_equal(token_id_input, 0) | |
| x = self.token_embedding(token_id_input) | |
| padding_mask = ops.cast(padding_mask, dtype=x.dtype) | |
| token_id_input = keras.Input( | |
| shape=(None,), dtype="int32", name="token_ids" | |
| ) | |
| padding_mask_input = keras.Input( | |
| shape=(None,), dtype="int32", name="padding_mask" | |
| ) | |
| x = self.token_embedding(token_id_input) | |
| padding_mask = ops.cast(padding_mask_input, dtype=x.dtype) |
Style Guide References
Footnotes
-
Backbone models should accept standardized input names like
token_idsandpadding_maskto ensure interoperability. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make padding_mask askeras.Input, let's not assume the padding token id will be always 0 in all the use case scenarios, which would make it difficult for the users to switch from different backbone to this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, I asked the author, it can be hardcoded.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the padding_mask is changed to keras.Input, make the inputs as dictionary which looks something like inputs={ "token_ids": token_id_input, "padding_mask": padding_mask_input, },
sachinprasadhs marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant, model will be already built in super().__init__(...) and this needs to be changed everytime the input signature is changed.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,37 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from keras import ops | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from keras_hub.src.tests.test_case import TestCase | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class RWKV7BackboneTest(TestCase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def setUp(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add test_saved_model test |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Set up the test case with default arguments and input data. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.init_kwargs = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "vocabulary_size": 10, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "hidden_size": 16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "num_layers": 2, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "head_size": 4, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "intermediate_dim": 32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "gate_lora": 32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "mv_lora": 16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "aaa_lora": 16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "decay_lora": 16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.input_data = ops.ones((2, 5), dtype="int32") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.backbone = RWKV7Backbone(**self.init_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_backbone_basics(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Test basic functionality of the RWKV7 backbone. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y = self.backbone(self.input_data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(y.shape, (2, 5, 10)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_num_parameters(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Test that the model has the expected number of parameters. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.assertEqual(self.backbone.count_params(), 10208) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_backbone_basics(self): | |
| """ | |
| Test basic functionality of the RWKV7 backbone. | |
| """ | |
| y = self.backbone(self.input_data) | |
| self.assertEqual(y.shape, (2, 5, 10)) | |
| def test_num_parameters(self): | |
| """ | |
| Test that the model has the expected number of parameters. | |
| """ | |
| self.assertEqual(self.backbone.count_params(), 10208) | |
| def test_backbone_basics(self): | |
| self.run_backbone_test( | |
| cls=RWKV7Backbone, | |
| init_kwargs=self.init_kwargs, | |
| input_data=self.input_data, | |
| expected_output_shape=(2, 5, 10), | |
| ) | |
| def test_saved_model(self): | |
| self.run_model_saving_test( | |
| cls=RWKV7Backbone, | |
| init_kwargs=self.init_kwargs, | |
| input_data=self.input_data, | |
| ) |
Style Guide References
Footnotes
-
The style guide requires using helper methods like
self.run_backbone_test()andself.run_model_saving_test()for standardized testing of backbones. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address this comment from Gemini.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After modifying it this way, fp16 will fail, but I cannot reproduce this error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is our standard way of testing the backbone. We should try to identify the issue than coming up with the workaround.
Uh oh!
There was an error while loading. Please reload this page.