Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.1 #2132

Open
wants to merge 49 commits into
base: master
Choose a base branch
from
Open

Llama3.1 #2132

wants to merge 49 commits into from

Conversation

pctablet505
Copy link
Collaborator

Added Support for Llama3.1

Notebook to verify numerics using repo

Notebook for actual numeric verification and code modifications

Tokenizer part is still remaining to be done.

@mattdangerw
Copy link
Member

I'm surprised this is coming in as a separate arch. How hard would it be to consolidate and add new config options to existing llama? Ideally a minor release like this is not a new entire set of symbols in the library.

@abheesht17
Copy link
Collaborator

abheesht17 commented Mar 11, 2025

I'm surprised this is coming in as a separate arch. How hard would it be to consolidate and add new config options to existing llama? Ideally a minor release like this is not a new entire set of symbols in the library.

The only difference is some changes in RoPE. Will fix!

@pctablet505
Copy link
Collaborator Author

I've removed that separate directory from keras-hub/models and did modifications to support llama3.1 weights using old APIs. Only we need to pass the rope related parameters in backbone.

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Did a quick review.

  • Let's fill up the arg descriptions in the doc-string.
  • We need better argument names. old_context_len is a bit awkward.
  • Also, can the if all(...) condition be simplified? Alternatively, is there a way to set default values such that we don't need the if condition at all?

Pinging @mattdangerw for a review here as well. Wonder if the added ops are too custom (like, for example, old_context_len is weird) for the RotaryEmbedding layer, and they warrant a custom layer for Llama 3.1 after all.

**kwargs,
):
super().__init__(**kwargs)
self.max_wavelength = max_wavelength
self.sequence_axis = sequence_axis
self.feature_axis = feature_axis
self.scaling_factor = scaling_factor
self.llow_freq_factor = low_freq_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why self.llow_freq_factor and not self.low_freq_factor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a typo.

inverse_freq = ops.where(
is_medium_freq, smoothed_inv_freq, inverse_freq
)
ops.cast(inverse_freq, "float32")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be inverse_freq = ops.cast(inverse_freq, "float32")? Or is this line meant to be removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this one, and verified.

@@ -66,13 +68,19 @@ def __init__(
scaling_factor=1.0,
sequence_axis=1,
feature_axis=-1,
low_freq_factor=None,
high_freq_factor=None,
old_context_len=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This naming is a bit awkward. I don't know what to call it either. Pinging @mattdangerw / @divyashreepathihalli, who might have a better idea.

Comment on lines 160 to 162
old_context_len = self.old_context_len
low_freq_factor = self.llow_freq_factor
high_freq_factor = self.high_freq_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just re-use self.{variable} instead of defining new variables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified this way

Comment on lines 71 to 73
low_freq_factor=None,
high_freq_factor=None,
old_context_len=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add description of all these arguments in the doc-string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, can you explain more of what's going on here? What is "old context length"? Let's probably rename this.

) * inverse_freq / factor + smooth_factor * inverse_freq
is_medium_freq = ops.logical_and(
ops.cast(
ops.greater_equal(wavelen, high_freq_wavelen), dtype="int8"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does dtype=bool not work here? Why do we need to cast to int8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier we were using ops.bitwise_and and that was not working, it required integer data type.
Will need to check for ops.logical_and

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified with bool

self.old_context_len,
)
):
factor = self.scaling_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as the scaling factor used to scale positions:

positions = positions / ops.cast(self.scaling_factor, "float32")
?

Do we need a separate argument for this particular scaling factor?

inverse_freq,
)

# otherwise: interpolate between the two, using a smooth factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "otherwise"? Interpolation is happening irrespective of what you're doing above.


