-
Notifications
You must be signed in to change notification settings - Fork 278
Add Mixtral #2196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Mixtral #2196
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments! Please provide a demo colab
) | ||
self._query_dense.build(inputs_shape) | ||
|
||
self._key_dense = keras.layers.EinsumDense( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the layer names to be compatible with enable_lora
@keras_hub_export("keras_hub.models.MixtralBackbone") | ||
class MixtralBackbone(Backbone): | ||
""" | ||
The Mixtral Transformer core architecture with hyperparameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring first line should follow """
preprocessor("League of legends") | ||
# Tokenize a batch of sentences. | ||
sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why tf?
target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) | ||
embeddings = None | ||
with tf.GradientTape(watch_accessed_variables=True) as tape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why tf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't recommend using backend specific examples, For generic usage use keras.ops or numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some conflicts in the api directory due to the recent changes, please resolve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conflicts resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't recommend using backend specific examples, For generic usage use keras.ops or numpy
@sachinprasadhs like I mentioned above, there is already tf.GradientTape examples in existing model docstrings, that should be cleaned up in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added few more comments.
This network implements a Transformer-based decoder network, | ||
Mixtral, as described in | ||
["Mixtral 7B"](https://arxiv.org/pdf/2310.06825.pdf). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference provided here is for Misral not Mixtral, add the correct reference.
router_logits, num_experts, top_k, attention_mask=None | ||
): | ||
""" | ||
Compute the load balancing auxiliary loss for a single MoE layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in the same line after """, and then a new blank line before Args.
from keras import ops | ||
|
||
|
||
# TODO: Deprecate this in favor of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support Keras 2 anymore in Keras Hub, I guess you can get rid of this
# Below is a workaround for `ops.triu` for Keras 2. | ||
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is | ||
# removed. | ||
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keras 2 support is removed now, you can enable this
class MixtralCausalLMPreprocessorTest(TestCase): | ||
def setUp(self): | ||
self.tokenizer = MixtralTokenizer( | ||
# Generated using create_mixtral_test_proto.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is missing.
This PR adds Mixtral to Keras Hub.
Reference