Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,15 @@
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
RoformerV2Tokenizer as RoformerV2Tokenizer,
)
from keras_hub.src.models.rwkv7.rwkv7_backbone import (
RWKV7Backbone as RWKV7Backbone,
)
from keras_hub.src.models.rwkv7.rwkv7_causal_lm import (
RWKV7CausalLM as RWKV7CausalLM,
)
from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
RWKV7CausalLMPreprocessor as RWKV7CausalLMPreprocessor,
)
from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter import (
SAMImageSegmenter as SAMImageSegmenter,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
RoformerV2Tokenizer as RoformerV2Tokenizer,
)
from keras_hub.src.models.rwkv7.rwkv7_tokenizer import (
RWKVTokenizer as RWKVTokenizer,
)
from keras_hub.src.models.siglip.siglip_tokenizer import (
SigLIPTokenizer as SigLIPTokenizer,
)
Expand Down
185 changes: 185 additions & 0 deletions keras_hub/src/models/rwkv7/rwkv7_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block


def rwkv7_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras_hub_export("keras_hub.models.RWKV7Backbone")
class RWKV7Backbone(Backbone):
"""The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add the link to the paper in the next section, in the first line keep only description of the Backbone.


This network implements a Modern RNN architecture based on linear
attention mechanisms with recurrent processing, as described in the
RWKV papers. It includes the embedding lookups and RWKV-7 blocks.

The default constructor gives a fully customizable, randomly initialized
RWKV-7 model with any number of layers, heads, and embedding dimensions.
To load preset architectures and weights, use the `from_preset`
constructor.

Args:
hidden_size: int. The size of the transformer encoding and pooling
layers.
head_size: int. The size of each attention head.
num_layers: int. The number of transformer layers.
vocabulary_size: int. The size of the token vocabulary.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
gate_lora: int. LoRA dimension for gating.
mv_lora: int. LoRA dimension for value mixing.
aaa_lora: int. LoRA dimension for alpha parameters.
decay_lora: int. LoRA dimension for decay parameters.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.
dropout_rate: float. Dropout rate for the dropout layer.

Examples:

```python
input_data = np.ones(shape=(1, 12), dtype="int32")


# Randomly initialized RWKV-7 decoder with custom config.
model = keras_hub.models.RWKV7Backbone(
vocabulary_size=10,
hidden_size=512,
num_layers=2,
head_size=64,
intermediate_dim=1024,
dtype="float32"
)
model(input_data)
```
"""

