-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3.1 #2132
base: master
Are you sure you want to change the base?
Llama3.1 #2132
Conversation
code fix
code fix
code fix
I'm surprised this is coming in as a separate arch. How hard would it be to consolidate and add new config options to existing llama? Ideally a minor release like this is not a new entire set of symbols in the library. |
The only difference is some changes in RoPE. Will fix! |
I've removed that separate directory from keras-hub/models and did modifications to support llama3.1 weights using old APIs. Only we need to pass the rope related parameters in backbone. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Did a quick review.
- Let's fill up the arg descriptions in the doc-string.
- We need better argument names.
old_context_len
is a bit awkward. - Also, can the
if all(...)
condition be simplified? Alternatively, is there a way to set default values such that we don't need the if condition at all?
Pinging @mattdangerw for a review here as well. Wonder if the added ops are too custom (like, for example, old_context_len
is weird) for the RotaryEmbedding
layer, and they warrant a custom layer for Llama 3.1 after all.
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.max_wavelength = max_wavelength | ||
self.sequence_axis = sequence_axis | ||
self.feature_axis = feature_axis | ||
self.scaling_factor = scaling_factor | ||
self.llow_freq_factor = low_freq_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why self.llow_freq_factor
and not self.low_freq_factor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a typo.
inverse_freq = ops.where( | ||
is_medium_freq, smoothed_inv_freq, inverse_freq | ||
) | ||
ops.cast(inverse_freq, "float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be inverse_freq = ops.cast(inverse_freq, "float32")
? Or is this line meant to be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this one, and verified.
@@ -66,13 +68,19 @@ def __init__( | |||
scaling_factor=1.0, | |||
sequence_axis=1, | |||
feature_axis=-1, | |||
low_freq_factor=None, | |||
high_freq_factor=None, | |||
old_context_len=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This naming is a bit awkward. I don't know what to call it either. Pinging @mattdangerw / @divyashreepathihalli, who might have a better idea.
old_context_len = self.old_context_len | ||
low_freq_factor = self.llow_freq_factor | ||
high_freq_factor = self.high_freq_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just re-use self.{variable}
instead of defining new variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified this way
low_freq_factor=None, | ||
high_freq_factor=None, | ||
old_context_len=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add description of all these arguments in the doc-string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, can you explain more of what's going on here? What is "old context length"? Let's probably rename this.
) * inverse_freq / factor + smooth_factor * inverse_freq | ||
is_medium_freq = ops.logical_and( | ||
ops.cast( | ||
ops.greater_equal(wavelen, high_freq_wavelen), dtype="int8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does dtype=bool
not work here? Why do we need to cast to int8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Earlier we were using ops.bitwise_and
and that was not working, it required integer data type.
Will need to check for ops.logical_and
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified with bool
self.old_context_len, | ||
) | ||
): | ||
factor = self.scaling_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same as the scaling factor used to scale positions:
positions = positions / ops.cast(self.scaling_factor, "float32") |
Do we need a separate argument for this particular scaling factor?
inverse_freq, | ||
) | ||
|
||
# otherwise: interpolate between the two, using a smooth factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "otherwise"? Interpolation is happening irrespective of what you're doing above.
|
||
inverse_freq = ops.where( | ||
ops.greater(wavelen, low_freq_wavelen), | ||
inverse_freq / factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wherever tensors are involved, prefer using keras.ops.{op}
. So, ops.divide(...)
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it required for operations between 2 tensors, or also for 1 scalar and 1 tensor?
if all( | ||
x is not None | ||
for x in ( | ||
self.llow_freq_factor, | ||
self.high_freq_factor, | ||
self.old_context_len, | ||
) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, is there a better way of specifying this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to have a discussion on it, either we can pass some boolean flag for llama3.1 specific usecase or other way, currently these 3 parameters were newly introduced for llama3.1, and user might provide some value to one of these parameters and, then we don't know how to use it.
Typo fix
Removed repeated declarations of variables.
Colab Notebooks:-
|
from keras_hub.src.api_export import keras_hub_export | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.LlamaRotaryEmbedding") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry just notice we are adding this to keras_hub.layers
. We should not. If we have a generic layer used by a few different models we can put it here with a generic name.
But in this case, we should move this into llama, and not expose it (like all the other llama layers).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it to llama.
args_none = [x is None for x in grouped_args] | ||
if any(args_none) and not all(args_none): | ||
raise ValueError( | ||
"Either all of ... should be set, or none of ... should be set" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to actually fill this in. this is user facing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
if transformers_config.get("rope_scaling", None) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to add some validation here for weirder llama uploads. E.g. inside the if block if transformers_config["rope_type"] != "llama3": raise ValueError(help message)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
included this check, inside the if block.
This layer encodes absolute positional information with a rotation | ||
matrix. It calculates the rotary encoding with a mix of sine and | ||
cosine functions with geometrically increasing wavelengths. | ||
Defined and formulated in | ||
[RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4). | ||
The input must be a tensor with shape a sequence dimension and a feature | ||
dimension. Typically, this will either an input with shape | ||
`(batch_size, sequence_length, feature_length)` or | ||
`(batch_size, sequence_length, num_heads, feature_length)`. | ||
This layer will return a new tensor with the rotary embedding applied to | ||
the input tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reads the same as https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/layers/modeling/rotary_embedding.py#L11-L21. We should probably add a line on how it is different from the original layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a line for it.
|
||
|
||
@keras_hub_export("keras_hub.layers.LlamaRotaryEmbedding") | ||
class LlamaRotaryEmbedding(keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question - why can we not subclass the original RoPE layer, and override just __init__
and the _get_inv_freq
function? We don't need to copy the whole thing over, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, passed position_scaling_factor as scaling_factor to RotaryEmbedding class through super().init
@@ -40,10 +40,18 @@ class LlamaBackbone(Backbone): | |||
a three-layer feedforward network for each transformer. | |||
num_key_value_heads (int): The number of key and value attention heads | |||
for each transformer. | |||
rope_max_wavelength (int, optional): The maximum angular wavelength of | |||
rope_max_wavelength: (int, optional): The maximum angular wavelength of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this is not how we specify args. It should be
rope_max_wavelength: int. The maximum .... Defaults to None.
@abheesht17 |
This layer will return a new tensor with the rotary embedding applied to | ||
the input tensor. | ||
It is extended from `RotaryEmbedding` layer in `keras_hub.layers`. | ||
It has additional smoothning and interpolation for some frequency ranges. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
smoothning -> smoothening
} | ||
|
||
if transformers_config.get("rope_scaling", None) is not None: | ||
if transformers_config["rope_scaling"]["rope_type"] != "llama3": | ||
raise ValueError("The config shall be a valid llama3 config.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall -> should?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pctablet505 - can you please make this one last change? Thanks!
|
||
def __init__( | ||
self, | ||
max_wavelength=10000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to pass this to the superclass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was done
calculation of rotary embedding. Defaults to `1.0` | ||
rope_frequency_adjustment_factor: flaot. The scaling factor | ||
used to scale the inverse frequencies. Defaults to `None`. | ||
rope_low_freq_factor: flaot. The low frequency scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flaot --> float. Here, and other places too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Did you fix the from_preset()
issue you were facing earlier?
Added Support for Llama3.1
Notebook to verify numerics using repo
Notebook for actual numeric verification and code modifications
Tokenizer part is still remaining to be done.