-
Notifications
You must be signed in to change notification settings - Fork 303
ADD RWKV7 #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ADD RWKV7 #2421
Conversation
Summary of ChangesHello @pass-lin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces the RWKV-7 model, a powerful RNN architecture, to keras_hub
. The contribution is significant and includes the backbone, tokenizer, preprocessor, an incomplete task model, and a checkpoint conversion script. The implementation follows the modular structure of keras_hub
.
However, there are several critical issues that must be addressed before this PR can be merged:
- Missing Tests: The PR lacks unit tests for all new components. According to the contribution guidelines, testing is a mandatory requirement.[^1]
- Incomplete
CausalLM
Task: TheRWKV7CausalLM
task model is a stub withTODO
s, making it non-functional for generation. - Critical Bugs: There are critical bugs in the tokenizer and preprocessor implementations that will cause runtime errors.
- Style Guide Violations: There are numerous style guide violations, including a filename typo, missing docstrings, and inconsistencies with the recommended model input structure.
I've left detailed comments on these issues. Once these are resolved, this will be a great addition to the library.
|
||
def save_assets(self, dir_path): | ||
path = os.path.join(dir_path, VOCAB_FILENAME) | ||
with open(path, "wb") as file: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def call_with_cache( | ||
self, | ||
token_ids, | ||
cache, | ||
cache_update_index, | ||
): | ||
pass # TODO | ||
|
||
def _build_cache(self, token_ids): | ||
pass # TODO | ||
|
||
def generate_step( | ||
self, | ||
inputs, | ||
stop_token_ids=None, | ||
): | ||
pass # TODO | ||
|
||
def score( | ||
self, | ||
token_ids, | ||
padding_mask=None, | ||
scoring_mode="logits", | ||
layer_intercept_fn=None, | ||
target_ids=None, | ||
): | ||
pass # TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from modelscope import snapshot_download | ||
|
||
from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone | ||
from keras_hub.src.models.rwkv7.rwkv7_casual_lm import RWKV7CausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
token_ids, padding_mask = self.packer( | ||
x, sequence_length=sequence_length, add_end_value=False | ||
) | ||
return token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generate_preprocess
method returns a single tensor, but generate_postprocess
expects a dictionary {'token_ids': ..., 'padding_mask': ...}
. This inconsistency will cause a TypeError
during text generation. generate_preprocess
should return a dictionary to match the expected input of generate_postprocess
and for consistency with the base preprocessor class.1
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
Style Guide References
Footnotes
-
Preprocessors should handle padding, truncation, generating attention masks, and formatting the output into a dictionary of tensors that match the backbone's input signature. ↩
def tokenize(self, inputs): | ||
self._check_vocabulary() | ||
tokens = self._tokenizer.encode(inputs) | ||
|
||
def tokens2ids(x): | ||
return [self.token_to_id(t) for t in x] | ||
|
||
if is_string_dtype(self.dtype): | ||
if isinstance(inputs, str): | ||
return tokens2ids(tokens) | ||
return [tokens2ids(t) for t in tokens] | ||
return tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling dtype='string'
is incorrect. It references a tokens2ids
function which is not in the correct scope and would not work as intended because self._tokenizer.encode()
already returns integer IDs. If dtype
is 'string'
, you should convert the integer IDs back to string tokens.
def tokenize(self, inputs): | |
self._check_vocabulary() | |
tokens = self._tokenizer.encode(inputs) | |
def tokens2ids(x): | |
return [self.token_to_id(t) for t in x] | |
if is_string_dtype(self.dtype): | |
if isinstance(inputs, str): | |
return tokens2ids(tokens) | |
return [tokens2ids(t) for t in tokens] | |
return tokens | |
def tokenize(self, inputs): | |
self._check_vocabulary() | |
ids = self._tokenizer.encode(inputs) | |
if is_string_dtype(self.dtype): | |
def ids_to_tokens(id_list): | |
return [self.id_to_token(i) for i in id_list] | |
if isinstance(inputs, str): | |
return ids_to_tokens(ids) | |
return [ids_to_tokens(id_list) for id_list in ids] | |
return ids |
"keras_hub.models.RWKVTokenizer", | ||
] | ||
) | ||
class RWKVTokenizer(tokenizer.Tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKVTokenizer
class is missing a docstring. Please add a comprehensive docstring that explains the tokenizer, its arguments, and includes usage examples, as per the style guide.1
Style Guide References
Footnotes
-
All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters. ↩
- 0.5 | ||
) # soft-clamp to (-inf, -0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
@keras_hub_export("keras_hub.models.RWKV7Backbone") | ||
class RWKV7Backbone(Backbone): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKV7Backbone
class is missing a docstring. Please add a Google-style docstring explaining the model's architecture, its parameters, and include a usage example, as specified in the style guide.1
Style Guide References
Footnotes
-
All public classes, methods, and functions must have Google-style docstrings, including a concise summary, comprehensive examples, and documentation for all parameters, return values, and exceptions. ↩
|
||
|
||
@keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor") | ||
class RWKV7CausalLMPreprocessor(CausalLMPreprocessor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKV7CausalLMPreprocessor
class is missing a docstring. Please add a Google-style docstring explaining its purpose, parameters, and include a usage example, as specified in the style guide.1
Style Guide References
Footnotes
-
All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters. ↩
super().__init__( | ||
inputs=token_id_input, | ||
outputs=sequence_output, | ||
dtype=dtype, | ||
**kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backbone's __init__
method only accepts a single token_ids
tensor as input. For consistency with other models in keras_hub
and to improve interoperability, the backbone should be modified to accept a dictionary of inputs, including token_ids
and padding_mask
.1 The padding_mask
is currently computed inside the backbone, but it's better practice to have it as an explicit input.
Style Guide References
Footnotes
-
Use standardized names for model input arguments to ensure interoperability. For text models, this includes
token_ids
andpadding_mask
. The backbone should accept a dictionary of these inputs. ↩
@divyashreepathihalli @mattdangerw Anybody review my code? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive implementation of the RWKV-7 model, a modern RNN architecture, into keras_hub
. The contribution is well-structured, following the repository's modular design with a backbone, causal LM task, preprocessor, tokenizer, and a checkpoint conversion script. The code is generally of high quality.
My review highlights a few areas for improvement. There are critical bugs in the tokenizer's asset loading and saving logic that need to be addressed. The custom layers in rwkv7_layer.py
have inconsistent return signatures, which could lead to runtime errors and makes the code harder to maintain. Additionally, the test files do not use the standardized helper methods from TestCase
as required by the repository's style guide, which is a significant deviation. I have also pointed out some minor typos and opportunities for code clarification.
Overall, this is an excellent and valuable addition. Addressing the feedback will improve the robustness, correctness, and maintainability of the new model.
with open(path, "wb") as file: | ||
file.write("\n".join(self.vocabulary)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The save_assets
method opens the vocabulary file in binary write mode ("wb"
) but attempts to write a string to it. This will raise a TypeError
. The file should be opened in text mode with utf-8
encoding.
with open(path, "wb") as file: | |
file.write("\n".join(self.vocabulary)) | |
with open(path, "w", encoding="utf-8") as file: | |
file.write("\n".join(self.vocabulary)) |
""" | ||
path = os.path.join(dir_path, VOCAB_FILENAME) | ||
with open(path, "r", encoding="utf-8") as f: | ||
vocabulary = f.readlines() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_assets
method uses f.readlines()
, which keeps trailing newline characters (\n
). This will cause the int()
conversion to fail when parsing the vocabulary file because int()
cannot parse a string with a trailing newline (e.g., int('3\n')
). You should strip the newlines when reading the file.
vocabulary = f.readlines() | |
vocabulary = f.read().splitlines() |
if train_mode: | ||
return x, v_first | ||
return x, v_first, last_cache_x, finnal_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call
method in RWKV7_TimeMix
has two issues:
- Inconsistent return signature: It returns a different number of values depending on
train_mode
, which can cause runtime errors.1 - Typo: The variable
finnal_state
is misspelled. It should befinal_state
.
To fix this, the method should always return the same number of values, and the typo should be corrected. This change should also be propagated to where finnal_state
is computed on line 448.
if train_mode: | |
return x, v_first | |
return x, v_first, last_cache_x, finnal_state | |
if train_mode: | |
return x, v_first, None, None | |
return x, v_first, last_cache_x, finnal_state |
Rules References
Footnotes
-
Methods should have a consistent return signature to avoid errors and improve predictability. Returning
None
for values not computed in a certain mode is a good practice. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is error issue
if train_mode: | ||
return output | ||
return output, last_cache_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call
method in RWKV7_ChannelMix
has an inconsistent return signature. It returns one value in training mode and two in inference mode. This can lead to errors and makes the layer harder to use. The method should always return the same number of values, using None
for outputs that are not computed in a particular mode.1
if train_mode: | |
return output | |
return output, last_cache_x | |
if train_mode: | |
return output, None | |
return output, last_cache_x |
Rules References
Footnotes
-
Methods should have a consistent return signature to avoid errors and improve predictability. Returning
None
for values not computed in a certain mode is a good practice. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method in has an inconsistent return signature. It returns one value in training mode and two in inference mode. This can lead to errors and makes the layer harder to use. The method should always return the same number of values, using for outputs that are not computed in a particular mode.
call``RWKV7_ChannelMix``None
1Rules References
Footnotes
- Methods should have a consistent return signature to avoid errors and improve predictability. Returning for values not computed in a certain mode is a good practice. ↩
None
this is error issue
def call( | ||
self, | ||
x, | ||
v_first=None, | ||
padding_mask=None, | ||
cache_state=None, | ||
cache_tmix_x=None, | ||
cache_cmix_x=None, | ||
rnn_mode=False, | ||
train_mode=True, | ||
): | ||
"""Process input through RWKV block. | ||
|
||
Args: | ||
x: Input tensor. | ||
v_first: First value for mixing. | ||
padding_mask: Mask for padding tokens. | ||
cache_state: Cached recurrent state. | ||
cache_tmix_x: Cached time mixer values. | ||
cache_cmix_x: Cached channel mixer values. | ||
rnn_mode: Whether to use RNN mode. | ||
train_mode: Whether in training mode. | ||
|
||
Returns: | ||
Processed output tensor and cache information. | ||
""" | ||
if padding_mask is not None: | ||
padding_mask = ops.cast(padding_mask, x.dtype) | ||
padding_mask = ops.expand_dims(padding_mask, axis=-1) | ||
if self.use_initial_norm: | ||
x = self.ln0(x) | ||
if train_mode: | ||
xx, v_first = self.att( | ||
self.ln1(x), | ||
v_first=v_first, | ||
padding_mask=padding_mask, | ||
train_mode=train_mode, | ||
) | ||
x = x + xx | ||
xx = self.ln2(x) | ||
if padding_mask is not None: | ||
xx = xx * padding_mask | ||
x = x + self.ffn(xx, train_mode=train_mode) | ||
return x, v_first | ||
else: | ||
xx, v_first, cache_tmix_x, cache_state = self.att.call( | ||
self.ln1(x), | ||
v_first=v_first, | ||
padding_mask=padding_mask, | ||
last_cache_x=cache_tmix_x, | ||
cache_state=cache_state, | ||
rnn_mode=rnn_mode, | ||
train_mode=train_mode, | ||
) | ||
x = x + xx | ||
xx = self.ln2(x) | ||
if padding_mask is not None: | ||
xx = xx * padding_mask | ||
xx, cache_cmix_x = self.ffn(xx, cache_cmix_x, train_mode=train_mode) | ||
x = x + xx | ||
return x, v_first, cache_state, cache_tmix_x, cache_cmix_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call
method in RWKV7_Block
is complex and has several issues that should be addressed:
- Inconsistent return signature: The method returns a different number of values for training and inference modes. This is problematic and can lead to errors.1
- Direct
.call()
usage: In inference mode,self.att.call(...)
is used. It is better to use the standardself.att(...)
invocation. - Confusing variable names: The variables used to unpack the result of
self.att.call
(cache_tmix_x
,cache_state
) do not match the names of the returned values fromRWKV7_TimeMix
(last_cache_x
,finnal_state
), which is confusing.
This method should be refactored to have a single, clear execution path and a consistent return signature. The sub-layers should also be updated to have consistent return signatures.
Rules References
Footnotes
-
Methods should have a consistent return signature to avoid errors and improve predictability. Returning
None
for values not computed in a certain mode is a good practice. ↩
|
||
Args: | ||
x: Input tensor. | ||
last_cache_x: Cached previous values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Initialize RWKV7 backbone. | ||
|
||
Args: | ||
hidden_size: Hidden dimension size. | ||
head_size: Attention head size. | ||
num_layers: Number of RWKV blocks. | ||
vocabulary_size: Size of vocabulary. | ||
intermediate_dim: Intermediate dimension for FFN. | ||
gate_lora: LoRA dimension for gating. | ||
mv_lora: LoRA dimension for value mixing. | ||
aaa_lora: LoRA dimension for alpha parameters. | ||
decay_lora: LoRA dimension for decay parameters. | ||
dtype: Data type for the layer. | ||
dropout_rate: Dropout rate for regularization. | ||
**kwargs: Additional arguments. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__
method's docstring duplicates information already present in the class-level docstring. According to the style guide's example for backbones, the __init__
method should not have a separate docstring.1 Removing this will make the code more concise and align it with the repository's conventions.
Style Guide References
Footnotes
-
The style guide example for a backbone class shows arguments documented in the class docstring, not in the
__init__
method. ↩
self.backbone.rwkv_layers[i].call( | ||
x, | ||
v_first=v_first, | ||
padding_mask=padding_mask, | ||
cache_state=current_state_cache, | ||
cache_tmix_x=current_token_cache[:, 0], | ||
cache_cmix_x=current_token_cache[:, 1], | ||
rnn_mode=rnn_mode, | ||
train_mode=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly calling the .call()
method of a layer is discouraged. It's better to use the __call__
interface (e.g., self.backbone.rwkv_layers[i](...)
) to ensure all Keras-specific logic is executed. This makes the code more robust and less prone to issues if the underlying layer implementation changes.
self.backbone.rwkv_layers[i].call( | |
x, | |
v_first=v_first, | |
padding_mask=padding_mask, | |
cache_state=current_state_cache, | |
cache_tmix_x=current_token_cache[:, 0], | |
cache_cmix_x=current_token_cache[:, 1], | |
rnn_mode=rnn_mode, | |
train_mode=False, | |
) | |
self.backbone.rwkv_layers[i]( | |
x, | |
v_first=v_first, | |
padding_mask=padding_mask, | |
cache_state=current_state_cache, | |
cache_tmix_x=current_token_cache[:, 0], | |
cache_cmix_x=current_token_cache[:, 1], | |
rnn_mode=rnn_mode, | |
train_mode=False, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because when using bf16, Keras converts the fp32 state into bf16, so you need to use call.
the final hidden representation of the input tokens, and `cache` is | ||
the decoding cache. | ||
""" | ||
state_cachce, last_token_cache = cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not os.path.exists(FLAGS.preset): | ||
os.makedirs(FLAGS.preset) | ||
|
||
souce_model_name = PRESET_MAP[FLAGS.preset] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RWKV7 is one of the strongest RNN models available today, and we now provide a full implementation for it in keras_hub.
📚 References
🔗 Pre-trained Checkpoints (ModelScope)
Numerical-verification and Inference Example notebook
This is the first modern RNN architecture in keras_hub. With the resurgence of recurrent models, more pre-trained RNN backbones will follow; hence this PR also serves as a reference implementation for future work.
Current progress