def __init__(
self,
hidden_size,
head_size,
num_layers,
vocabulary_size,
intermediate_dim,
gate_lora=128,
mv_lora=32,
aaa_lora=64,
decay_lora=64,
Comment on lines +75 to +78
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these defaults common accross all the checkpoint config, if not then let's not set any default value.
If these are common, then add the "Defaults to xxx" in thier respective arg description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.1B, 0.3B, 1.5B, and 3B models all use this parameter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean same values 128, 32, 64, 64? Then it's fine to keep it as it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean same values 128, 32, 64, 64? Then it's fine to keep it as it is.

yep

dtype=None,
dropout_rate=0,
**kwargs,
):
"""Initialize RWKV7 backbone.

Args:
hidden_size: Hidden dimension size.
head_size: Attention head size.
num_layers: Number of RWKV blocks.
vocabulary_size: Size of vocabulary.
intermediate_dim: Intermediate dimension for FFN.
gate_lora: LoRA dimension for gating.
mv_lora: LoRA dimension for value mixing.
aaa_lora: LoRA dimension for alpha parameters.
decay_lora: LoRA dimension for decay parameters.
dtype: Data type for the layer.
dropout_rate: Dropout rate for regularization.
**kwargs: Additional arguments.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method's docstring duplicates information already present in the class-level docstring. According to the style guide's example for backbones, the __init__ method should not have a separate docstring.1 Removing this will make the code more concise and align it with the repository's conventions.

Style Guide References

Footnotes

  1. The style guide example for a backbone class shows arguments documented in the class docstring, not in the __init__ method.

# === Layers ===
self.token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_size,
embeddings_initializer=rwkv7_kernel_initializer(),
dtype=dtype,
name="token_embedding",
)
self.token_embedding.build([None, None])
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This self.token_embedding.build([None, None]) can be removed, since you will be calling this layer later in the code and it will be built there.


self.output_layer_norm = keras.layers.LayerNormalization(
epsilon=1e-5, name="output_norm"
)
self.output_layer_norm.build([None, None, hidden_size])
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, you can remove this self.output_layer_norm.build([None, None, hidden_size]).

self.dropout = keras.layers.Dropout(
dropout_rate,
dtype=dtype,
name="dropout",
)
self.rwkv_layers = []
for i in range(num_layers):
layer = RWKV7_Block(
hidden_size,
head_size,
intermediate_dim,
gate_lora,
mv_lora,
aaa_lora,
decay_lora,
use_initial_norm=i == 0,
kernel_initializer=rwkv7_kernel_initializer(),
dtype=dtype,
name=f"rwkv_layer_{i}",
)

self.rwkv_layers.append(layer)
self.head = keras.layers.Dense(
units=vocabulary_size,
kernel_initializer=rwkv7_kernel_initializer(),
use_bias=False,
name="head",
)
# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

padding_mask = ops.not_equal(token_id_input, 0)

x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask, dtype=x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model should accept padding_mask as an input instead of deriving it from token_ids. This hardcodes the assumption that the padding token ID is 0 and deviates from the repository's style guide for backbone models.1

Please update the model to accept padding_mask as a keras.Input and also update the super().__init__ call to include it in the inputs dictionary.

Suggested change
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask = ops.not_equal(token_id_input, 0)
x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask, dtype=x.dtype)
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask_input, dtype=x.dtype)

Style Guide References

Footnotes

  1. Backbone models should accept standardized input names like token_ids and padding_mask to ensure interoperability.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make padding_mask askeras.Input, let's not assume the padding token id will be always 0 in all the use case scenarios, which would make it difficult for the users to switch from different backbone to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, I asked the author, it can be hardcoded.

v_first = None
for rwkv_layer in self.rwkv_layers:
x, v_first = rwkv_layer(x, v_first, padding_mask)
x = self.dropout(x)
sequence_output = self.output_layer_norm(x)
sequence_output = self.head(sequence_output)
super().__init__(
inputs=token_id_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the padding_mask is changed to keras.Input, make the inputs as dictionary which looks something like inputs={ "token_ids": token_id_input, "padding_mask": padding_mask_input, },

outputs=sequence_output,
dtype=dtype,
**kwargs,
)
# Initialize the graph to avoid potential errors in some cases
self.call(ops.ones([1, 16], "int32"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant, model will be already built in super().__init__(...) and this needs to be changed everytime the input signature is changed.


self.num_layers = num_layers
self.head_size = head_size
self.hidden_size = hidden_size
self.gate_lora = gate_lora
self.mv_lora = mv_lora
self.aaa_lora = aaa_lora
self.decay_lora = decay_lora
self.vocabulary_size = vocabulary_size
self.dropout_rate = dropout_rate
self.intermediate_dim = intermediate_dim

def get_config(self):
config = {
"hidden_size": self.hidden_size,
"head_size": self.head_size,
"gate_lora": self.gate_lora,
"mv_lora": self.mv_lora,
"aaa_lora": self.aaa_lora,
"decay_lora": self.decay_lora,
"vocabulary_size": self.vocabulary_size,
"dropout_rate": self.dropout_rate,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
37 changes: 37 additions & 0 deletions keras_hub/src/models/rwkv7/rwkv7_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from keras import ops

from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
from keras_hub.src.tests.test_case import TestCase


class RWKV7BackboneTest(TestCase):
def setUp(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test_saved_model test

"""
Set up the test case with default arguments and input data.
"""
self.init_kwargs = {
"vocabulary_size": 10,
"hidden_size": 16,
"num_layers": 2,
"head_size": 4,
"intermediate_dim": 32,
"gate_lora": 32,
"mv_lora": 16,
"aaa_lora": 16,
"decay_lora": 16,
}
self.input_data = ops.ones((2, 5), dtype="int32")
self.backbone = RWKV7Backbone(**self.init_kwargs)

def test_backbone_basics(self):
"""
Test basic functionality of the RWKV7 backbone.
"""
y = self.backbone(self.input_data)
self.assertEqual(y.shape, (2, 5, 10))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test test_backbone_basics is implemented manually. The repository style guide requires using the standardized test routines from TestCase to ensure all core functionality is covered, including shape inference, serialization, and handling of variable input shapes.1 You should use self.run_backbone_test() instead.

    def test_backbone_basics(self):
        """
        Test basic functionality of the RWKV7 backbone.
        """
        self.run_backbone_test(
            cls=RWKV7Backbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
            expected_output_shape=(2, 5, 10),
        )

Style Guide References

Footnotes

  1. The style guide mandates the use of helper methods like self.run_backbone_test() for testing backbones to ensure standardized validation.


def test_num_parameters(self):
"""
Test that the model has the expected number of parameters.
"""
self.assertEqual(self.backbone.count_params(), 10208)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tests should use the standardized test routines provided by TestCase as required by the style guide.1 Please replace the custom test logic with calls to self.run_backbone_test() and self.run_model_saving_test(). This ensures consistency and covers more test cases automatically, such as variable input shapes and serialization. You will also need to add import pytest.

Suggested change
def test_backbone_basics(self):
"""
Test basic functionality of the RWKV7 backbone.
"""
y = self.backbone(self.input_data)
self.assertEqual(y.shape, (2, 5, 10))
def test_num_parameters(self):
"""
Test that the model has the expected number of parameters.
"""
self.assertEqual(self.backbone.count_params(), 10208)
def test_backbone_basics(self):
self.run_backbone_test(
cls=RWKV7Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 10),
)
def test_saved_model(self):
self.run_model_saving_test(
cls=RWKV7Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

Style Guide References

Footnotes

  1. The style guide requires using helper methods like self.run_backbone_test() and self.run_model_saving_test() for standardized testing of backbones.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address this comment from Gemini.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After modifying it this way, fp16 will fail, but I cannot reproduce this error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is our standard way of testing the backbone. We should try to identify the issue than coming up with the workaround.

Loading
Loading