inverse_freq = ops.where(
ops.greater(wavelen, low_freq_wavelen),
inverse_freq / factor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wherever tensors are involved, prefer using keras.ops.{op}. So, ops.divide(...) here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it required for operations between 2 tensors, or also for 1 scalar and 1 tensor?

Comment on lines 151 to 158
if all(
x is not None
for x in (
self.llow_freq_factor,
self.high_freq_factor,
self.old_context_len,
)
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, is there a better way of specifying this condition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to have a discussion on it, either we can pass some boolean flag for llama3.1 specific usecase or other way, currently these 3 parameters were newly introduced for llama3.1, and user might provide some value to one of these parameters and, then we don't know how to use it.

@abheesht17 abheesht17 requested a review from mattdangerw March 17, 2025 13:41
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label Apr 1, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 1, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label Apr 1, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 1, 2025
@pctablet505
Copy link
Collaborator Author

pctablet505 commented Apr 3, 2025

Colab Notebooks:-

from keras_hub.src.api_export import keras_hub_export


@keras_hub_export("keras_hub.layers.LlamaRotaryEmbedding")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just notice we are adding this to keras_hub.layers. We should not. If we have a generic layer used by a few different models we can put it here with a generic name.

But in this case, we should move this into llama, and not expose it (like all the other llama layers).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it to llama.

args_none = [x is None for x in grouped_args]
if any(args_none) and not all(args_none):
raise ValueError(
"Either all of ... should be set, or none of ... should be set"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to actually fill this in. this is user facing!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
if transformers_config.get("rope_scaling", None) is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to add some validation here for weirder llama uploads. E.g. inside the if block if transformers_config["rope_type"] != "llama3": raise ValueError(help message)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

included this check, inside the if block.

Comment on lines 13 to 23
This layer encodes absolute positional information with a rotation
matrix. It calculates the rotary encoding with a mix of sine and
cosine functions with geometrically increasing wavelengths.
Defined and formulated in
[RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
The input must be a tensor with shape a sequence dimension and a feature
dimension. Typically, this will either an input with shape
`(batch_size, sequence_length, feature_length)` or
`(batch_size, sequence_length, num_heads, feature_length)`.
This layer will return a new tensor with the rotary embedding applied to
the input tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reads the same as https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/layers/modeling/rotary_embedding.py#L11-L21. We should probably add a line on how it is different from the original layer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a line for it.



@keras_hub_export("keras_hub.layers.LlamaRotaryEmbedding")
class LlamaRotaryEmbedding(keras.layers.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question - why can we not subclass the original RoPE layer, and override just __init__ and the _get_inv_freq function? We don't need to copy the whole thing over, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, passed position_scaling_factor as scaling_factor to RotaryEmbedding class through super().init

@@ -40,10 +40,18 @@ class LlamaBackbone(Backbone):
a three-layer feedforward network for each transformer.
num_key_value_heads (int): The number of key and value attention heads
for each transformer.
rope_max_wavelength (int, optional): The maximum angular wavelength of
rope_max_wavelength: (int, optional): The maximum angular wavelength of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is not how we specify args. It should be

rope_max_wavelength: int. The maximum .... Defaults to None.

@pctablet505
Copy link
Collaborator Author

@abheesht17
I've made all the necessary changes.
May you please approve it for merge.

This layer will return a new tensor with the rotary embedding applied to
the input tensor.
It is extended from `RotaryEmbedding` layer in `keras_hub.layers`.
It has additional smoothning and interpolation for some frequency ranges.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smoothning -> smoothening

}

if transformers_config.get("rope_scaling", None) is not None:
if transformers_config["rope_scaling"]["rope_type"] != "llama3":
raise ValueError("The config shall be a valid llama3 config.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall -> should?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pctablet505 - can you please make this one last change? Thanks!


def __init__(
self,
max_wavelength=10000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to pass this to the superclass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was done

calculation of rotary embedding. Defaults to `1.0`
rope_frequency_adjustment_factor: flaot. The scaling factor
used to scale the inverse frequencies. Defaults to `None`.
rope_low_freq_factor: flaot. The low frequency scaling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flaot --> float. Here, and other places too.

@abheesht17 abheesht17 added the kokoro:force-run Runs Tests on GPU label Apr 9, 2025
Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Did you fix the from_preset() issue you were facing earlier?

@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 9, 2025
@pctablet505 pctablet505 added the kokoro:force-run Runs Tests on GPU label Apr 9, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants