Skip to content

Add Mixtral #2196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

Add Mixtral #2196

wants to merge 13 commits into from

Conversation

kanpuriyanawab
Copy link
Collaborator

@kanpuriyanawab kanpuriyanawab commented Apr 2, 2025

This PR adds Mixtral to Keras Hub.

Reference

@kanpuriyanawab kanpuriyanawab marked this pull request as ready for review April 10, 2025 08:40
@kanpuriyanawab
Copy link
Collaborator Author

Output matching :

image

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments! Please provide a demo colab

)
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.EinsumDense(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the layer names to be compatible with enable_lora

@keras_hub_export("keras_hub.models.MixtralBackbone")
class MixtralBackbone(Backbone):
"""
The Mixtral Transformer core architecture with hyperparameters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring first line should follow """

preprocessor("League of legends")
# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why tf?

target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1)
embeddings = None
with tf.GradientTape(watch_accessed_variables=True) as tape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why tf?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

borrowed docstring

Screenshot 2025-04-16 at 7 18 14 PM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't recommend using backend specific examples, For generic usage use keras.ops or numpy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some conflicts in the api directory due to the recent changes, please resolve.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conflicts resolved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't recommend using backend specific examples, For generic usage use keras.ops or numpy

@sachinprasadhs like I mentioned above, there is already tf.GradientTape examples in existing model docstrings, that should be cleaned up in a separate PR.

@kanpuriyanawab
Copy link
Collaborator Author

mixtral output matching

Screenshot 2025-04-20 at 3 06 15 PM Screenshot 2025-04-20 at 3 06 32 PM

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added few more comments.

This network implements a Transformer-based decoder network,
Mixtral, as described in
["Mixtral 7B"](https://arxiv.org/pdf/2310.06825.pdf).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference provided here is for Misral not Mixtral, add the correct reference.

router_logits, num_experts, top_k, attention_mask=None
):
"""
Compute the load balancing auxiliary loss for a single MoE layer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in the same line after """, and then a new blank line before Args.

from keras import ops


# TODO: Deprecate this in favor of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support Keras 2 anymore in Keras Hub, I guess you can get rid of this

# Below is a workaround for `ops.triu` for Keras 2.
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is
# removed.
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keras 2 support is removed now, you can enable this

class MixtralCausalLMPreprocessorTest(TestCase):
def setUp(self):
self.tokenizer = MixtralTokenizer(
# Generated using create_mixtral_test_proto.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is missing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants