diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py
index 9562e7df58..6b65ab8cec 100644
--- a/keras_hub/api/layers/__init__.py
+++ b/keras_hub/api/layers/__init__.py
@@ -58,6 +58,9 @@
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,
)
+from keras_hub.src.models.moonshine.moonshine_audio_converter import (
+ MoonshineAudioConverter,
+)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py
index 2f510446d7..041531efbb 100644
--- a/keras_hub/api/models/__init__.py
+++ b/keras_hub/api/models/__init__.py
@@ -232,6 +232,16 @@
from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor,
)
+from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
+ MoonshineAudioToText,
+)
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
+ MoonshineSeq2SeqLMPreprocessor,
+)
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
from keras_hub.src.models.object_detector import ObjectDetector
from keras_hub.src.models.object_detector import (
ObjectDetector as ImageObjectDetector,
diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py
index 9f73bfd665..29accb7b08 100644
--- a/keras_hub/api/tokenizers/__init__.py
+++ b/keras_hub/api/tokenizers/__init__.py
@@ -24,6 +24,9 @@
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
diff --git a/keras_hub/src/models/moonshine/__init__.py b/keras_hub/src/models/moonshine/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/keras_hub/src/models/moonshine/moonshine_audio_converter.py b/keras_hub/src/models/moonshine/moonshine_audio_converter.py
new file mode 100644
index 0000000000..9c713418b3
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_audio_converter.py
@@ -0,0 +1,260 @@
+import warnings
+
+import keras
+
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.models.moonshine.moonshine_layers import (
+ moonshine_kernel_initializer,
+)
+from keras_hub.src.utils.keras_utils import clone_initializer
+
+
+@keras_hub_export("keras_hub.layers.MoonshineAudioConverter")
+class MoonshineAudioConverter(AudioConverter):
+ """Moonshine preprocessor and audio converter layer.
+
+ This layer processes raw audio waveforms for the Moonshine ASR model. Audio
+ is formatted as a batched tensor at a 16kHz sample rate and validated for
+ length (0.1 to 64 seconds). The layer downsamples and extracts key features
+ from the audio signal through a series of convolutional operations,
+ normalization, and nonlinear activations.
+
+ Args:
+ filter_dim: int. The number of filters for the first convolutional
+ layer. This influences the dimensionality of the feature extraction
+ pipeline and determines the richness of the audio representation.
+ sampling_rate: int, optional. The audio sampling rate in Hz. Defaults to
+ 16,000.
+ padding_value: float, optional. The value for padding. Defaults to 0.0.
+ do_normalize: bool, optional. Whether to normalize inputs. Defaults to
+ False.
+ return_attention_mask: bool, optional. Whether to return an attention
+ mask. Defaults to True.
+ initializer_range: float, optional. The standard deviation for kernel
+ initialization. Defaults to 0.02.
+ **kwargs: Additional keyword arguments passed to the base AudioConverter
+ class for customizing the underlying preprocessing behavior.
+
+ Examples:
+ ```python
+ import keras
+ from keras_hub.layers import MoonshineAudioConverter
+
+ # Create a dummy audio input (1 second at 16kHz).
+ dummy_audio = keras.ops.convert_to_tensor(
+ [[0.1] * 16000],
+ dtype="float32"
+ )
+ dummy_audio = keras.ops.expand_dims(dummy_audio, axis=-1)
+
+ # Initialize the preprocessor.
+ preprocessor = MoonshineAudioConverter(filter_dim=256)
+
+ # Process the audio.
+ features = preprocessor(dummy_audio)
+
+ # Output shapes.
+ print(features["input_values"].shape) # Expected: (1, 40, 256)
+ print(features["attention_mask"].shape) # Expected: (1, 40)
+ ```
+ """
+
+ # References:
+ # Defined and formulated based on the Hugging Face implementation of the
+ # Wav2Vec2FeatureExtractor class (https://github.com/huggingface/transformers/blob/66f29aaaf55c8fe0c3dbcd24beede2ca4effac56/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L31-L243)
+ # and the convolutional layer structure defined in the UsefulSensors
+ # implementation of the AudioPreprocessor class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L6-L32).
+
+ backbone_cls = MoonshineBackbone
+
+ def __init__(
+ self,
+ filter_dim,
+ sampling_rate=16000,
+ padding_value=0.0,
+ do_normalize=False,
+ return_attention_mask=True,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.filter_dim = filter_dim
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+ self.do_normalize = do_normalize
+ self.return_attention_mask = return_attention_mask
+ self.initializer_range = initializer_range
+ self.kernel_initializer = moonshine_kernel_initializer(
+ initializer_range=initializer_range
+ )
+
+ self.conv1 = keras.layers.Conv1D(
+ filters=filter_dim,
+ kernel_size=127,
+ strides=64,
+ use_bias=False,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ )
+ self.tanh = keras.layers.Activation("tanh")
+ self.group_norm = keras.layers.GroupNormalization(
+ groups=1,
+ axis=-1,
+ epsilon=1e-5,
+ )
+ self.conv2 = keras.layers.Conv1D(
+ filters=2 * filter_dim,
+ kernel_size=7,
+ strides=3,
+ padding="valid",
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ )
+ self.gelu1 = keras.layers.Activation("gelu")
+ self.conv3 = keras.layers.Conv1D(
+ filters=filter_dim,
+ kernel_size=3,
+ strides=2,
+ padding="valid",
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ )
+ self.gelu2 = keras.layers.Activation("gelu")
+
+ def build(self, input_shape):
+ self.conv1.build((None, None, 1))
+ self.group_norm.build((None, None, self.filter_dim))
+ self.conv2.build((None, None, self.filter_dim))
+ self.conv3.build((None, None, 2 * self.filter_dim))
+ self.built = True
+
+ def call(
+ self,
+ inputs,
+ sampling_rate=None,
+ padding=None,
+ max_length=None,
+ pad_to_multiple_of=None,
+ return_tensors=None,
+ ):
+ # Validate sampling rate.
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"Expected sampling_rate {self.sampling_rate}, got "
+ f"{sampling_rate}"
+ )
+
+ # Ensure inputs are (batch_size, time_steps, 1).
+ if keras.ops.ndim(inputs) == 2:
+ inputs = keras.ops.expand_dims(inputs, axis=-1)
+ elif keras.ops.ndim(inputs) != 3 or keras.ops.shape(inputs)[-1] != 1:
+ raise ValueError(
+ "Inputs must be mono audio: (batch_size, time_steps, 1)"
+ )
+
+ # Get original length and validate duration.
+ original_length = keras.ops.shape(inputs)[1]
+ duration = original_length / self.sampling_rate
+ # Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/transcribe.py#L20
+ if duration < 0.1 or duration > 64:
+ raise warnings.warn(
+ f"Audio duration must be between 0.1 and 64 seconds, got "
+ f"{duration:.2f} seconds in a single transcribe call. For "
+ "transcribing longer segments, pre-segment your audio and "
+ "provide shorter segments."
+ )
+ # Handle padding.
+ if padding == "longest":
+ max_length = original_length
+ elif padding == "max_length" and max_length is None:
+ max_length = original_length
+ if max_length is not None:
+ if pad_to_multiple_of:
+ max_length = (
+ (max_length + pad_to_multiple_of - 1) // pad_to_multiple_of
+ ) * pad_to_multiple_of
+ if original_length < max_length:
+ padding_amount = max_length - original_length
+ inputs = keras.ops.pad(
+ inputs,
+ [(0, 0), (0, padding_amount), (0, 0)],
+ constant_values=self.padding_value,
+ )
+
+ # Normalize if enabled.
+ if self.do_normalize:
+ mean = keras.ops.mean(inputs, axis=1, keepdims=True)
+ var = keras.ops.var(inputs, axis=1, keepdims=True)
+ inputs = (inputs - mean) / keras.ops.sqrt(var + 1e-7)
+
+ # Apply convolutional feature extraction.
+ x = self.conv1(inputs)
+ x = self.tanh(x)
+ x = self.group_norm(x)
+ x = self.conv2(x)
+ x = self.gelu1(x)
+ x = self.conv3(x)
+ features = self.gelu2(x)
+
+ # Generate attention mask.
+ output_length = keras.ops.shape(features)[1]
+ attention_mask = None
+ if self.return_attention_mask:
+ # Calculate mask length through the network's downsampling ops.
+ # Step 1: First conv layer (conv1).
+ conv1_out = (original_length - 127 + 1) / 64
+ # Step 2: Second conv layer (conv2).
+ conv2_out = (conv1_out - 7 + 1) / 3
+ # Step 3: Third conv layer (conv3).
+ conv3_out = (conv2_out - 3 + 1) / 2
+
+ # Apply ceil() to get the final mask length as an int.
+ mask_length = keras.ops.cast(
+ keras.ops.ceil(keras.ops.cast(conv3_out, "float32")), "int32"
+ )
+ # Broadcast the mask length to match the batch size.
+ batch_size = keras.ops.shape(inputs)[0]
+ mask_length = keras.ops.broadcast_to(mask_length, [batch_size])
+ indices = keras.ops.arange(output_length, dtype="int32")
+ attention_mask = keras.ops.cast(
+ indices[None, :] < mask_length[:, None], dtype="int32"
+ )
+
+ output = {"input_values": features}
+ if attention_mask is not None:
+ output["attention_mask"] = attention_mask
+
+ return output
+
+ def compute_output_shape(self, input_shape):
+ # [batch_size, time_steps] → [batch_size, time_steps, 1].
+ if len(input_shape) == 2:
+ expanded_shape = (input_shape[0], input_shape[1], 1)
+ else:
+ expanded_shape = input_shape
+ # Compute output shape sequentially.
+ x_shape = self.conv1.compute_output_shape(expanded_shape)
+ x_shape = self.tanh.compute_output_shape(x_shape)
+ x_shape = self.group_norm.compute_output_shape(x_shape)
+ x_shape = self.conv2.compute_output_shape(x_shape)
+ x_shape = self.gelu1.compute_output_shape(x_shape)
+ x_shape = self.conv3.compute_output_shape(x_shape)
+ x_shape = self.gelu2.compute_output_shape(x_shape)
+ output_shape = {"input_values": x_shape}
+ if self.return_attention_mask:
+ # [batch_size, output_time_steps].
+ output_shape["attention_mask"] = (expanded_shape[0], x_shape[1])
+ return output_shape
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "filter_dim": self.filter_dim,
+ "sampling_rate": self.sampling_rate,
+ "padding_value": self.padding_value,
+ "do_normalize": self.do_normalize,
+ "return_attention_mask": self.return_attention_mask,
+ "initializer_range": self.initializer_range,
+ }
+ )
+ return config
diff --git a/keras_hub/src/models/moonshine/moonshine_audio_converter_test.py b/keras_hub/src/models/moonshine/moonshine_audio_converter_test.py
new file mode 100644
index 0000000000..ff80fcfe10
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_audio_converter_test.py
@@ -0,0 +1,114 @@
+import keras
+
+from keras_hub.src.models.moonshine.moonshine_audio_converter import (
+ MoonshineAudioConverter,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineAudioConverterTest(TestCase):
+ def setUp(self):
+ super().setUp()
+ self.filter_dim = 256
+ self.preprocessor = MoonshineAudioConverter(filter_dim=self.filter_dim)
+ self.input_data = keras.ops.convert_to_tensor(
+ [[0.1] * 16000], dtype="float32"
+ )
+ self.input_data = keras.ops.expand_dims(self.input_data, axis=-1)
+ self.init_kwargs = {
+ "filter_dim": self.filter_dim,
+ "sampling_rate": 16000,
+ "padding_value": 0.0,
+ "do_normalize": False,
+ "return_attention_mask": True,
+ "initializer_range": 0.02,
+ }
+
+ def test_output_shape(self):
+ output = self.preprocessor(self.input_data)
+ self.assertEqual(
+ keras.ops.shape(output["input_values"]), (1, 40, self.filter_dim)
+ )
+ self.assertEqual(keras.ops.shape(output["attention_mask"]), (1, 40))
+ self.assertAllEqual(
+ output["attention_mask"], keras.ops.ones((1, 40), dtype="int32")
+ )
+
+ def test_padding(self):
+ max_length = 20000
+ output = self.preprocessor(
+ self.input_data, padding="max_length", max_length=max_length
+ )
+ self.assertEqual(
+ keras.ops.shape(output["input_values"]), (1, 50, self.filter_dim)
+ )
+ self.assertEqual(keras.ops.shape(output["attention_mask"]), (1, 50))
+ expected_mask = keras.ops.concatenate(
+ [
+ keras.ops.ones((1, 40), dtype="int32"),
+ keras.ops.zeros((1, 10), dtype="int32"),
+ ],
+ axis=1,
+ )
+ self.assertAllEqual(output["attention_mask"], expected_mask)
+
+ def test_normalization(self):
+ preprocessor_no_norm = MoonshineAudioConverter(
+ filter_dim=self.filter_dim, do_normalize=False
+ )
+ preprocessor_norm = MoonshineAudioConverter(
+ filter_dim=self.filter_dim, do_normalize=True
+ )
+ input_data = keras.ops.arange(16000, dtype="float32") / 16000 # Values
+ # from 0 to ~1
+ input_data = keras.ops.expand_dims(input_data, axis=0) # (1, 16000)
+ input_data = keras.ops.expand_dims(input_data, axis=-1) # (1, 16000, 1)
+ output_no_norm = preprocessor_no_norm(input_data)
+ output_norm = preprocessor_norm(input_data)
+ self.assertFalse(
+ keras.ops.all(
+ output_no_norm["input_values"] == output_norm["input_values"]
+ )
+ )
+
+ def test_sampling_rate_validation(self):
+ # Test with the correct sampling rate (should not raise an error).
+ self.preprocessor(
+ self.input_data, sampling_rate=self.preprocessor.sampling_rate
+ )
+ # Test with an incorrect sampling rate (should raise ValueError).
+ with self.assertRaises(ValueError):
+ self.preprocessor(self.input_data, sampling_rate=8000)
+
+ def test_get_config(self):
+ config = self.preprocessor.get_config()
+ self.assertIsInstance(config, dict)
+ self.assertEqual(config["filter_dim"], self.filter_dim)
+ self.assertEqual(config["sampling_rate"], 16000)
+ self.assertEqual(config["padding_value"], 0.0)
+ self.assertEqual(config["do_normalize"], False)
+ self.assertEqual(config["return_attention_mask"], True)
+ self.assertEqual(config["initializer_range"], 0.02)
+
+ def test_correctness(self):
+ audio_input = keras.ops.convert_to_tensor(
+ [[1.0, 2.0, 3.0] + [0.0] * 15997], dtype="float32"
+ )
+ audio_input = keras.ops.expand_dims(audio_input, axis=-1)
+ converter = MoonshineAudioConverter(**self.init_kwargs)
+
+ outputs = converter(audio_input)
+ self.assertIn("input_values", outputs)
+ self.assertIn("attention_mask", outputs)
+
+ self.assertEqual(
+ keras.ops.shape(outputs["input_values"]), (1, 40, self.filter_dim)
+ )
+ self.assertEqual(keras.ops.shape(outputs["attention_mask"]), (1, 40))
+ self.assertAllEqual(
+ outputs["attention_mask"], keras.ops.ones((1, 40), dtype="int32")
+ )
+
+ def test_serialization(self):
+ instance = MoonshineAudioConverter(**self.init_kwargs)
+ self.run_serialization_test(instance=instance)
diff --git a/keras_hub/src/models/moonshine/moonshine_audio_to_text.py b/keras_hub/src/models/moonshine/moonshine_audio_to_text.py
new file mode 100644
index 0000000000..1248f819ea
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_audio_to_text.py
@@ -0,0 +1,483 @@
+import keras
+import tensorflow as tf
+from keras import tree
+
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
+ MoonshineSeq2SeqLMPreprocessor,
+)
+from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
+from keras_hub.src.utils.tensor_utils import any_equal
+
+
+@keras_hub_export("keras_hub.models.MoonshineAudioToText")
+class MoonshineAudioToText(Seq2SeqLM):
+ """An end-to-end Moonshine model for audio-to-text tasks.
+
+ A Seq2Seq LM designed for audio-to-text tasks, such as speech recognition.
+ The encoder processes audio features, and the decoder generates text
+ transcriptions. You can finetune `MoonshineAudioToText` for any
+ audio-to-text task (e.g., live transcription or voice commands).
+
+ This model includes a `generate()` method for text generation based on audio
+ inputs and an optional text prompt for the decoder. The generation strategy
+ is controlled by a `sampler` argument passed to `compile()`. By default,
+ `"top_k"` sampling is used.
+
+ Args:
+ backbone: A `keras_hub.models.MoonshineBackbone` instance.
+ preprocessor: A `keras_hub.models.MoonshineSeq2SeqLMPreprocessor` or
+ `None`. If `None`, inputs must be preprocessed before calling the
+ model.
+
+ Examples:
+ ```python
+ # Initialize model from preset.
+ moonshine_lm = keras_hub.models.MoonshineAudioToText.from_preset(
+ "moonshine_base"
+ )
+
+ # Generate with single audio input.
+ audio_tensor = keras.random.normal((1, 16000, 1))
+ moonshine_lm.generate({"audio": audio_tensor})
+
+ # Generate with text prompt.
+ moonshine_lm.generate({"audio": audio_tensor, "text": "quick"})
+
+ # Use different sampling strategy.
+ moonshine_lm.compile(sampler="greedy")
+ moonshine_lm.generate({"audio": audio_tensor})
+ """
+
+ # References:
+ # Defined and formulated based on the Hugging Face implementation of the
+ # MoonshineForConditionalGeneration class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1509-L1626).
+
+ backbone_cls = MoonshineBackbone
+ preprocessor_cls = MoonshineSeq2SeqLMPreprocessor
+
+ def __init__(self, backbone, preprocessor=None, **kwargs):
+ # === Layers ===
+ self.backbone = backbone
+ self.preprocessor = preprocessor
+
+ # === Functional Model ===
+ inputs = backbone.input
+ hidden_states = backbone(inputs)["decoder_sequence_output"]
+ outputs = backbone.token_embedding(hidden_states, reverse=True)
+ super().__init__(
+ inputs=inputs,
+ outputs=outputs,
+ **kwargs,
+ )
+
+ def call_decoder_with_cache(
+ self,
+ encoder_hidden_states,
+ encoder_padding_mask,
+ decoder_token_ids,
+ self_attention_cache=None,
+ self_attention_cache_update_index=None,
+ cross_attention_cache=None,
+ decoder_padding_mask=None,
+ ):
+ """Process decoder inputs with attention caching for efficient
+ generation.
+
+ Args:
+ encoder_hidden_states: Tensor. Encoder outputs.
+ encoder_padding_mask: Tensor. Padding mask for encoder outputs.
+ decoder_token_ids: Tensor. Decoder input token IDs.
+ self_attention_cache: Tensor. Cache for self-attention layers.
+ self_attention_cache_update_index: int. Index for cache updates.
+ decoder_attention_mask: Tensor, optional. Mask for decoder attention
+
+ Returns:
+ Tuple of (logits, hidden_states, self_attention_cache).
+ """
+ tokens = self.backbone.token_embedding(decoder_token_ids)
+ x = tokens
+
+ # Cache management for audio-to-text generation.
+ self_attention_caches = []
+ cross_attention_caches = []
+
+ # Determine if this is initialization or generation.
+ if self_attention_cache_update_index is None:
+ # Initialization: Process full sequence, compute caches.
+ seq_len = keras.ops.shape(decoder_token_ids)[1]
+ positions = keras.ops.arange(0, seq_len, dtype="int32")
+ rotary_embedding = self.backbone.decoder_rotary_embedding(positions)
+
+ self_attention_caches = []
+ cross_attention_caches = []
+ for layer in self.backbone.decoder_blocks:
+ x, cache_k, cache_v, x_attn_cache_k, x_attn_cache_v = layer(
+ [x, encoder_hidden_states, rotary_embedding],
+ use_cache=False,
+ decoder_attention_mask=decoder_padding_mask,
+ encoder_attention_mask=encoder_padding_mask,
+ )
+ # Stack key and value for each layer.
+ self_attention_caches.append(
+ keras.ops.stack([cache_k, cache_v], axis=1)
+ )
+ cross_attention_caches.append(
+ keras.ops.stack([x_attn_cache_k, x_attn_cache_v], axis=1)
+ )
+ self_attention_cache = keras.ops.stack(
+ self_attention_caches, axis=1
+ )
+ cross_attention_cache = keras.ops.stack(
+ cross_attention_caches, axis=1
+ )
+
+ else:
+ position = keras.ops.array(
+ [self_attention_cache_update_index], dtype="int32"
+ )
+ position_ids = keras.ops.expand_dims(position, axis=0)
+ batch_size = keras.ops.shape(decoder_token_ids)[0]
+ if batch_size > 1:
+ position_ids = keras.ops.repeat(
+ position_ids, batch_size, axis=0
+ )
+ rotary_embedding = self.backbone.decoder_rotary_embedding(position)
+
+ for i, layer in enumerate(self.backbone.decoder_blocks):
+ # [batch_size, 2, seq_len, num_heads, head_dim].
+ current_self_cache = self_attention_cache[:, i, :, :, :, :]
+ cache_k = current_self_cache[
+ :, 0, :, :, :
+ ] # [batch_size, seq_len, num_heads, head_dim]
+ cache_v = current_self_cache[
+ :, 1, :, :, :
+ ] # [batch_size, seq_len, num_heads, head_dim]
+ # [batch_size, 2, context_len, num_heads, head_dim].
+ current_cross_cache = cross_attention_cache[:, i, :, :, :, :]
+ x_attn_cache_k = current_cross_cache[
+ :, 0, :, :, :
+ ] # [batch_size, context_len, num_heads, head_dim]
+ x_attn_cache_v = current_cross_cache[
+ :, 1, :, :, :
+ ] # [batch_size, context_len, num_heads, head_dim]
+
+ # Call layer with 7 inputs.
+ x, new_cache_k, new_cache_v = layer(
+ [
+ x,
+ encoder_hidden_states,
+ cache_k,
+ cache_v,
+ x_attn_cache_k,
+ x_attn_cache_v,
+ rotary_embedding,
+ ],
+ use_cache=True,
+ decoder_attention_mask=decoder_padding_mask,
+ encoder_attention_mask=encoder_padding_mask,
+ training=False,
+ )
+ # Update self-attention cache.
+ new_self_cache = keras.ops.stack(
+ [new_cache_k, new_cache_v], axis=1
+ )
+ self_attention_caches.append(new_self_cache)
+
+ # [batch_size, num_layers, 2, seq_len, num_heads, head_dim].
+ self_attention_cache = keras.ops.stack(
+ self_attention_caches, axis=1
+ )
+
+ hidden_states = self.backbone.decoder_post_norm(x)
+ logits = self.backbone.logits(hidden_states)
+ return (
+ logits,
+ hidden_states,
+ self_attention_cache,
+ cross_attention_cache,
+ )
+
+ def call_encoder(self, encoder_input_values, padding_mask):
+ """Process audio input through the encoder stack."""
+ x = encoder_input_values
+ seq_length = keras.ops.shape(x)[1]
+ positions = keras.ops.arange(0, seq_length, dtype="int32")
+ rotary_embedding = self.backbone.encoder_rotary_embedding(positions)
+ if hasattr(self.backbone, "encoder_dropout"):
+ x = self.backbone.encoder_dropout(x, training=False)
+ for transformer_layer in self.backbone.encoder_blocks:
+ x = transformer_layer(
+ inputs=x,
+ rotary_embedding=rotary_embedding,
+ attention_mask=padding_mask,
+ training=False,
+ )
+ if hasattr(self.backbone, "encoder_final_layer_norm"):
+ x = self.backbone.encoder_final_layer_norm(x)
+ return x
+
+ def _build_cache(
+ self,
+ audio_inputs,
+ audio_padding_mask,
+ decoder_token_ids,
+ ):
+ """Initialize and populate attention caches with encoder and decoder
+ outputs."""
+ encoder_hidden_states = self.call_encoder(
+ audio_inputs, padding_mask=audio_padding_mask
+ )
+ _, hidden_states, self_attention_cache, cross_attention_cache = (
+ self.call_decoder_with_cache(
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_padding_mask=audio_padding_mask,
+ decoder_token_ids=decoder_token_ids,
+ self_attention_cache=None,
+ cross_attention_cache=None,
+ )
+ )
+ return (
+ hidden_states,
+ encoder_hidden_states,
+ self_attention_cache,
+ cross_attention_cache,
+ )
+
+ # Source: https://github.com/huggingface/transformers/blob/9e94801146ceeb3b215bbdb9492be74d7d7b7210/src/transformers/generation/utils.py#L1970-L2463
+ def generate_step(self, inputs, stop_token_ids=None):
+ """A compilable generation function for a batch of inputs.
+
+ This function represents the inner, XLA-compilable, generation function
+ for a single batch of inputs. Inputs should have the same structure as
+ model inputs, a dictionary with keys `"encoder_input_values"`,
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
+ `"decoder_padding_mask"`.
+
+ Args:
+ inputs: A dictionary with four keys - `"encoder_input_values"`,
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
+ `"decoder_padding_mask"`, with batched tensor values.
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
+ sequences have produced a new stop token, generation
+ will stop.
+
+ Returns:
+ Dictionary: A dictionary with two keys - `"decoder_token_ids"`
+ containing the updated token sequence with newly generated
+ tokens, and `"decoder_padding_mask"` containing the updated
+ padding mask for the generated sequence.
+ """
+ encoder_input_values = inputs["encoder_input_values"]
+ encoder_padding_mask = inputs["encoder_padding_mask"]
+ decoder_token_ids = inputs["decoder_token_ids"]
+ decoder_padding_mask = inputs["decoder_padding_mask"]
+
+ if (
+ encoder_input_values is None
+ or encoder_padding_mask is None
+ or decoder_token_ids is None
+ ):
+ raise ValueError("Input tensors cannot be None")
+
+ batch_size = keras.ops.shape(encoder_input_values)[0]
+ # Calculate the length of the valid prompt before building the cache.
+ row_lengths = keras.ops.sum(
+ keras.ops.cast(decoder_padding_mask, "int32"),
+ axis=-1,
+ )
+ index = keras.ops.min(row_lengths)
+ # NOTE: For the JAX backend, pre-allocate the cache based on max_length.
+ max_length = keras.ops.shape(decoder_token_ids)[1]
+
+ encoder_hidden_states = self.call_encoder(
+ encoder_input_values=encoder_input_values,
+ padding_mask=encoder_padding_mask,
+ )
+ initial_decoder_token_ids = keras.ops.slice(
+ decoder_token_ids, [0, 0], [batch_size, index]
+ )
+ initial_decoder_padding_mask = keras.ops.slice(
+ decoder_padding_mask, [0, 0], [batch_size, index]
+ )
+ (
+ _,
+ hidden_states,
+ init_self_attention_cache,
+ init_cross_attention_cache,
+ ) = self.call_decoder_with_cache(
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_padding_mask=encoder_padding_mask,
+ decoder_token_ids=initial_decoder_token_ids,
+ self_attention_cache=None,
+ cross_attention_cache=None,
+ decoder_padding_mask=initial_decoder_padding_mask,
+ )
+ self_attention_cache = init_self_attention_cache
+ cross_attention_cache = init_cross_attention_cache
+
+ row_lengths = keras.ops.sum(
+ keras.ops.cast(decoder_padding_mask, "int32"),
+ axis=-1,
+ )
+ index = keras.ops.min(row_lengths)
+
+ def next(prompt, cache, index):
+ if isinstance(cache, tuple) and len(cache) == 1:
+ cache = cache[0]
+ elif isinstance(cache, tuple) and len(cache) == 0:
+ cache = None
+ cache_index = index - 1
+ num_samples = keras.ops.shape(prompt)[0]
+ next_token_input = keras.ops.slice(
+ prompt, [0, cache_index], [num_samples, 1]
+ )
+ single_token_padding_mask = keras.ops.ones_like(
+ next_token_input, dtype="bool"
+ )
+
+ def repeat_tensor(x):
+ if keras.ops.shape(x)[0] == num_samples:
+ return x
+ return keras.ops.repeat(
+ x, repeats=num_samples // batch_size, axis=0
+ )
+
+ logits, hidden_states, new_cache, _ = self.call_decoder_with_cache(
+ encoder_hidden_states=repeat_tensor(encoder_hidden_states),
+ encoder_padding_mask=repeat_tensor(encoder_padding_mask),
+ decoder_token_ids=next_token_input,
+ self_attention_cache=cache,
+ self_attention_cache_update_index=cache_index,
+ cross_attention_cache=repeat_tensor(cross_attention_cache),
+ decoder_padding_mask=single_token_padding_mask,
+ )
+ return (
+ logits[:, 0, :],
+ hidden_states[:, 0, :],
+ new_cache,
+ )
+
+ if keras.config.backend() == "jax":
+ current_prompt = decoder_token_ids
+ current_cache = self_attention_cache
+ current_index = index
+ for _ in range(max_length - index):
+ if stop_token_ids is not None:
+ prompt_mask = keras.ops.cast(
+ current_prompt
+ == (
+ self.preprocessor.tokenizer.pad_token_id
+ if self.preprocessor
+ else -1
+ ),
+ dtype="bool",
+ )
+ valid_token_mask = ~prompt_mask
+ full_range = keras.ops.arange(max_length)
+ generated_range_mask = (full_range >= index) & (
+ full_range < current_index
+ )
+ check_mask = valid_token_mask & keras.ops.expand_dims(
+ generated_range_mask, 0
+ )
+ end_tokens = any_equal(
+ current_prompt, stop_token_ids, check_mask
+ )
+ prompt_done = keras.ops.any(end_tokens, axis=-1)
+ if keras.ops.all(prompt_done):
+ break
+
+ logits, _, current_cache = next(
+ current_prompt, current_cache, current_index
+ )
+ probabilities = self.sampler.compute_probabilities(logits)
+ next_token = self.sampler.get_next_token(probabilities)
+ next_token = keras.ops.cast(next_token, current_prompt.dtype)
+ next_token = next_token[:, None]
+ current_prompt = keras.ops.slice_update(
+ current_prompt, [0, current_index], next_token
+ )
+ current_index += 1
+
+ decoder_token_ids = current_prompt
+ else:
+ decoder_token_ids = self.sampler(
+ next=next,
+ prompt=decoder_token_ids,
+ cache=self_attention_cache,
+ index=index,
+ mask=keras.ops.cast(
+ decoder_token_ids
+ != self.preprocessor.tokenizer.pad_token_id
+ if self.preprocessor is not None
+ else decoder_padding_mask,
+ dtype="bool",
+ ),
+ stop_token_ids=stop_token_ids,
+ hidden_states=hidden_states,
+ model=self,
+ )
+
+ if stop_token_ids is not None:
+ end_locations = any_equal(
+ decoder_token_ids,
+ stop_token_ids,
+ decoder_token_ids == self.preprocessor.tokenizer.pad_token_id
+ if self.preprocessor is not None
+ else False,
+ )
+ end_locations = keras.ops.cast(end_locations, "int32")
+ cumsum = keras.ops.cumsum(end_locations, axis=-1)
+ overflow = cumsum - end_locations
+ decoder_padding_mask = keras.ops.logical_not(
+ keras.ops.cast(overflow, "bool")
+ )
+ else:
+ decoder_padding_mask = keras.ops.ones_like(
+ decoder_token_ids, dtype="bool"
+ )
+
+ return {
+ "decoder_token_ids": decoder_token_ids,
+ "decoder_padding_mask": decoder_padding_mask,
+ }
+
+ def make_generate_function(self):
+ """Create or return the compiled generation function."""
+ if self.generate_function is not None:
+ return self.generate_function
+
+ self.generate_function = self.generate_step
+ if keras.config.backend() == "torch":
+ import torch
+
+ def wrapped_generate_function(
+ inputs,
+ stop_token_ids=None,
+ ):
+ with torch.no_grad():
+ return self.generate_step(inputs, stop_token_ids)
+
+ self.generate_function = wrapped_generate_function
+ elif keras.config.backend() == "tensorflow" and not self.run_eagerly:
+ # `jit_compile` is a property of keras.Model after TF 2.12.
+ # Use `getattr()` for backwards compatibility.
+ # NOTE: Override, explicitly disabled JIT compilation for the
+ # TensorFlow backend.
+ self.generate_function = tf.function(
+ self.generate_step, jit_compile=False
+ )
+ elif keras.config.backend() == "jax" and not self.run_eagerly:
+
+ def wrapped_generate_function(
+ inputs,
+ stop_token_ids=None,
+ ):
+ inputs = tree.map_structure(keras.ops.convert_to_tensor, inputs)
+ return self.generate_step(inputs, stop_token_ids)
+
+ self.generate_function = wrapped_generate_function
+
+ return self.generate_function
diff --git a/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py b/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py
new file mode 100644
index 0000000000..ec7d4e0516
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py
@@ -0,0 +1,172 @@
+import os
+from unittest.mock import patch
+
+import keras
+import numpy as np
+import pytest
+from keras import ops
+
+from keras_hub.src.models.moonshine.moonshine_audio_converter import (
+ MoonshineAudioConverter,
+)
+from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
+ MoonshineAudioToText,
+)
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
+ MoonshineSeq2SeqLMPreprocessor,
+)
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineAudioToTextTest(TestCase):
+ def setUp(self):
+ self.tokenizer = MoonshineTokenizer(
+ proto=os.path.join(
+ self.get_test_data_dir(), "moonshine_test_vocab.spm"
+ )
+ )
+ self.vocab_size = 1036
+ hidden_dim = 32
+ self.audio_converter = MoonshineAudioConverter(
+ filter_dim=hidden_dim,
+ sampling_rate=16000,
+ do_normalize=False,
+ return_attention_mask=True,
+ padding_value=0.0,
+ initializer_range=0.02,
+ )
+ self.preprocessor = MoonshineSeq2SeqLMPreprocessor(
+ audio_converter=self.audio_converter,
+ tokenizer=self.tokenizer,
+ decoder_sequence_length=10,
+ )
+ self.backbone = MoonshineBackbone(
+ vocabulary_size=self.vocab_size,
+ hidden_dim=hidden_dim,
+ encoder_num_layers=2,
+ decoder_num_layers=2,
+ encoder_num_heads=4,
+ decoder_num_heads=4,
+ intermediate_dim=hidden_dim * 4,
+ feedforward_expansion_factor=4,
+ encoder_use_swiglu_activation=False,
+ decoder_use_swiglu_activation=True,
+ max_position_embeddings=2048,
+ pad_head_dim_to_multiple_of=None,
+ partial_rotary_factor=0.62,
+ dropout=0.0,
+ initializer_range=0.02,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ rope_scaling=None,
+ )
+ self.init_kwargs = {
+ "preprocessor": self.preprocessor,
+ "backbone": self.backbone,
+ }
+ # NOTE: Since keras.ops.convert_to_tensor() does not support
+ # dtype="string" for the JAX and PyTorch backends, the only way to pass
+ # inputs that aren't a mix of tensors and non-tensors is to use a
+ # library-specific function. Using an np.ndarray here as a substitute to
+ # a librosa.load() call.
+ self.train_data = (
+ {
+ "audio": np.random.normal(size=(2, 16000, 1)),
+ "text": ["quick brown", "earth is round"],
+ },
+ )
+ self.input_data = self.preprocessor(self.train_data[0])[0]
+
+ @pytest.mark.skipif(
+ keras.config.backend() == "torch" or keras.config.backend() == "jax",
+ reason="NotImplementedError: Cannot convert a symbolic tf.Tensor (args_"
+ "0:0) to a numpy array. This error may indicate that you're trying to"
+ "pass a Tensor to a NumPy call, which is not supported.",
+ )
+ def test_causal_lm_basics(self):
+ self.run_task_test(
+ cls=MoonshineAudioToText,
+ init_kwargs=self.init_kwargs,
+ train_data=self.train_data,
+ expected_output_shape=(2, 10, self.tokenizer.vocabulary_size()),
+ )
+
+ def test_generate(self):
+ inputs = {"audio": keras.random.normal((1, 16000, 1)), "text": "quick"}
+ seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs)
+ output = seq_2_seq_lm.generate(inputs)
+ self.assertTrue("quick" in output)
+
+ seq_2_seq_lm.preprocessor = None
+ preprocessed = self.preprocessor.generate_preprocess(inputs)
+ outputs = seq_2_seq_lm.generate(preprocessed, stop_token_ids=None)
+ self.assertAllEqual(
+ outputs["decoder_token_ids"][:, :2],
+ preprocessed["decoder_token_ids"][:, :2],
+ )
+
+ def test_early_stopping(self):
+ seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs)
+ call_decoder_with_cache = seq_2_seq_lm.call_decoder_with_cache
+
+ def wrapper(*args, **kwargs):
+ logits, hidden_states, self_cache, cross_cache = (
+ call_decoder_with_cache(*args, **kwargs)
+ )
+ index = self.preprocessor.tokenizer.end_token_id
+ update = ops.ones_like(logits)[:, :, index] * 1.0e9
+ update = ops.expand_dims(update, axis=-1)
+ logits = ops.slice_update(logits, (0, 0, index), update)
+ return logits, hidden_states, self_cache, cross_cache
+
+ with patch.object(
+ seq_2_seq_lm, "call_decoder_with_cache", wraps=wrapper
+ ):
+ inputs = {
+ "audio": keras.random.normal((2, 16000, 1)),
+ "text": ["quick", "earth"],
+ }
+ output = seq_2_seq_lm.generate(inputs)
+ self.assertAllEqual(inputs["text"], output)
+
+ def test_generate_compilation(self):
+ seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs)
+ seq_2_seq_lm.generate({"audio": keras.random.normal((1, 16000, 1))})
+ first_fn = seq_2_seq_lm.generate_function
+ seq_2_seq_lm.generate({"audio": keras.random.normal((1, 16000, 1))})
+ second_fn = seq_2_seq_lm.generate_function
+ self.assertEqual(first_fn, second_fn)
+ seq_2_seq_lm.compile(sampler="greedy")
+ self.assertIsNone(seq_2_seq_lm.generate_function)
+
+ @pytest.mark.skipif(
+ keras.config.backend() == "jax",
+ reason="Beam search involves state management not supported in the "
+ "JAX manual eager loop override.",
+ )
+ def test_beam_search(self):
+ seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs)
+ seq_2_seq_lm.compile(sampler="beam")
+ seq_2_seq_lm.generate({"audio": keras.random.normal((1, 16000, 1))})
+
+ @pytest.mark.large
+ def test_saved_model(self):
+ self.run_model_saving_test(
+ cls=MoonshineAudioToText,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ )
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in MoonshineAudioToText.presets:
+ self.run_preset_test(
+ cls=MoonshineAudioToText,
+ preset=preset,
+ input_data=self.input_data,
+ )
diff --git a/keras_hub/src/models/moonshine/moonshine_backbone.py b/keras_hub/src/models/moonshine/moonshine_backbone.py
new file mode 100644
index 0000000000..f873653202
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_backbone.py
@@ -0,0 +1,360 @@
+import keras
+
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.layers.modeling.reversible_embedding import (
+ ReversibleEmbedding,
+)
+from keras_hub.src.models.backbone import Backbone
+from keras_hub.src.models.moonshine.moonshine_decoder import (
+ MoonshineDecoderBlock,
+)
+from keras_hub.src.models.moonshine.moonshine_encoder import (
+ MoonshineEncoderBlock,
+)
+from keras_hub.src.models.moonshine.moonshine_layers import (
+ MoonshineRotaryEmbedding,
+)
+from keras_hub.src.models.moonshine.moonshine_layers import (
+ moonshine_kernel_initializer,
+)
+from keras_hub.src.utils.keras_utils import clone_initializer
+
+
+class Arange(keras.layers.Layer):
+ def call(self, inputs):
+ sequence_length = keras.ops.shape(inputs)[1]
+ return keras.ops.arange(sequence_length, dtype="int32")
+
+
+@keras_hub_export("keras_hub.models.MoonshineBackbone")
+class MoonshineBackbone(Backbone):
+ """Moonshine backbone for speech recognition.
+
+ This class implements an encoder-decoder backbone, as used in the Moonshine
+ ASR system. It combines `MoonshineEncoderBlock` instances for processing
+ input sequences and `MoonshineDecoderBlock` instances for generating output
+ sequences.
+
+ Args:
+ vocabulary_size: int. The size of the vocabulary for the embedding
+ layers.
+ encoder_num_layers: int. The number of stacked encoder blocks.
+ decoder_num_layers: int. The number of stacked decoder blocks.
+ hidden_dim: int. The dimensionality of the model's hidden
+ representations and embeddings.
+ intermediate_dim: int. The dimensionality of the intermediate
+ representations in feedforward networks.
+ encoder_num_heads: int. The number of attention heads in the encoder's
+ multi-head attention.
+ decoder_num_heads: int. The number of attention heads in the decoder's
+ multi-head attention.
+ feedforward_expansion_factor: int, optional. A multiplier applied to
+ `intermediate_dim` to determine the total width of the feedforward
+ network. Defaults to 4.
+ use_swiglu_activation: bool, optional. When True, uses the SwiGLU
+ activation in the feedforward network for improved performance.
+ Defaults to False.
+ max_position_embeddings: int, optional. The maximum sequence length for
+ position embeddings. Defaults to 2048.
+ pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
+ dimension to be a multiple of this value for performance
+ optimization. Defaults to None.
+ partial_rotary_factor: float, optional. The fraction of dimensions to
+ apply rotary position embeddings to. Defaults to 0.62.
+ dropout: float, optional. The dropout probability for input dropout
+ layers. Defaults to 0.0.
+ initializer_range: float, optional. The standard deviation of the
+ truncated normal initializer for weights. Defaults to 0.02.
+ rope_theta: float, optional. The base frequency for rotary position
+ embeddings. Defaults to 10,000.0.
+ attention_bias: bool, optional. Whether to use bias in attention
+ mechanisms. Defaults to False.
+ attention_dropout: float, optional. The dropout probability for
+ attention mechanisms. Defaults to 0.0.
+ rope_scaling: dict, optional. The scaling configuration for rotary
+ position embeddings. Defaults to None.
+ dtype: str, optional. The dtype to use for model computations and
+ weights. Defaults to None.
+
+ Examples:
+ ```python
+ # Create random input data for demonstration.
+ encoder_input_values = np.random.rand(1, 100, 256).astype("float32")
+ decoder_token_ids = np.random.randint(
+ 0, 1000, size=(1, 20), dtype="int32"
+ )
+
+ # Initialize the Moonshine backbone with specific parameters.
+ backbone = MoonshineBackbone(
+ vocabulary_size=10000,
+ encoder_num_layers=6,
+ decoder_num_layers=6,
+ hidden_dim=256,
+ intermediate_dim=512,
+ encoder_num_heads=8,
+ decoder_num_heads=8,
+ feedforward_expansion_factor=4,
+ use_swiglu_activation=True,
+ )
+
+ # Forward pass through the model.
+ outputs = backbone(
+ {
+ "encoder_input_values": encoder_input_values,
+ "decoder_token_ids": decoder_token_ids,
+ }
+ )
+
+ # Display the outputs.
+ print("Encoder output shape:", outputs["encoder_sequence_output"].shape)
+ print("Decoder output shape:", outputs["decoder_sequence_output"].shape)
+ ```
+ """
+
+ # References:
+ # Defined and formulated based on the Hugging Face implementation of the
+ # MoonshineModel class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1326-L1486).
+
+ def __init__(
+ self,
+ vocabulary_size,
+ encoder_num_layers,
+ decoder_num_layers,
+ hidden_dim,
+ intermediate_dim,
+ encoder_num_heads,
+ decoder_num_heads,
+ feedforward_expansion_factor=4,
+ encoder_use_swiglu_activation=False,
+ decoder_use_swiglu_activation=True,
+ max_position_embeddings=2048,
+ pad_head_dim_to_multiple_of=None,
+ partial_rotary_factor=0.62,
+ dropout=0.0,
+ initializer_range=0.02,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ rope_scaling=None,
+ dtype=None,
+ **kwargs,
+ ):
+ # ==== Config ====
+ self.vocabulary_size = vocabulary_size
+ self.encoder_num_layers = encoder_num_layers
+ self.decoder_num_layers = decoder_num_layers
+ self.hidden_dim = hidden_dim
+ self.intermediate_dim = intermediate_dim
+ self.encoder_num_heads = encoder_num_heads
+ self.decoder_num_heads = decoder_num_heads
+ self.feedforward_expansion_factor = feedforward_expansion_factor
+ self.encoder_use_swiglu_activation = encoder_use_swiglu_activation
+ self.decoder_use_swiglu_activation = decoder_use_swiglu_activation
+ self.max_position_embeddings = max_position_embeddings
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
+ self.partial_rotary_factor = partial_rotary_factor
+ self.dropout = dropout
+ self.initializer_range = initializer_range
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.rope_scaling = rope_scaling
+ self.embeddings_initializer = moonshine_kernel_initializer(
+ initializer_range=initializer_range
+ )
+
+ # ==== Layers ====
+ encoder_head_dim = hidden_dim // encoder_num_heads
+ if pad_head_dim_to_multiple_of:
+ encoder_head_dim = (
+ (encoder_head_dim + pad_head_dim_to_multiple_of - 1)
+ // pad_head_dim_to_multiple_of
+ ) * pad_head_dim_to_multiple_of
+
+ decoder_head_dim = hidden_dim // decoder_num_heads
+ if pad_head_dim_to_multiple_of:
+ decoder_head_dim = (
+ (decoder_head_dim + pad_head_dim_to_multiple_of - 1)
+ // pad_head_dim_to_multiple_of
+ ) * pad_head_dim_to_multiple_of
+
+ # Embedding layer for decoder.
+ self.token_embedding = ReversibleEmbedding(
+ input_dim=vocabulary_size,
+ output_dim=hidden_dim,
+ embeddings_initializer=clone_initializer(
+ self.embeddings_initializer
+ ),
+ name="token_embedding",
+ dtype=dtype,
+ )
+
+ # Rotary embeddings for encoder and decoder.
+ self.encoder_rotary_embedding = MoonshineRotaryEmbedding(
+ head_dim=encoder_head_dim,
+ max_position_embeddings=max_position_embeddings,
+ partial_rotary_factor=partial_rotary_factor,
+ base_value=rope_theta,
+ rope_scaling=rope_scaling,
+ name="encoder_rotary_embedding",
+ dtype=dtype,
+ )
+
+ self.decoder_rotary_embedding = MoonshineRotaryEmbedding(
+ head_dim=decoder_head_dim,
+ max_position_embeddings=max_position_embeddings,
+ partial_rotary_factor=partial_rotary_factor,
+ base_value=rope_theta,
+ rope_scaling=rope_scaling,
+ name="decoder_rotary_embedding",
+ dtype=dtype,
+ )
+
+ # Dropout for encoder.
+ self.encoder_dropout = keras.layers.Dropout(
+ dropout, name="encoder_dropout", dtype=dtype
+ )
+ # Dropout for decoder.
+ self.decoder_dropout = keras.layers.Dropout(
+ dropout, name="decoder_dropout", dtype=dtype
+ )
+
+ # Encoder blocks.
+ self.encoder_blocks = []
+ for i in range(encoder_num_layers):
+ encoder_block = MoonshineEncoderBlock(
+ hidden_dim=hidden_dim,
+ intermediate_dim=intermediate_dim,
+ num_heads=encoder_num_heads,
+ feedforward_expansion_factor=feedforward_expansion_factor,
+ use_swiglu_activation=encoder_use_swiglu_activation,
+ pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of,
+ initializer_range=initializer_range,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ name=f"encoder_block_{i}",
+ dtype=dtype,
+ )
+ self.encoder_blocks.append(encoder_block)
+
+ # Layer normalization for encoder.
+ self.encoder_final_layer_norm = keras.layers.LayerNormalization(
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ name="encoder_final_layer_norm",
+ dtype=dtype,
+ )
+
+ # Decoder blocks.
+ self.decoder_blocks = []
+ for i in range(decoder_num_layers):
+ decoder_block = MoonshineDecoderBlock(
+ hidden_dim=hidden_dim,
+ intermediate_dim=intermediate_dim,
+ num_heads=decoder_num_heads,
+ feedforward_expansion_factor=feedforward_expansion_factor,
+ use_swiglu_activation=decoder_use_swiglu_activation,
+ pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of,
+ initializer_range=initializer_range,
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ name=f"decoder_block_{i}",
+ dtype=dtype,
+ )
+ self.decoder_blocks.append(decoder_block)
+
+ # Layer normalization for decoder.
+ self.decoder_post_norm = keras.layers.LayerNormalization(
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ name="decoder_post_norm",
+ dtype=dtype,
+ )
+
+ # === Functional Model ===
+ encoder_input = keras.Input(
+ shape=(None, hidden_dim), name="encoder_input_values", dtype=dtype
+ )
+ decoder_input = keras.Input(
+ shape=(None,), name="decoder_token_ids", dtype="int32"
+ )
+ encoder_padding_mask = keras.Input(
+ shape=(None,), name="encoder_padding_mask", dtype="bool"
+ )
+ decoder_padding_mask = keras.Input(
+ shape=(None,), name="decoder_padding_mask", dtype="bool"
+ )
+
+ # Encoder.
+ encoder_positions = Arange(name="encoder_positions")(encoder_input)
+ encoder_rotary_emb = self.encoder_rotary_embedding(encoder_positions)
+ encoder_hidden_states = self.encoder_dropout(encoder_input)
+ for encoder_block in self.encoder_blocks:
+ encoder_hidden_states = encoder_block(
+ encoder_hidden_states,
+ encoder_rotary_emb,
+ attention_mask=encoder_padding_mask,
+ )
+ encoder_output = self.encoder_final_layer_norm(encoder_hidden_states)
+
+ # Decoder.
+ decoder_positions = Arange(name="decoder_positions")(decoder_input)
+ decoder_rotary_emb = self.decoder_rotary_embedding(decoder_positions)
+ decoder_hidden_states = self.token_embedding(decoder_input)
+ decoder_hidden_states = self.decoder_dropout(decoder_hidden_states)
+ for decoder_block in self.decoder_blocks:
+ decoder_hidden_states, _, _, _, _ = decoder_block(
+ [decoder_hidden_states, encoder_output, decoder_rotary_emb],
+ decoder_attention_mask=decoder_padding_mask,
+ encoder_attention_mask=encoder_padding_mask,
+ )
+ decoder_output = self.decoder_post_norm(decoder_hidden_states)
+
+ super().__init__(
+ inputs={
+ "encoder_input_values": encoder_input,
+ "decoder_token_ids": decoder_input,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_padding_mask": decoder_padding_mask,
+ },
+ outputs={
+ "encoder_sequence_output": encoder_output,
+ "decoder_sequence_output": decoder_output,
+ },
+ dtype=dtype,
+ **kwargs,
+ )
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "vocabulary_size": self.vocabulary_size,
+ "encoder_num_layers": self.encoder_num_layers,
+ "decoder_num_layers": self.decoder_num_layers,
+ "hidden_dim": self.hidden_dim,
+ "intermediate_dim": self.intermediate_dim,
+ "encoder_num_heads": self.encoder_num_heads,
+ "decoder_num_heads": self.decoder_num_heads,
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
+ "encoder_use_swiglu_activation": self.encoder_use_swiglu_activation, # noqa: E501
+ "decoder_use_swiglu_activation": self.decoder_use_swiglu_activation, # noqa: E501
+ "max_position_embeddings": self.max_position_embeddings,
+ "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of,
+ "partial_rotary_factor": self.partial_rotary_factor,
+ "dropout": self.dropout,
+ "initializer_range": self.initializer_range,
+ "rope_theta": self.rope_theta,
+ "attention_bias": self.attention_bias,
+ "attention_dropout": self.attention_dropout,
+ "rope_scaling": self.rope_scaling,
+ "dtype": self.dtype,
+ }
+ )
+ return config
+
+ # Use the MoonshineBackbone class as part of a trainable model.
+ def logits(self, decoder_hidden_states):
+ return self.token_embedding(decoder_hidden_states, reverse=True)
diff --git a/keras_hub/src/models/moonshine/moonshine_backbone_test.py b/keras_hub/src/models/moonshine/moonshine_backbone_test.py
new file mode 100644
index 0000000000..18d04973f1
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_backbone_test.py
@@ -0,0 +1,197 @@
+import keras
+import pytest
+
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineBackboneTest(TestCase):
+ def setUp(self):
+ self.init_kwargs = {
+ "vocabulary_size": 10000,
+ "encoder_num_layers": 2,
+ "decoder_num_layers": 2,
+ "hidden_dim": 64,
+ "intermediate_dim": 512,
+ "encoder_num_heads": 8,
+ "decoder_num_heads": 8,
+ "feedforward_expansion_factor": 4,
+ "encoder_use_swiglu_activation": False,
+ "decoder_use_swiglu_activation": True,
+ "max_position_embeddings": 2048,
+ "pad_head_dim_to_multiple_of": None,
+ "partial_rotary_factor": 0.62,
+ "dropout": 0.0,
+ "initializer_range": 0.02,
+ "rope_theta": 10000.0,
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "rope_scaling": None,
+ }
+ encoder_input_values = keras.random.uniform((2, 16, 64))
+ decoder_token_ids = keras.random.randint(
+ shape=(2, 10), minval=0, maxval=10000
+ )
+ encoder_padding_mask = keras.ops.ones((2, 16), dtype="bool")
+ decoder_padding_mask = keras.ops.ones((2, 10), dtype="bool")
+ self.input_data = {
+ "encoder_input_values": encoder_input_values,
+ "decoder_token_ids": decoder_token_ids,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_padding_mask": decoder_padding_mask,
+ }
+
+ def test_forward_pass(self):
+ backbone = MoonshineBackbone(**self.init_kwargs)
+ outputs = backbone(self.input_data)
+ self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64))
+ self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64))
+
+ def test_serialization(self):
+ instance = MoonshineBackbone(**self.init_kwargs)
+ self.run_serialization_test(instance=instance)
+
+ def test_swiglu_feedforward(self):
+ init_kwargs = self.init_kwargs.copy()
+ init_kwargs["encoder_use_swiglu_activation"] = True
+ backbone = MoonshineBackbone(**init_kwargs)
+ outputs = backbone(self.input_data)
+ self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64))
+ self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64))
+
+ def test_different_sequence_lengths(self):
+ backbone = MoonshineBackbone(**self.init_kwargs)
+
+ # Short sequences.
+ short_encoder_input_values = keras.random.uniform((2, 8, 64))
+ short_decoder_token_ids = keras.random.randint(
+ shape=(2, 5), minval=0, maxval=10000
+ )
+ short_encoder_padding_mask = keras.ops.ones((2, 8), dtype="bool")
+ short_decoder_padding_mask = keras.ops.ones((2, 5), dtype="bool")
+ short_input_data = {
+ "encoder_input_values": short_encoder_input_values,
+ "decoder_token_ids": short_decoder_token_ids,
+ "encoder_padding_mask": short_encoder_padding_mask,
+ "decoder_padding_mask": short_decoder_padding_mask,
+ }
+ short_outputs = backbone(short_input_data)
+ self.assertEqual(
+ short_outputs["encoder_sequence_output"].shape, (2, 8, 64)
+ )
+ self.assertEqual(
+ short_outputs["decoder_sequence_output"].shape, (2, 5, 64)
+ )
+
+ # Long sequences.
+ long_encoder_input_values = keras.random.uniform((2, 32, 64))
+ long_decoder_token_ids = keras.random.randint(
+ shape=(2, 15), minval=0, maxval=10000
+ )
+ long_encoder_padding_mask = keras.ops.ones((2, 32), dtype="bool")
+ long_decoder_padding_mask = keras.ops.ones((2, 15), dtype="bool")
+ long_input_data = {
+ "encoder_input_values": long_encoder_input_values,
+ "decoder_token_ids": long_decoder_token_ids,
+ "encoder_padding_mask": long_encoder_padding_mask,
+ "decoder_padding_mask": long_decoder_padding_mask,
+ }
+ long_outputs = backbone(long_input_data)
+ self.assertEqual(
+ long_outputs["encoder_sequence_output"].shape, (2, 32, 64)
+ )
+ self.assertEqual(
+ long_outputs["decoder_sequence_output"].shape, (2, 15, 64)
+ )
+
+ def test_predict_model(self):
+ backbone = MoonshineBackbone(**self.init_kwargs)
+ outputs = backbone.predict(self.input_data)
+ self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64))
+ self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64))
+
+ def test_varying_batch_sizes(self):
+ backbone = MoonshineBackbone(**self.init_kwargs)
+ for batch_size in [1, 3, 5]:
+ encoder_input_values = keras.random.uniform((batch_size, 16, 64))
+ decoder_token_ids = keras.random.randint(
+ shape=(batch_size, 10), minval=0, maxval=10000
+ )
+ encoder_padding_mask = keras.ops.ones(
+ (batch_size, 16), dtype="bool"
+ )
+ decoder_padding_mask = keras.ops.ones(
+ (batch_size, 10), dtype="bool"
+ )
+ input_data = {
+ "encoder_input_values": encoder_input_values,
+ "decoder_token_ids": decoder_token_ids,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_padding_mask": decoder_padding_mask,
+ }
+ outputs = backbone(input_data)
+ self.assertEqual(
+ outputs["encoder_sequence_output"].shape, (batch_size, 16, 64)
+ )
+ self.assertEqual(
+ outputs["decoder_sequence_output"].shape, (batch_size, 10, 64)
+ )
+
+ def test_attention_parameters(self):
+ init_kwargs = self.init_kwargs.copy()
+ init_kwargs["attention_bias"] = True
+ init_kwargs["attention_dropout"] = 0.1
+ backbone = MoonshineBackbone(**init_kwargs)
+ outputs = backbone(self.input_data)
+ self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64))
+ self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64))
+
+ def test_rope_parameters(self):
+ init_kwargs = self.init_kwargs.copy()
+ init_kwargs["rope_theta"] = 5000.0
+ init_kwargs["rope_scaling"] = {"type": "linear", "factor": 2.0}
+ backbone = MoonshineBackbone(**init_kwargs)
+ outputs = backbone(self.input_data)
+ self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64))
+ self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64))
+
+ @pytest.mark.large
+ def test_saved_model(self):
+ self.run_model_saving_test(
+ cls=MoonshineBackbone,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ )
+
+ def test_backbone_basics(self):
+ self.run_backbone_test(
+ cls=MoonshineBackbone,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ expected_output_shape={
+ "encoder_sequence_output": (2, 16, 64),
+ "decoder_sequence_output": (2, 10, 64),
+ },
+ run_mixed_precision_check=False,
+ run_quantization_check=False,
+ )
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in MoonshineBackbone.presets.keys():
+ hidden_size = 288 if preset == "moonshine_tiny_en" else 416
+ encoder_input_values = keras.ops.ones((1, 100, hidden_size))
+ decoder_token_ids = keras.ops.ones((1, 10), dtype="int32")
+ encoder_padding_mask = keras.ops.ones((1, 100), dtype="bool")
+ decoder_padding_mask = keras.ops.ones((1, 10), dtype="bool")
+ input_data = {
+ "encoder_input_values": encoder_input_values,
+ "decoder_token_ids": decoder_token_ids,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_padding_mask": decoder_padding_mask,
+ }
+ self.run_preset_test(
+ cls=MoonshineBackbone,
+ preset=preset,
+ input_data=input_data,
+ )
diff --git a/keras_hub/src/models/moonshine/moonshine_decoder.py b/keras_hub/src/models/moonshine/moonshine_decoder.py
new file mode 100644
index 0000000000..e8c98ec0aa
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_decoder.py
@@ -0,0 +1,394 @@
+import keras
+
+from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
+from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
+from keras_hub.src.models.moonshine.moonshine_layers import (
+ moonshine_kernel_initializer,
+)
+from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
+ MoonshineMultiHeadAttention,
+)
+from keras_hub.src.utils.keras_utils import clone_initializer
+
+
+@keras.saving.register_keras_serializable(package="keras_hub")
+class MoonshineDecoderBlock(TransformerDecoder):
+ """Moonshine decoder block for sequence processing.
+
+ This layer implements a decoder block that includes self-attention with
+ causal masking, cross-attention with precomputed key/value pairs, and a
+ feedforward network. It supports both cached and uncached operation modes.
+
+ Args:
+ hidden_dim: int. The dimensionality of the model's hidden
+ representations.
+ intermediate_dim: int. The dimensionality of the intermediate
+ representations in the feedforward network.
+ num_heads: int. The number of attention heads for multi-head attention
+ mechanisms.
+ feedforward_expansion_factor: int, optional. A multiplicative factor for
+ scaling the feedforward network dimension. Defaults to 4.
+ use_swiglu_activation: bool, optional. Whether to use the SwiGLU
+ activation in the feedforward network for improved performance.
+ Defaults to True.
+ pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
+ dimension to be a multiple of this value for performance
+ optimization. Defaults to None.
+ initializer_range: float, optional. The standard deviation of the
+ truncated normal distribution used to initialize model weights.
+ Defaults to 0.02.
+ attention_bias: bool, optional. Whether to add a bias term to the
+ attention computations. Defaults to False.
+ attention_dropout: float, optional. The dropout rate applied to
+ attention weights during training. Defaults to 0.0.
+ dtype: str, optional. The data type to use for model computations and
+ weights. Defaults to None.
+ **kwargs: Additional keyword arguments passed to the base layer.
+ """
+
+ # References:
+ # Defined and formulated based on the UsefulSensors implementation of the
+ # DecoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L348-L466).
+
+ def __init__(
+ self,
+ hidden_dim,
+ intermediate_dim,
+ num_heads,
+ feedforward_expansion_factor=4,
+ use_swiglu_activation=True,
+ pad_head_dim_to_multiple_of=None,
+ initializer_range=0.02,
+ attention_bias=False,
+ attention_dropout=0.0,
+ dtype=None,
+ **kwargs,
+ ):
+ kwargs.pop("dropout", None)
+ kwargs.pop("activation", None)
+ kwargs.pop("kernel_initializer", None)
+ self.kernel_initializer = moonshine_kernel_initializer(
+ initializer_range=initializer_range
+ )
+ super().__init__(
+ intermediate_dim=intermediate_dim,
+ num_heads=num_heads,
+ dropout=attention_dropout,
+ activation="gelu" if use_swiglu_activation else "silu",
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ dtype=dtype,
+ **kwargs,
+ )
+ self.initializer_range = initializer_range
+ self.hidden_dim = hidden_dim
+ self.intermediate_dim = intermediate_dim
+ self.num_heads = num_heads
+ self.feedforward_expansion_factor = feedforward_expansion_factor
+ self.use_swiglu_activation = use_swiglu_activation
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
+ self.attention_dropout = attention_dropout
+ self.attention_bias = attention_bias
+
+ self.head_dim = hidden_dim // num_heads
+ if pad_head_dim_to_multiple_of is not None:
+ self.head_dim = (
+ (self.head_dim + pad_head_dim_to_multiple_of - 1)
+ // pad_head_dim_to_multiple_of
+ ) * pad_head_dim_to_multiple_of
+
+ self.norm1 = keras.layers.LayerNormalization(
+ axis=-1,
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ dtype=self.dtype,
+ )
+ self.self_attention = MoonshineMultiHeadAttention(
+ num_heads=num_heads,
+ key_dim=self.head_dim,
+ use_bias=False,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ use_causal_mask=True,
+ apply_rotary_embedding=True,
+ cache_mode="autoregressive",
+ dtype=self.dtype,
+ )
+ self.norm2 = keras.layers.LayerNormalization(
+ axis=-1,
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ dtype=self.dtype,
+ )
+ self.cross_attention = MoonshineMultiHeadAttention(
+ num_heads=num_heads,
+ key_dim=self.head_dim,
+ use_bias=False,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ use_causal_mask=False,
+ apply_rotary_embedding=False,
+ cache_mode="precomputed",
+ dtype=self.dtype,
+ )
+ self.norm3 = keras.layers.LayerNormalization(
+ axis=-1,
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ dtype=self.dtype,
+ )
+ self.ff = MoonshineMLP(
+ hidden_dim=hidden_dim,
+ feedforward_expansion_factor=feedforward_expansion_factor,
+ use_swiglu_activation=use_swiglu_activation,
+ initializer_range=initializer_range,
+ dtype=self.dtype,
+ )
+
+ def build(self, input_shape):
+ if not isinstance(input_shape, (list, tuple)) or len(input_shape) < 2:
+ raise ValueError(
+ "Expected input_shape to be a list of at least two shapes."
+ )
+ decoder_sequence_shape = (
+ input_shape[0]["decoder_token_ids"] # Shape of x
+ if isinstance(input_shape[0], dict)
+ else input_shape[0]
+ )
+ context_shape = (
+ input_shape[1]["input_values"] # Shape of context
+ if isinstance(input_shape[1], dict)
+ else input_shape[1]
+ )
+
+ # Build sublayers.
+ self.norm1.build(decoder_sequence_shape)
+ self.norm2.build(decoder_sequence_shape)
+ self.norm3.build(decoder_sequence_shape)
+
+ self.self_attention.build(
+ query_shape=decoder_sequence_shape,
+ key_shape=decoder_sequence_shape,
+ value_shape=decoder_sequence_shape,
+ )
+
+ self.cross_attention.build(
+ query_shape=decoder_sequence_shape,
+ key_shape=context_shape,
+ value_shape=context_shape,
+ )
+
+ self.ff.build(decoder_sequence_shape)
+ self.built = True
+
+ def compute_output_spec(
+ self,
+ inputs,
+ training=None,
+ use_cache=False,
+ decoder_attention_mask=None,
+ encoder_attention_mask=None,
+ ):
+ if use_cache:
+ # Cached case: expect 7 inputs.
+ if len(inputs) != 7:
+ raise ValueError(
+ "When use_cache=True, expected 7 inputs: "
+ "[x, context, cache_k, cache_v, x_attn_cache_k, "
+ "x_attn_cache_v, rotary_embedding]"
+ )
+ (
+ x,
+ context,
+ cache_k,
+ cache_v,
+ x_attn_cache_k,
+ x_attn_cache_v,
+ rotary_embedding,
+ ) = inputs
+ # Output shape for x is the same as input x_shape but with
+ # hidden_dim.
+ x_shape = x.shape if hasattr(x, "shape") else x
+ output_shape = x_shape[:-1] + (self.hidden_dim,)
+ # New cache shapes are the same as input cache_k_shape and
+ # cache_v_shape.
+ # Note: In practice, sequence length may increase due to
+ # concatenation, but symbolically, it remains None.
+ new_cache_shape = (
+ cache_k.shape if hasattr(cache_k, "shape") else cache_k
+ )
+ return (
+ keras.KerasTensor(shape=output_shape, dtype=self.dtype), # x
+ keras.KerasTensor(
+ shape=new_cache_shape, dtype=self.dtype
+ ), # new_cache_k
+ keras.KerasTensor(
+ shape=new_cache_shape, dtype=self.dtype
+ ), # new_cache_v
+ )
+ else:
+ # Uncached case: expect 3 inputs.
+ if len(inputs) != 3:
+ raise ValueError(
+ "When use_cache=False, expected 3 inputs: [x, context, "
+ "rotary_embedding]"
+ )
+ x, context, rotary_embedding = inputs
+ x_shape = x.shape if hasattr(x, "shape") else x
+ context_shape = (
+ context.shape if hasattr(context, "shape") else context
+ )
+ batch_size = x_shape[0] # None (symbolic)
+ seq_len = x_shape[1] # None (symbolic)
+ context_len = context_shape[1] # None (symbolic)
+ hidden_dim = self.hidden_dim
+ num_heads = self.num_heads
+ head_dim = self.head_dim
+
+ # Define output shapes.
+ output_shape = (batch_size, seq_len, hidden_dim) # x
+ cache_shape_self = (
+ batch_size,
+ seq_len,
+ num_heads,
+ head_dim,
+ ) # Self-attention caches
+ cache_shape_cross = (
+ batch_size,
+ context_len,
+ num_heads,
+ head_dim,
+ ) # Cross-attention caches
+
+ return (
+ keras.KerasTensor(shape=output_shape, dtype=self.dtype), # x
+ keras.KerasTensor(
+ shape=cache_shape_self, dtype=self.dtype
+ ), # cache_k
+ keras.KerasTensor(
+ shape=cache_shape_self, dtype=self.dtype
+ ), # cache_v
+ keras.KerasTensor(
+ shape=cache_shape_cross, dtype=self.dtype
+ ), # x_attn_cache_k
+ keras.KerasTensor(
+ shape=cache_shape_cross, dtype=self.dtype
+ ), # x_attn_cache_v
+ )
+
+ def call(
+ self,
+ inputs,
+ training=None,
+ use_cache=False,
+ decoder_attention_mask=None,
+ encoder_attention_mask=None,
+ self_attention_cache=None,
+ self_attention_cache_update_index=None,
+ ):
+ if use_cache:
+ if not isinstance(inputs, (list, tuple)) or len(inputs) != 7:
+ raise ValueError(
+ "When use_cache=True, expected 7 inputs: "
+ "[x, context, cache_k, cache_v, x_attn_cache_k, "
+ "x_attn_cache_v, rotary_embedding]. "
+ f"Received {len(inputs)} inputs."
+ )
+ (
+ x,
+ context,
+ cache_k, # Self-attn key cache
+ cache_v, # Self-attn value cache
+ x_attn_cache_k, # Cross-attn key cache (precomputed)
+ x_attn_cache_v, # Cross-attn value cache (precomputed)
+ rotary_embedding,
+ ) = inputs
+ else:
+ if not isinstance(inputs, (list, tuple)) or len(inputs) != 3:
+ raise ValueError(
+ "When use_cache=False, expected 3 inputs: [x, context, "
+ f"rotary_embedding]. Received {len(inputs)} inputs."
+ )
+ x, context, rotary_embedding = inputs
+ cache_k, cache_v, x_attn_cache_k, x_attn_cache_v = (
+ None,
+ None,
+ None,
+ None,
+ )
+ residual = x
+ x_norm1 = self.norm1(x)
+
+ if use_cache:
+ x_self_attn, new_cache_k, new_cache_v = self.self_attention(
+ query=x_norm1,
+ key=x_norm1,
+ value=x_norm1,
+ rotary_embedding=rotary_embedding,
+ key_cache=cache_k,
+ value_cache=cache_v,
+ attention_mask=decoder_attention_mask,
+ training=training,
+ )
+ else:
+ x_self_attn, cache_k, cache_v = self.self_attention(
+ query=x_norm1,
+ key=x_norm1,
+ value=x_norm1,
+ rotary_embedding=rotary_embedding,
+ attention_mask=decoder_attention_mask,
+ training=training,
+ )
+ x = x_self_attn + residual
+ residual = x
+ x_norm2 = self.norm2(x)
+
+ if use_cache:
+ x_cross_attn = self.cross_attention(
+ query=x_norm2,
+ key=context,
+ value=context,
+ key_cache=x_attn_cache_k,
+ value_cache=x_attn_cache_v,
+ attention_mask=encoder_attention_mask,
+ training=training,
+ )
+ else:
+ x_cross_attn, x_attn_cache_k, x_attn_cache_v = self.cross_attention(
+ query=x_norm2,
+ key=context,
+ value=context,
+ attention_mask=encoder_attention_mask,
+ training=training,
+ )
+ x = x_cross_attn + residual
+ residual = x
+ x_norm3 = self.norm3(x)
+ x_ff = self.ff(x_norm3)
+ x = x_ff + residual
+
+ if use_cache:
+ return x, new_cache_k, new_cache_v
+ return x, cache_k, cache_v, x_attn_cache_k, x_attn_cache_v
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "hidden_dim": self.hidden_dim,
+ "intermediate_dim": self.intermediate_dim,
+ "num_heads": self.num_heads,
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
+ "use_swiglu_activation": self.use_swiglu_activation,
+ "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of, # noqa: E501
+ "initializer_range": self.initializer_range,
+ "attention_bias": self.attention_bias,
+ "attention_dropout": self.attention_dropout,
+ "dtype": self.dtype,
+ }
+ )
+ return config
diff --git a/keras_hub/src/models/moonshine/moonshine_decoder_test.py b/keras_hub/src/models/moonshine/moonshine_decoder_test.py
new file mode 100644
index 0000000000..9161b2cb1c
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_decoder_test.py
@@ -0,0 +1,204 @@
+import keras
+
+from keras_hub.src.models.moonshine.moonshine_decoder import (
+ MoonshineDecoderBlock,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineDecoderTest(TestCase):
+ def setUp(self):
+ super().setUp()
+ self.hidden_dim = 64
+ self.intermediate_dim = 256
+ self.num_heads = 4
+ self.init_kwargs = {
+ "hidden_dim": self.hidden_dim,
+ "intermediate_dim": self.intermediate_dim,
+ "num_heads": self.num_heads,
+ "feedforward_expansion_factor": 4,
+ "use_swiglu_activation": True,
+ "pad_head_dim_to_multiple_of": None,
+ "initializer_range": 0.02,
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ }
+ self.decoder_block = MoonshineDecoderBlock(**self.init_kwargs)
+ self.batch_size = 2
+ self.seq_len = 10
+ self.encoder_seq_len = 16
+ self.head_dim = self.hidden_dim // self.num_heads # 16
+ self.rotary_dim = int(
+ self.head_dim * 0.62
+ ) # Default partial_rotary_factor = 0.62
+ self.rotary_dim = (self.rotary_dim // 2) * 2 # Ensure even
+ self.rotary_dim = self.rotary_dim // 2 # Half for freqs, e.g., 4
+ self.x = keras.random.normal(
+ (self.batch_size, self.seq_len, self.hidden_dim)
+ )
+ self.context = keras.random.normal(
+ (self.batch_size, self.encoder_seq_len, self.hidden_dim)
+ )
+ self.rotary_embedding = keras.random.normal(
+ (self.seq_len, self.rotary_dim)
+ )
+ self.decoder_attention_mask = keras.ops.ones(
+ (self.batch_size, self.seq_len), dtype="bool"
+ )
+ self.encoder_attention_mask = keras.ops.ones(
+ (self.batch_size, self.encoder_seq_len), dtype="bool"
+ )
+
+ def test_initialization(self):
+ self.assertEqual(self.decoder_block.hidden_dim, self.hidden_dim)
+ self.assertEqual(
+ self.decoder_block.intermediate_dim, self.intermediate_dim
+ )
+ self.assertEqual(self.decoder_block.num_heads, self.num_heads)
+ self.assertTrue(self.decoder_block.use_swiglu_activation)
+
+ def test_forward_pass_without_caching(self):
+ outputs = self.decoder_block(
+ [self.x, self.context, self.rotary_embedding],
+ decoder_attention_mask=self.decoder_attention_mask,
+ encoder_attention_mask=self.encoder_attention_mask,
+ )
+ x, cache_k, cache_v, x_attn_cache_k, x_attn_cache_v = outputs
+ self.assertEqual(
+ x.shape, (self.batch_size, self.seq_len, self.hidden_dim)
+ )
+ self.assertEqual(
+ cache_k.shape,
+ (self.batch_size, self.seq_len, self.num_heads, self.head_dim),
+ )
+ self.assertEqual(
+ cache_v.shape,
+ (self.batch_size, self.seq_len, self.num_heads, self.head_dim),
+ )
+ self.assertEqual(
+ x_attn_cache_k.shape,
+ (
+ self.batch_size,
+ self.encoder_seq_len,
+ self.num_heads,
+ self.head_dim,
+ ),
+ )
+ self.assertEqual(
+ x_attn_cache_v.shape,
+ (
+ self.batch_size,
+ self.encoder_seq_len,
+ self.num_heads,
+ self.head_dim,
+ ),
+ )
+
+ def test_forward_pass_with_padding(self):
+ # Padding in decoder sequence.
+ padded_mask = keras.ops.concatenate(
+ [
+ keras.ops.ones((self.batch_size, 5), dtype="bool"),
+ keras.ops.zeros(
+ (self.batch_size, self.seq_len - 5), dtype="bool"
+ ),
+ ],
+ axis=1,
+ )
+ outputs = self.decoder_block(
+ [self.x, self.context, self.rotary_embedding],
+ decoder_attention_mask=padded_mask,
+ encoder_attention_mask=self.encoder_attention_mask,
+ )
+ x, _, _, _, _ = outputs
+ self.assertEqual(
+ x.shape, (self.batch_size, self.seq_len, self.hidden_dim)
+ )
+
+ def test_autoregressive_caching(self):
+ # First pass to get initial caches.
+ outputs_full = self.decoder_block(
+ [self.x, self.context, self.rotary_embedding],
+ decoder_attention_mask=self.decoder_attention_mask,
+ encoder_attention_mask=self.encoder_attention_mask,
+ )
+ _, cache_k_full, cache_v_full, x_attn_cache_k, x_attn_cache_v = (
+ outputs_full
+ )
+
+ # Autoregressive decoding.
+ for i in range(self.seq_len):
+ x_i = self.x[:, i : i + 1, :]
+ rotary_i = self.rotary_embedding[i : i + 1, :]
+ mask_i = self.decoder_attention_mask[:, i : i + 1]
+ cache_k = None if i == 0 else cache_k_full[:, :i, :, :]
+ cache_v = None if i == 0 else cache_v_full[:, :i, :, :]
+ outputs_i = self.decoder_block(
+ [
+ x_i,
+ self.context,
+ cache_k,
+ cache_v,
+ x_attn_cache_k,
+ x_attn_cache_v,
+ rotary_i,
+ ],
+ use_cache=True,
+ decoder_attention_mask=mask_i,
+ encoder_attention_mask=self.encoder_attention_mask,
+ )
+ x_i_out, new_cache_k, new_cache_v = outputs_i
+ self.assertEqual(
+ x_i_out.shape, (self.batch_size, 1, self.hidden_dim)
+ )
+ self.assertEqual(
+ new_cache_k.shape,
+ (self.batch_size, i + 1, self.num_heads, self.head_dim),
+ )
+ self.assertEqual(
+ new_cache_v.shape,
+ (self.batch_size, i + 1, self.num_heads, self.head_dim),
+ )
+
+ def test_caching_consistency(self):
+ # Full sequence without caching.
+ outputs_full = self.decoder_block(
+ [self.x, self.context, self.rotary_embedding],
+ decoder_attention_mask=self.decoder_attention_mask,
+ encoder_attention_mask=self.encoder_attention_mask,
+ )
+ x_full, _, _, _, _ = outputs_full
+
+ # Autoregressive with caching.
+ x_auto = []
+ cache_k, cache_v = None, None
+ x_attn_cache_k, x_attn_cache_v = (
+ outputs_full[3],
+ outputs_full[4],
+ ) # Precomputed cross-attention caches
+ for i in range(self.seq_len):
+ x_i = self.x[:, i : i + 1, :]
+ rotary_i = self.rotary_embedding[i : i + 1, :]
+ mask_i = self.decoder_attention_mask[:, i : i + 1]
+ outputs_i = self.decoder_block(
+ [
+ x_i,
+ self.context,
+ cache_k,
+ cache_v,
+ x_attn_cache_k,
+ x_attn_cache_v,
+ rotary_i,
+ ],
+ use_cache=True,
+ decoder_attention_mask=mask_i,
+ encoder_attention_mask=self.encoder_attention_mask,
+ )
+ x_i_out, cache_k, cache_v = outputs_i
+ x_auto.append(x_i_out)
+ x_auto = keras.ops.concatenate(x_auto, axis=1)
+ self.assertAllClose(x_full, x_auto, atol=1e-5)
+
+ def test_serialization(self):
+ instance = MoonshineDecoderBlock(**self.init_kwargs)
+ self.run_serialization_test(instance=instance)
diff --git a/keras_hub/src/models/moonshine/moonshine_encoder.py b/keras_hub/src/models/moonshine/moonshine_encoder.py
new file mode 100644
index 0000000000..5e8838384e
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_encoder.py
@@ -0,0 +1,213 @@
+import keras
+
+from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
+from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
+from keras_hub.src.models.moonshine.moonshine_layers import (
+ moonshine_kernel_initializer,
+)
+from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
+ MoonshineMultiHeadAttention,
+)
+from keras_hub.src.utils.keras_utils import clone_initializer
+
+
+@keras.saving.register_keras_serializable(package="keras_hub")
+class MoonshineEncoderBlock(TransformerEncoder):
+ """
+ Moonshine encoder block for sequence processing.
+
+ Implements a standard encoder block with self-attention and feedforward
+ sublayers, including residual connections and layer normalization. The
+ implementation utilizes Moonshine-specific attention and feedforward
+ mechanisms.
+
+ Args:
+ hidden_dim: int. The dimensionality of the model's hidden
+ representations throughout the block.
+ intermediate_dim: int. The dimensionality used in projections before
+ applying non-linearities.
+ num_heads: int. The number of attention heads for multi-head attention
+ computation.
+ feedforward_expansion_factor: int, optional. A multiplier for expanding
+ the dimension in the feedforward network. Defaults to 4.
+ use_swiglu_activation: bool, optional. Whether to use SwiGLU activation
+ (True) or LinearGeLU (False) in the feedforward sublayer. Defaults
+ to False.
+ pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
+ dimension to be a multiple of this value for hardware optimization.
+ Defaults to None.
+ initializer_range: float, optional. The standard deviation of the
+ truncated normal distribution used for weight initialization.
+ Defaults to 0.02.
+ attention_bias: bool, optional. Whether to use a bias term in the
+ attention mechanism. Defaults to False.
+ attention_dropout: float, optional. The dropout rate applied to the
+ attention weights. Defaults to 0.0.
+ dtype: str, optional. The data type to use for model computations and
+ weights. Defaults to None.
+ **kwargs: Additional keyword arguments passed to the base layer.
+ """
+
+ # References:
+ # Defined and formulated based on the UsefulSensors implementation of the
+ # EncoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L124-L161).
+
+ def __init__(
+ self,
+ hidden_dim,
+ intermediate_dim,
+ num_heads,
+ feedforward_expansion_factor=4,
+ use_swiglu_activation=False,
+ pad_head_dim_to_multiple_of=None,
+ dtype=None,
+ initializer_range=0.02,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ kwargs.pop("dropout", None)
+ kwargs.pop("activation", None)
+ kwargs.pop("kernel_initializer", None)
+ self.kernel_initializer = moonshine_kernel_initializer(
+ initializer_range=initializer_range
+ )
+ super().__init__(
+ intermediate_dim=intermediate_dim,
+ num_heads=num_heads,
+ dropout=attention_dropout,
+ activation="gelu" if use_swiglu_activation else "silu",
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ dtype=dtype,
+ **kwargs,
+ )
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.initializer_range = initializer_range
+ self.hidden_dim = hidden_dim
+ self.intermediate_dim = intermediate_dim
+ self.num_heads = num_heads
+ self.feedforward_expansion_factor = feedforward_expansion_factor
+ self.use_swiglu_activation = use_swiglu_activation
+
+ # Self-attention sublayers.
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
+
+ self.head_dim = hidden_dim // num_heads
+ if pad_head_dim_to_multiple_of is not None:
+ self.head_dim = (
+ (self.head_dim + pad_head_dim_to_multiple_of - 1)
+ // pad_head_dim_to_multiple_of
+ ) * pad_head_dim_to_multiple_of
+
+ self.self_attention_layer = MoonshineMultiHeadAttention(
+ num_heads=num_heads,
+ key_dim=self.head_dim,
+ use_bias=False,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ attention_bias=attention_bias,
+ attention_dropout=attention_dropout,
+ use_causal_mask=False,
+ apply_rotary_embedding=True,
+ cache_mode="none",
+ name="self_attention_layer",
+ dtype=self.dtype,
+ )
+ self.self_attention_layer_norm = keras.layers.LayerNormalization(
+ axis=-1,
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ name="self_attention_layer_norm",
+ dtype=self.dtype,
+ )
+
+ # Feedforward sublayers.
+ self.feedforward_layer_norm = keras.layers.LayerNormalization(
+ axis=-1,
+ epsilon=1e-5,
+ center=False,
+ scale=True,
+ name="feedforward_layer_norm",
+ dtype=self.dtype,
+ )
+ self.feedforward = MoonshineMLP(
+ hidden_dim=hidden_dim,
+ feedforward_expansion_factor=feedforward_expansion_factor,
+ use_swiglu_activation=use_swiglu_activation,
+ initializer_range=initializer_range,
+ name="feedforward",
+ dtype=self.dtype,
+ )
+
+ def build(self, input_shape):
+ if isinstance(input_shape, dict):
+ encoder_input_shape = input_shape["input_values"]
+ else:
+ encoder_input_shape = input_shape
+ # Build self-attention branch.
+ self.self_attention_layer_norm.build(encoder_input_shape)
+ self.self_attention_layer.build(
+ encoder_input_shape, encoder_input_shape, encoder_input_shape
+ )
+ # Build feedforward branch.
+ self.feedforward_layer_norm.build(encoder_input_shape)
+ # The feedforward layer expects the last dimension to be hidden_dim.
+ feed_forward_input_shape = list(encoder_input_shape)
+ feed_forward_input_shape[-1] = self.hidden_dim
+ self.feedforward.build(tuple(feed_forward_input_shape))
+ self.built = True
+
+ def call(
+ self,
+ inputs,
+ rotary_embedding,
+ attention_mask=None,
+ training=None,
+ **kwargs,
+ ):
+ x = inputs
+
+ # Self-attention block with residual connection.
+ attention_residual = x
+ x = self.self_attention_layer_norm(x)
+ x = self.self_attention_layer(
+ query=x,
+ value=x,
+ key=x,
+ rotary_embedding=rotary_embedding,
+ attention_mask=attention_mask,
+ training=training,
+ **kwargs,
+ )
+ x = x + attention_residual
+
+ # Feedforward block with residual connection.
+ ff_residual = x
+ x = self.feedforward_layer_norm(x)
+ x = self.feedforward(x)
+ x = x + ff_residual
+
+ return x
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ def get_config(self):
+ # ==== Config ====
+ config = super().get_config()
+ config.update(
+ {
+ "hidden_dim": self.hidden_dim,
+ "intermediate_dim": self.intermediate_dim,
+ "num_heads": self.num_heads,
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
+ "use_swiglu_activation": self.use_swiglu_activation,
+ "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of,
+ "initializer_range": self.initializer_range,
+ "attention_bias": self.attention_bias,
+ "attention_dropout": self.attention_dropout,
+ "dtype": self.dtype,
+ }
+ )
+ return config
diff --git a/keras_hub/src/models/moonshine/moonshine_layers.py b/keras_hub/src/models/moonshine/moonshine_layers.py
new file mode 100644
index 0000000000..228e56af26
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_layers.py
@@ -0,0 +1,315 @@
+import keras
+
+from keras_hub.src.utils.keras_utils import clone_initializer
+
+
+def moonshine_kernel_initializer(initializer_range=0.02):
+ return keras.initializers.TruncatedNormal(stddev=initializer_range)
+
+
+@keras.saving.register_keras_serializable(package="keras_hub")
+class MoonshineRotaryEmbedding(keras.layers.Layer):
+ """
+ Moonshine rotary embedding layer.
+
+ Computes rotary positional embeddings using precomputed inverse frequencies.
+ Supports two RoPE types: "default" and "dynamic".
+
+ - **Default RoPE**: Applies rotary embeddings to a fraction of dimensions
+ controlled by `partial_rotary_factor`.
+ - **Dynamic RoPE**: Updates frequencies dynamically based on sequence length
+ aligning functionally with the Hugging Face implementation.
+
+ The layer stores inverse frequency weights as a non-trainable parameter and
+ computes sinusoidal embeddings based on input positions. Unlike KerasHub's
+ `RotaryEmbedding` class, this implementation explicitly requires `head_dim`
+ and applies `partial_rotary_factor` for selective rotary embedding, whereas
+ KerasHub uses `max_wavelength` without partial application.
+
+ Args:
+ head_dim: int. The dimensionality of each attention head, determining
+ the feature space for rotary embeddings.
+ max_position_embeddings: int, optional. The maximum sequence length the
+ model can process, controlling the positional embedding scale.
+ Defaults to 2048.
+ base_value: float, optional. Base value for computing inverse
+ frequencies. Higher values result in longer wavelengths. Defaults to
+ 10000.
+ rope_scaling: dict, optional. Configuration for RoPE scaling, such as
+ `{"rope_type": "default"}` or `{"rope_type": "dynamic"}`.
+ Defaults to `{"rope_type": "default"}` if None.
+ partial_rotary_factor: float, optional. The fraction of `head_dim`
+ dimensions that receive rotary embeddings, balancing rotary and
+ non-rotary components. Defaults to 0.62.
+ dtype: string, optional. The data type for model computations and
+ weights. Defaults to None.
+ **kwargs: Additional keyword arguments passed to the parent class.
+ """
+
+ # References:
+ # Based on the UsefulSensors implementation of the RotaryEmbedding class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L176-L193).
+ # Incorporates dynamic RoPE concepts from the Hugging Face implementation (https://github.com/huggingface/transformers/blob/bc30dd1efb99f571d45b2e2131a555d09285ddd8/src/transformers/models/moonshine/modeling_moonshine.py#L311-L369).
+
+ def __init__(
+ self,
+ head_dim,
+ max_position_embeddings=2048,
+ base_value=10000,
+ rope_scaling=None,
+ partial_rotary_factor=0.62,
+ dtype=None,
+ **kwargs,
+ ):
+ super().__init__(dtype=dtype, **kwargs)
+ self.head_dim = head_dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base_value = base_value
+ self.partial_rotary_factor = partial_rotary_factor
+
+ if rope_scaling is None:
+ rope_scaling = {"rope_type": "default"}
+ self.rope_scaling = rope_scaling
+ self.rope_type = rope_scaling.get("rope_type", "default")
+
+ if self.rope_type == "default":
+ self.scaling_factor = 1.0
+ self.attention_scaling = 1.0
+ elif "dynamic" in self.rope_type:
+ self.scaling_factor = 1.0 # Initial scaling, updated dynamically
+ self.attention_scaling = 1.0
+ else:
+ raise NotImplementedError(
+ f"rope_type '{self.rope_type}' not implemented"
+ )
+
+ def build(self, input_shape):
+ # Create and track the non-trainable weight immediately.
+ rotary_dim = int(self.head_dim * self.partial_rotary_factor)
+ rotary_dim = (rotary_dim // 2) * 2
+ if rotary_dim <= 0:
+ raise ValueError(
+ f"Calculated rotary_dim ({rotary_dim}) must be a positive even "
+ f"number. Check head_dim ({self.head_dim}) and "
+ f"partial_rotary_factor ({self.partial_rotary_factor})."
+ )
+ rotary_dim_half = rotary_dim // 2
+
+ # Compute inv_freq.
+ inv_freq = 1.0 / (
+ self.base_value
+ ** (
+ keras.ops.arange(0, rotary_dim_half, dtype=self.dtype)
+ / rotary_dim_half
+ )
+ )
+ inv_freq = inv_freq * self.scaling_factor
+
+ # Set the non-trainable weight using the computed tensor.
+ self.inv_freq = self.add_weight(
+ name="inv_freq",
+ shape=(rotary_dim_half,),
+ initializer=keras.initializers.Constant(inv_freq),
+ trainable=False,
+ dtype=self.dtype,
+ )
+ self.original_inv_freq = keras.ops.convert_to_tensor(self.inv_freq)
+ self.max_sequence_length_cached = self.max_position_embeddings
+ self.built = True
+
+ def call(self, t, position_ids=None):
+ # "Dynamic" RoPE behavior.
+ if "dynamic" in self.rope_type:
+ if position_ids is None:
+ position_ids = keras.ops.expand_dims(t, axis=0)
+ seq_len = keras.ops.max(position_ids) + 1
+ if seq_len > self.max_position_embeddings:
+ scaling = keras.ops.cast(
+ self.max_position_embeddings, self.dtype
+ ) / keras.ops.cast(seq_len, self.dtype)
+ else:
+ scaling = keras.ops.cast(1.0, self.dtype)
+ current_inv_freq = self.original_inv_freq * scaling
+ if seq_len > self.max_sequence_length_cached:
+ self.max_sequence_length_cached = seq_len
+ elif (
+ seq_len < self.max_position_embeddings
+ and self.max_sequence_length_cached
+ > self.max_position_embeddings
+ ):
+ self.max_sequence_length_cached = self.max_position_embeddings
+
+ pos_cast = keras.ops.cast(position_ids, self.dtype)
+ freqs = pos_cast[:, :, None] * current_inv_freq[None, None, :]
+ cos = keras.ops.cos(freqs) * self.attention_scaling
+ sin = keras.ops.sin(freqs) * self.attention_scaling
+ return cos, sin
+ # Original "default" behavior.
+ else:
+ t_cast = keras.ops.cast(t, keras.ops.dtype(self.inv_freq))
+ original_shape = keras.ops.shape(t_cast)
+ is_generation_step = (
+ len(original_shape) == 2 and original_shape[1] == 1
+ )
+ if is_generation_step:
+ # (batch, 1) -> Squeeze to (batch,) for einsum "i,j->ij".
+ t_cast_for_einsum = keras.ops.squeeze(t_cast, axis=1)
+ freqs = keras.ops.einsum(
+ "i,j->ij", t_cast_for_einsum, self.inv_freq
+ ) # Shape (batch, rotary_dim_half)
+ elif len(original_shape) == 1:
+ t_cast_for_einsum = t_cast
+ freqs = keras.ops.einsum(
+ "i,j->ij", t_cast_for_einsum, self.inv_freq
+ ) # Shape (seq_len, rotary_dim_half)
+ else:
+ raise ValueError(
+ f"Unexpected shape for input 't' in "
+ f"MoonshineRotaryEmbedding default path: {original_shape}. "
+ "Expected (seq_len,) or (batch, 1)."
+ )
+ emb = keras.ops.stack((freqs, freqs), axis=-1)
+ shape_list = list(keras.ops.shape(emb))
+ shape_list[-2:] = [-1]
+ emb_flat = keras.ops.reshape(emb, shape_list)
+ if is_generation_step:
+ final_emb = keras.ops.expand_dims(emb_flat, axis=1)
+ else:
+ final_emb = keras.ops.expand_dims(emb_flat, axis=0)
+ return final_emb
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "head_dim": self.head_dim,
+ "max_position_embeddings": self.max_position_embeddings,
+ "base_value": self.base_value,
+ "rope_scaling": self.rope_scaling,
+ "partial_rotary_factor": self.partial_rotary_factor,
+ "dtype": self.dtype,
+ }
+ )
+ return config
+
+
+@keras.saving.register_keras_serializable(package="keras_hub")
+class MoonshineMLP(keras.layers.Layer):
+ """
+ Moonshine MLP layer.
+
+ Implements a Multi-Layer Perceptron (MLP) for Moonshine models with support
+ for both `SwiGLU` and `LinearGeLU` activation patterns. The MLP consists of
+ two dense layers with an activation function in between, expanding the input
+ dimension before projecting back to the original dimension.
+
+ Args:
+ hidden_dim: int. The dimensionality of the input and output tensors.
+ feedforward_expansion_factor: float. The factor by which to expand the
+ hidden dimension in the intermediate layer.
+ use_swiglu_activation: bool, optional. If `True`, uses SwiGLU activation
+ (SiLU with gating). If `False`, uses standard GeLU activation.
+ Defaults to `True`.
+ initializer_range: float, optional. The standard deviation for kernel
+ initialization. Defaults to 0.02.
+ dtype: string, optional. The data type for model computations and
+ weights. Defaults to `None`.
+ **kwargs: Additional keyword arguments passed to the parent class.
+ """
+
+ # References:
+ # Based on the HuggingFace implementation of the MoonshineEncoderMLP and
+ # MoonshineDecoderMLP classes (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L66-L94).
+
+ def __init__(
+ self,
+ hidden_dim,
+ feedforward_expansion_factor,
+ use_swiglu_activation=True,
+ initializer_range=0.02,
+ dtype=None,
+ **kwargs,
+ ):
+ super().__init__(dtype=dtype, **kwargs)
+ self.hidden_dim = hidden_dim
+ self.feedforward_expansion_factor = feedforward_expansion_factor
+ self.use_swiglu_activation = use_swiglu_activation
+ self.kernel_initializer = moonshine_kernel_initializer(
+ initializer_range=initializer_range
+ )
+ self.initializer_range = initializer_range
+
+ if use_swiglu_activation:
+ # First dense layer produces (2 * feedforward_expansion_factor *
+ # hidden_dim) outputs.
+ self.dense_1 = keras.layers.Dense(
+ int(hidden_dim * feedforward_expansion_factor * 2),
+ use_bias=True,
+ name="dense_1",
+ dtype=self.dtype,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ )
+ # Activation layer using "silu" (Swish activation).
+ self.activation = keras.layers.Activation(
+ "silu", name="activation", dtype=self.dtype
+ )
+ else:
+ # Taken from pretrained weights.
+ # First dense layer: output dimension is (hidden_dim *
+ # feedforward_expansion_factor).
+ self.dense_1 = keras.layers.Dense(
+ int(hidden_dim * feedforward_expansion_factor),
+ use_bias=True,
+ name="dense_1",
+ dtype=self.dtype,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ )
+ self.activation = keras.layers.Activation(
+ "gelu", name="activation", dtype=self.dtype
+ )
+
+ # Second dense layer projects back to hidden_dim.
+ self.dense_2 = keras.layers.Dense(
+ hidden_dim,
+ use_bias=True,
+ name="dense_2",
+ dtype=self.dtype,
+ kernel_initializer=clone_initializer(self.kernel_initializer),
+ )
+
+ def build(self, input_shape):
+ super().build(input_shape)
+ # Build the first dense layer using the original input shape.
+ self.dense_1.build(input_shape)
+ # After dense_1, the output shape becomes: (..., 2 *
+ # feedforward_expansion_factor * hidden_dim).
+ # When splitting, each part will have shape (...,
+ # feedforward_expansion_factor * hidden_dim).
+ new_input_shape = list(input_shape)
+ new_input_shape[-1] = (
+ self.hidden_dim * self.feedforward_expansion_factor
+ )
+ self.dense_2.build(tuple(new_input_shape))
+
+ def call(self, inputs):
+ x = self.dense_1(inputs)
+ if self.use_swiglu_activation:
+ x1, gate = keras.ops.split(x, 2, axis=-1)
+ activated_gate = self.activation(gate)
+ x = x1 * activated_gate
+ else:
+ x = self.activation(x)
+ output = self.dense_2(x)
+ return output
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "hidden_dim": self.hidden_dim,
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
+ "use_swiglu_activation": self.use_swiglu_activation,
+ "initializer_range": self.initializer_range,
+ "dtype": self.dtype,
+ }
+ )
+ return config
diff --git a/keras_hub/src/models/moonshine/moonshine_layers_test.py b/keras_hub/src/models/moonshine/moonshine_layers_test.py
new file mode 100644
index 0000000000..c8880ddb0e
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_layers_test.py
@@ -0,0 +1,92 @@
+import keras
+
+from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
+from keras_hub.src.models.moonshine.moonshine_layers import (
+ MoonshineRotaryEmbedding,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineLayersTest(TestCase):
+ def test_moonshine_rotary_embedding(self):
+ layer = MoonshineRotaryEmbedding(
+ head_dim=64,
+ max_position_embeddings=2048,
+ base_value=10000,
+ rope_scaling=None,
+ partial_rotary_factor=0.62,
+ dtype="float32",
+ )
+ input_data = keras.ops.arange(10, dtype="float32")
+ output_data = layer(input_data)
+ expected_output_shape = (1, 10, 38)
+ self.assertEqual(keras.ops.shape(output_data), expected_output_shape)
+ self.assertEqual(len(layer.trainable_weights), 0)
+ self.assertEqual(len(layer.non_trainable_weights), 1)
+ self.assertEqual(len(layer.non_trainable_variables), 1)
+
+ def test_moonshine_rotary_embedding_dynamic(self):
+ layer = MoonshineRotaryEmbedding(
+ head_dim=64,
+ max_position_embeddings=10,
+ base_value=10000,
+ rope_scaling={"rope_type": "dynamic"},
+ partial_rotary_factor=1.0,
+ )
+ # Compute original inverse frequencies.
+ rotary_dim = 32 # Derived from head_dim = 64, partial_rotary_factor = 1
+ arange = keras.ops.arange(0, rotary_dim, dtype="float32")
+ original_inv_freq = 1.0 / (10000 ** (arange / rotary_dim))
+
+ # seq_len = 5 < 10.
+ position_ids = keras.ops.arange(5, dtype="int32")[None, :] # [1, 5]
+ cos1, sin1 = layer(None, position_ids=position_ids) # [1, 5, 32]
+ expected_cos1 = keras.ops.cos(original_inv_freq)
+ expected_sin1 = keras.ops.sin(original_inv_freq)
+ self.assertAllClose(cos1[0, 1, :], expected_cos1, rtol=1e-5)
+ self.assertAllClose(sin1[0, 1, :], expected_sin1, rtol=1e-5)
+
+ # seq_len = 15 > 10.
+ position_ids = keras.ops.arange(15, dtype="int32")[None, :] # [1, 15]
+ cos2, sin2 = layer(None, position_ids=position_ids) # [1, 15, 32]
+ scaling = 10 / 15 # 2 / 3
+ expected_cos2 = keras.ops.cos(original_inv_freq * scaling)
+ expected_sin2 = keras.ops.sin(original_inv_freq * scaling)
+ self.assertAllClose(cos2[0, 1, :], expected_cos2, rtol=1e-5)
+ self.assertAllClose(sin2[0, 1, :], expected_sin2, rtol=1e-5)
+
+ # seq_len = 8 < 10, should reset.
+ position_ids = keras.ops.arange(8, dtype="int32")[None, :] # [1, 8]
+ cos3, sin3 = layer(None, position_ids=position_ids) # [1, 8, 32]
+ self.assertAllClose(cos3[0, 1, :], expected_cos1, rtol=1e-5)
+ self.assertAllClose(sin3[0, 1, :], expected_sin1, rtol=1e-5)
+
+ def test_moonshine_mlp_swiglu(self):
+ self.run_layer_test(
+ cls=MoonshineMLP,
+ init_kwargs={
+ "hidden_dim": 64,
+ "feedforward_expansion_factor": 4,
+ "use_swiglu_activation": True,
+ },
+ input_data=keras.random.uniform((2, 10, 64), dtype="float32"),
+ expected_output_shape=(2, 10, 64),
+ expected_num_trainable_weights=4,
+ expected_num_non_trainable_weights=0,
+ run_precision_checks=False,
+ )
+
+ def test_moonshine_mlp_linear_gelu(self):
+ self.run_layer_test(
+ cls=MoonshineMLP,
+ init_kwargs={
+ "hidden_dim": 64,
+ "feedforward_expansion_factor": 4,
+ "use_swiglu_activation": False,
+ },
+ input_data=keras.random.uniform((2, 10, 64), dtype="float32"),
+ expected_output_shape=(2, 10, 64),
+ expected_num_trainable_weights=4,
+ expected_num_non_trainable_weights=0,
+ run_precision_checks=False,
+ )
diff --git a/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py b/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py
new file mode 100644
index 0000000000..072af883da
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py
@@ -0,0 +1,401 @@
+import keras
+from keras import backend
+
+from keras_hub.src.layers.modeling.cached_multi_head_attention import (
+ CachedMultiHeadAttention,
+)
+from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import (
+ _build_proj_equation,
+)
+from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import (
+ _get_output_shape,
+)
+
+
+# Removed dependence on einops.
+# Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L35
+def _rotate_half(x):
+ """
+ Rotates the two halves of the last dimension.
+
+ This function splits the last dimension of the input tensor into two equal
+ halves and swaps them with a sign inversion. Specifically, for an input of
+ shape `[..., 2*d]`, it returns a tensor of the same shape where `[x1, x2]`
+ is transformed into `[-x2, x1]`.
+
+ Args:
+ x: Tensor. Shape `[..., 2*d]`. The input tensor to be rotated.
+
+ Returns:
+ Tensor: A tensor of shape `[..., 2*d]` with the two halves rotated.
+ """
+ # Conditional for Tensorflow backend.
+ if backend.backend() == "tensorflow":
+ x_shape = keras.ops.shape(x)
+ last_dim = x_shape[-1]
+ d = last_dim // 2
+ x_shape_tensor = keras.ops.convert_to_tensor(x_shape)
+ new_shape = keras.ops.concatenate(
+ [x_shape_tensor[:-1], keras.ops.convert_to_tensor([d, 2])], axis=0
+ )
+ x = keras.ops.reshape(x, new_shape)
+ x1 = x[..., 0]
+ x2 = x[..., 1]
+ x_rotated = keras.ops.stack([-x2, x1], axis=-1)
+ x_rotated = keras.ops.reshape(x_rotated, x_shape)
+ return x_rotated
+
+ # Conditional for PyTorch and JAX backends.
+ if backend.backend() == "torch" or backend.backend() == "jax":
+ x_shape = keras.ops.shape(x)
+ x_shape_tuple = tuple(
+ int(keras.ops.convert_to_numpy(dim).item()) for dim in x_shape
+ )
+ last_dim = x_shape_tuple[-1]
+ d = last_dim // 2
+ new_shape = x_shape_tuple[:-1] + (d, 2)
+ x = keras.ops.reshape(x, new_shape)
+ x1 = x[..., 0]
+ x2 = x[..., 1]
+ x_rotated = keras.ops.stack([-x2, x1], axis=-1)
+ x_rotated = keras.ops.reshape(x_rotated, x_shape_tuple)
+ return x_rotated
+
+ else:
+ raise NotImplementedError(
+ "Backend not supported. Please use TensorFlow, PyTorch, or JAX."
+ )
+
+
+def _apply_rotary_pos_emb(t, freqs):
+ """
+ Applies rotary positional embeddings to the input tensor. Used in on-the-fly
+ computation of rotary positional embeddings in multi-head attention layers.
+
+ Args:
+ t: A tensor with shape `[..., seq_len, ..., hidden_dim]` where the
+ rotary embedding is applied to the first `rot_dim` channels of the
+ last dimension.
+ freqs: A tensor of frequency values with shape `[max_seq_len, rot_dim]`.
+ The last `seq_len` entries are used to compute the rotary
+ embeddings.
+
+ Returns:
+ Tensor: A tensor of the same shape as `t` with the rotary positional
+ embeddings applied to the first `rot_dim` channels of the last dimension
+ and the remaining channels concatenated unchanged.
+ """
+ rot_dim = keras.ops.shape(freqs)[-1]
+ seq_len = keras.ops.shape(t)[-3]
+ orig_dtype = t.dtype
+ freqs = freqs[:seq_len, :]
+ freqs = keras.ops.reshape(freqs, (seq_len, 1, rot_dim))
+ t_rot = t[..., :rot_dim]
+ t_nonrot = t[..., rot_dim:]
+ t_rotated = t_rot * keras.ops.cos(freqs) + _rotate_half(
+ t_rot
+ ) * keras.ops.sin(freqs)
+ out = keras.ops.concatenate([t_rotated, t_nonrot], axis=-1)
+ return keras.ops.cast(out, orig_dtype)
+
+
+@keras.saving.register_keras_serializable(package="keras_hub")
+class MoonshineMultiHeadAttention(CachedMultiHeadAttention):
+ """
+ Moonshine multi-head attention layer.
+
+ Implements a multi-head attention mechanism for Moonshine models with
+ support for rotary position embeddings and different caching strategies.
+ This layer extends the `CachedMultiHeadAttention` base class to include
+ specialized functionality for Moonshine models, such as rotary embeddings
+ and causal masking.
+
+ Args:
+ num_heads: int. Number of attention heads.
+ key_dim: int. Size of each attention head for key.
+ value_dim: int, optional. Size of each attention head for value. If
+ None, defaults to `key_dim`.
+ attention_bias: bool, optional. Whether to include bias in attention
+ projection layers. Defaults to `False`.
+ attention_dropout: float, optional. Dropout probability for attention
+ weights. Defaults to 0.0.
+ use_causal_mask: bool, optional. Whether to apply causal masking to
+ prevent positions from attending to subsequent positions. Defaults
+ to `False`.
+ apply_rotary_embedding: bool, optional. Whether to apply rotary position
+ embeddings to queries and keys. Defaults to `True`.
+ cache_mode: str, optional. Mode for key-value caching. Must be one of:
+ 'none': No caching.
+ 'autoregressive': Incremental caching for autoregressive generation.
+ 'precomputed': Use precomputed key-value pairs. Defaults to None.
+ **kwargs: Additional keyword arguments passed to the parent class.
+ """
+
+ # References:
+ # Based on the HuggingFace implementation of the MoonshineAttention class (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L184-L315).
+
+ def __init__(
+ self,
+ num_heads,
+ key_dim,
+ value_dim=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ use_causal_mask=False,
+ apply_rotary_embedding=True,
+ cache_mode="none",
+ **kwargs,
+ ):
+ kwargs.pop("use_bias", None)
+ kwargs.pop("dropout", None)
+ super().__init__(
+ num_heads=num_heads,
+ key_dim=key_dim,
+ value_dim=value_dim,
+ use_bias=attention_bias,
+ dropout=attention_dropout,
+ **kwargs,
+ )
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.use_causal_mask = use_causal_mask
+ self.apply_rotary_embedding = apply_rotary_embedding
+ if cache_mode not in ["none", "autoregressive", "precomputed"]:
+ raise ValueError(
+ "cache_mode must be 'none', 'autoregressive', or 'precomputed'"
+ )
+ self.cache_mode = cache_mode
+
+ def build(self, query_shape, value_shape, key_shape=None):
+ # Ensure key_shape is defined.
+ key_shape = value_shape if key_shape is None else key_shape
+ query_rank = len(query_shape)
+ value_rank = len(value_shape)
+ key_rank = len(key_shape)
+
+ # Build query projection layer.
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ free_dims=query_rank - 1, bound_dims=1, output_dims=2
+ )
+ self._query_dense = keras.layers.EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(
+ output_rank - 1, [self._num_heads, self._key_dim]
+ ),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="query",
+ **self._get_common_kwargs_for_sublayer(),
+ )
+ self._query_dense.build(query_shape)
+
+ # Build key projection layer.
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ free_dims=key_rank - 1, bound_dims=1, output_dims=2
+ )
+ self._key_dense = keras.layers.EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(
+ output_rank - 1, [self._num_heads, self._key_dim]
+ ),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="key",
+ **self._get_common_kwargs_for_sublayer(),
+ )
+ self._key_dense.build(key_shape)
+
+ # Build value projection layer.
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ free_dims=value_rank - 1, bound_dims=1, output_dims=2
+ )
+ self._value_dense = keras.layers.EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(
+ output_rank - 1, [self._num_heads, self._value_dim]
+ ),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="value",
+ **self._get_common_kwargs_for_sublayer(),
+ )
+ self._value_dense.build(value_shape)
+
+ # Build the internal attention computation sublayer.
+ self._build_attention(output_rank)
+
+ # Build output projection layer.
+ output_shape = (
+ query_shape[-1] if not self._output_shape else self._output_shape
+ )
+ if isinstance(output_shape, (list, tuple)):
+ output_shape = list(output_shape)
+ else:
+ output_shape = [output_shape]
+
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
+ free_dims=query_rank - 1,
+ bound_dims=2,
+ output_dims=len(output_shape),
+ )
+ self._output_dense = keras.layers.EinsumDense(
+ einsum_equation,
+ output_shape=_get_output_shape(output_rank - 1, output_shape),
+ bias_axes=bias_axes if self._use_bias else None,
+ name="attention_output",
+ **self._get_common_kwargs_for_sublayer(),
+ )
+ output_dense_input_shape = list(
+ self._query_dense.compute_output_shape(query_shape)
+ )
+ output_dense_input_shape[-1] = self._value_dim
+ self._output_dense.build(tuple(output_dense_input_shape))
+
+ self.built = True
+
+ def _compute_causal_mask(self, query, value=None, for_cache=False):
+ if backend.backend() == "torch" or backend.backend() == "jax":
+ q_seq_length = int(
+ keras.ops.convert_to_numpy(keras.ops.shape(query)[1]).item()
+ )
+ v_seq_length = (
+ int(
+ keras.ops.convert_to_numpy(keras.ops.shape(value)[1]).item()
+ )
+ if value is not None
+ else q_seq_length
+ )
+ elif backend.backend() == "tensorflow":
+ if for_cache:
+ assert value is not None
+ v_seq_length = keras.ops.shape(value)[1]
+ else:
+ v_seq_length = keras.ops.shape(query)[1]
+ q_seq_length = keras.ops.shape(query)[1]
+ n_rows = v_seq_length if for_cache else q_seq_length
+ ones_mask = keras.ops.ones((1, n_rows, v_seq_length), dtype="int32")
+ row_index = keras.ops.cumsum(ones_mask, axis=-2)
+ col_index = keras.ops.cumsum(ones_mask, axis=-1)
+ mask = keras.ops.greater_equal(row_index, col_index)
+
+ if for_cache:
+ mask = mask[:, -q_seq_length:, :]
+
+ return mask
+
+ def call(
+ self,
+ query,
+ value,
+ key,
+ rotary_embedding=None,
+ attention_mask=None,
+ key_cache=None,
+ value_cache=None,
+ training=None,
+ **kwargs,
+ ):
+ # Project inputs.
+ query_proj = self._query_dense(query)
+ if rotary_embedding is not None:
+ query_proj = _apply_rotary_pos_emb(query_proj, rotary_embedding)
+
+ # Handle caching.
+ if self.cache_mode == "none":
+ key_proj = self._key_dense(key)
+ value_proj = self._value_dense(value)
+ if self.apply_rotary_embedding and rotary_embedding is not None:
+ key_proj = _apply_rotary_pos_emb(key_proj, rotary_embedding)
+ final_key = key_proj
+ final_value = value_proj
+ elif self.cache_mode == "autoregressive":
+ if key_cache is None and value_cache is not None:
+ raise ValueError(
+ "key_cache must be provided if value_cache is provided"
+ )
+ new_key = self._key_dense(key)
+ new_value = self._value_dense(value)
+ if self.apply_rotary_embedding and rotary_embedding is not None:
+ new_key = _apply_rotary_pos_emb(new_key, rotary_embedding)
+ if key_cache is not None and value_cache is not None:
+ final_key = keras.ops.concatenate((key_cache, new_key), axis=-3)
+ final_value = keras.ops.concatenate(
+ (value_cache, new_value), axis=-3
+ )
+ else:
+ final_key = new_key
+ final_value = new_value
+ elif self.cache_mode == "precomputed":
+ if key_cache is None and value_cache is not None:
+ raise ValueError(
+ "key_cache must be provided if value_cache is provided"
+ )
+ if key_cache is not None and value_cache is not None:
+ final_key = key_cache
+ final_value = value_cache
+ else:
+ final_key = self._key_dense(key)
+ final_value = self._value_dense(value)
+ else:
+ raise ValueError(f"Invalid cache_mode: {self.cache_mode}")
+
+ # Compute attention mask.
+ if self.use_causal_mask:
+ causal_mask = self._compute_causal_mask(
+ query,
+ final_value if self.cache_mode == "autoregressive" else None,
+ for_cache=(
+ self.cache_mode == "autoregressive"
+ and key_cache is not None
+ ),
+ )
+ # Combine with attention_mask if provided.
+ if attention_mask is not None:
+ # [batch_size, seq_len_k] → [batch_size, 1, 1, seq_len_k].
+ attention_mask_expanded = keras.ops.expand_dims(
+ attention_mask, axis=1
+ )
+ attention_mask_expanded = keras.ops.expand_dims(
+ attention_mask_expanded, axis=-1
+ )
+ final_mask = keras.ops.logical_and(
+ causal_mask, attention_mask_expanded
+ )
+ else:
+ final_mask = causal_mask
+ else:
+ if attention_mask is not None:
+ if self.cache_mode == "none":
+ seq_len = keras.ops.shape(query)[1]
+ final_mask = keras.ops.tile(
+ attention_mask[:, None, :], [1, seq_len, 1]
+ )
+ elif self.cache_mode == "precomputed":
+ final_mask = attention_mask[:, None, None, :]
+ else: # Autoregressive
+ final_mask = attention_mask
+ else:
+ final_mask = None
+
+ attention_kwargs = {
+ k: v for k, v in kwargs.items() if k != "padding_mask"
+ }
+ # Compute attention.
+ attention_output, _ = self._compute_attention(
+ query=query_proj,
+ key=final_key,
+ value=final_value,
+ attention_mask=final_mask,
+ training=training,
+ **attention_kwargs,
+ )
+
+ # Project the attention output.
+ output = self._output_dense(attention_output)
+
+ # Return based on cache_mode.
+ if self.cache_mode == "none":
+ return output
+ elif self.cache_mode == "autoregressive":
+ return output, final_key, final_value
+ elif self.cache_mode == "precomputed":
+ if key_cache is not None and value_cache is not None:
+ return output
+ return output, final_key, final_value
diff --git a/keras_hub/src/models/moonshine/moonshine_multi_head_attention_test.py b/keras_hub/src/models/moonshine/moonshine_multi_head_attention_test.py
new file mode 100644
index 0000000000..2d15cba0c4
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_multi_head_attention_test.py
@@ -0,0 +1,160 @@
+import keras
+
+from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
+ MoonshineMultiHeadAttention,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineMultiHeadAttentionTest(TestCase):
+ def setUp(self):
+ super().setUp()
+ self.num_heads = 4
+ self.key_dim = 16
+ self.hidden_dim = self.num_heads * self.key_dim
+ self.init_kwargs = {
+ "num_heads": self.num_heads,
+ "key_dim": self.key_dim,
+ "value_dim": None,
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "use_causal_mask": False,
+ "apply_rotary_embedding": True,
+ "cache_mode": "none",
+ }
+ self.attention_layer = MoonshineMultiHeadAttention(**self.init_kwargs)
+ self.batch_size = 2
+ self.query_seq_len = 10
+ self.key_seq_len = 16
+ self.rotary_dim = int(
+ self.key_dim * 0.62
+ ) # Default partial_rotary_factor = 0.62
+ self.rotary_dim = (self.rotary_dim // 2) * 2 # Ensure even
+ self.rotary_dim = self.rotary_dim // 2 # Half for freqs, e.g., 4
+ self.query = keras.random.normal(
+ (self.batch_size, self.query_seq_len, self.hidden_dim)
+ )
+ self.key = keras.random.normal(
+ (self.batch_size, self.key_seq_len, self.hidden_dim)
+ )
+ self.value = self.key # For testing purposes
+ self.rotary_embedding = keras.random.normal(
+ (self.query_seq_len, self.rotary_dim)
+ )
+ self.attention_mask = keras.ops.ones(
+ (self.batch_size, self.key_seq_len), dtype="bool"
+ )
+
+ def test_initialization(self):
+ self.assertEqual(self.attention_layer.num_heads, self.num_heads)
+ self.assertEqual(self.attention_layer.key_dim, self.key_dim)
+ self.assertFalse(self.attention_layer.attention_bias)
+ self.assertTrue(self.attention_layer.apply_rotary_embedding)
+
+ def test_forward_pass_without_caching(self):
+ self.attention_layer.apply_rotary_embedding = (
+ False # Test cross-attention
+ )
+ output = self.attention_layer(
+ query=self.query,
+ key=self.key,
+ value=self.value,
+ rotary_embedding=self.rotary_embedding,
+ attention_mask=self.attention_mask,
+ )
+ self.assertEqual(
+ output.shape, (self.batch_size, self.query_seq_len, self.hidden_dim)
+ )
+
+ def test_precomputed_caching(self):
+ self.attention_layer.build(
+ query_shape=(self.batch_size, self.query_seq_len, self.hidden_dim),
+ value_shape=(self.batch_size, self.key_seq_len, self.hidden_dim),
+ key_shape=(self.batch_size, self.key_seq_len, self.hidden_dim),
+ )
+ self.attention_layer.cache_mode = "precomputed"
+ self.attention_layer.apply_rotary_embedding = False
+ key_proj = self.attention_layer._key_dense(self.key)
+ value_proj = self.attention_layer._value_dense(self.value)
+ output_precomputed = self.attention_layer(
+ query=self.query,
+ key=None,
+ value=None,
+ key_cache=key_proj,
+ value_cache=value_proj,
+ rotary_embedding=self.rotary_embedding,
+ attention_mask=self.attention_mask,
+ )
+ self.attention_layer.cache_mode = "none"
+ output_normal = self.attention_layer(
+ query=self.query,
+ key=self.key,
+ value=self.value,
+ rotary_embedding=self.rotary_embedding,
+ attention_mask=self.attention_mask,
+ )
+ self.assertEqual(
+ output_precomputed.shape,
+ (self.batch_size, self.query_seq_len, self.hidden_dim),
+ )
+ self.assertAllClose(output_precomputed, output_normal, atol=1e-5)
+
+ def test_autoregressive_caching(self):
+ self.attention_layer.cache_mode = "autoregressive"
+ self.attention_layer.use_causal_mask = True # Ensure causal attention
+ cache_k, cache_v = None, None
+ outputs_auto = []
+ for i in range(self.query_seq_len):
+ query_i = self.query[:, i : i + 1, :]
+ key_i = self.query[:, i : i + 1, :] # Self-attention
+ value_i = self.query[:, i : i + 1, :]
+ rotary_i = self.rotary_embedding[i : i + 1, :]
+ output_i, new_cache_k, new_cache_v = self.attention_layer(
+ query=query_i,
+ key=key_i,
+ value=value_i,
+ rotary_embedding=rotary_i,
+ key_cache=cache_k,
+ value_cache=cache_v,
+ )
+ outputs_auto.append(output_i)
+ self.assertEqual(
+ output_i.shape, (self.batch_size, 1, self.hidden_dim)
+ )
+ self.assertEqual(
+ new_cache_k.shape,
+ (self.batch_size, i + 1, self.num_heads, self.key_dim),
+ )
+ self.assertEqual(
+ new_cache_v.shape,
+ (self.batch_size, i + 1, self.num_heads, self.key_dim),
+ )
+ cache_k, cache_v = new_cache_k, new_cache_v
+ outputs_auto = keras.ops.concatenate(outputs_auto, axis=1)
+ self.attention_layer.cache_mode = "none"
+ self.attention_layer.use_causal_mask = (
+ True # Consistent with autoregressive
+ )
+ output_full = self.attention_layer(
+ query=self.query,
+ key=self.query,
+ value=self.query,
+ rotary_embedding=self.rotary_embedding,
+ )
+ self.assertAllClose(outputs_auto, output_full, atol=1e-5)
+
+ def test_forward_pass_with_causal_mask(self):
+ self.attention_layer.use_causal_mask = True
+ output = self.attention_layer(
+ query=self.query,
+ key=self.query, # Self-attention for causal test
+ value=self.query,
+ rotary_embedding=self.rotary_embedding,
+ )
+ self.assertEqual(
+ output.shape, (self.batch_size, self.query_seq_len, self.hidden_dim)
+ )
+
+ def test_serialization(self):
+ instance = MoonshineMultiHeadAttention(**self.init_kwargs)
+ self.run_serialization_test(instance=instance)
diff --git a/keras_hub/src/models/moonshine/moonshine_presets.py b/keras_hub/src/models/moonshine/moonshine_presets.py
new file mode 100644
index 0000000000..b394c454c0
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_presets.py
@@ -0,0 +1,25 @@
+# Metadata for loading pretrained model weights.
+backbone_presets = {
+ "moonshine_tiny_en": {
+ "metadata": {
+ "description": (
+ "Moonshine tiny model for English speech recognition. "
+ "Developed by Useful Sensors for real-time transcription."
+ ),
+ "params": 27092736,
+ "path": "moonshine",
+ },
+ "kaggle_handle": "",
+ },
+ "moonshine_base_en": {
+ "metadata": {
+ "description": (
+ "Moonshine base model for English speech recognition. "
+ "Developed by Useful Sensors for real-time transcription."
+ ),
+ "params": 61513920,
+ "path": "moonshine",
+ },
+ "kaggle_handle": "",
+ },
+}
diff --git a/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor.py
new file mode 100644
index 0000000000..2500505bf8
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor.py
@@ -0,0 +1,177 @@
+import keras
+
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
+from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
+from keras_hub.src.utils.tensor_utils import preprocessing_function
+from keras_hub.src.utils.tensor_utils import strip_to_ragged
+
+
+@keras_hub_export("keras_hub.models.MoonshineSeq2SeqLMPreprocessor")
+class MoonshineSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor):
+ """Moonshine Seq2Seq LM preprocessor for audio-to-text tasks.
+
+ This preprocessor converts raw audio and text inputs into a format suitable
+ for the `MoonshineAudioToText` model. It processes audio waveforms into
+ features using `MoonshineAudioConverter` for the encoder and tokenizes text
+ using `MoonshineTokenizer` for the decoder. It supports training and
+ generation.
+
+ Args:
+ audio_converter: A `MoonshineAudioConverter` instance to process audio.
+ tokenizer: A `MoonshineTokenizer` instance to tokenize text.
+ encoder_sequence_length: int, optional. Maximum length for audio
+ features. If None, features are variable-length with padding masks.
+ Defaults to None.
+ decoder_sequence_length: int, optional. Maximum length for decoder token
+ sequences. Defaults to 1024.
+ **kwargs: Additional keyword arguments for the parent class.
+
+ Examples:
+ ```python
+ # Create audio converter and tokenizer instances.
+ audio_converter = keras_hub.models.MoonshineAudioConverter()
+ tokenizer = keras_hub.models.MoonshineTokenizer.from_preset(
+ "moonshine_base"
+ )
+
+ # Initialize the preprocessor.
+ preprocessor = keras_hub.models.MoonshineSeq2SeqLMPreprocessor(
+ audio_converter=audio_converter,
+ tokenizer=tokenizer,
+ decoder_sequence_length=8
+ )
+
+ # Prepare input data (audio tensor and text).
+ inputs = {
+ "audio": keras.random.normal((1, 16000, 1)),
+ "text": ["the quick brown fox"]
+ }
+
+ # Process the inputs.
+ preprocessed = preprocessor(inputs)
+ """
+
+ backbone_cls = MoonshineBackbone
+ tokenizer_cls = MoonshineTokenizer
+
+ def __init__(
+ self,
+ audio_converter,
+ tokenizer,
+ encoder_sequence_length=None,
+ decoder_sequence_length=1024,
+ **kwargs,
+ ):
+ super().__init__(tokenizer=tokenizer, **kwargs)
+ self.audio_converter = audio_converter
+ self.encoder_sequence_length = encoder_sequence_length
+ self.decoder_sequence_length = decoder_sequence_length
+ self.encoder_packer = None
+ self.decoder_packer = None
+
+ def build(self, input_shape):
+ self.audio_converter.build(input_shape)
+ self.decoder_packer = StartEndPacker(
+ start_value=self.tokenizer.start_token_id,
+ end_value=self.tokenizer.end_token_id,
+ pad_value=self.tokenizer.pad_token_id,
+ sequence_length=self.decoder_sequence_length,
+ return_padding_mask=True,
+ )
+ self.built = True
+
+ @preprocessing_function
+ def call(
+ self,
+ x,
+ y=None,
+ sample_weight=None,
+ encoder_sequence_length=None,
+ decoder_sequence_length=None,
+ sequence_length=None,
+ ):
+ if not self.built:
+ self.build(None)
+ if isinstance(x, tuple) and len(x) == 1:
+ x = x[0]
+ encoder_sequence_length = (
+ encoder_sequence_length or self.encoder_sequence_length
+ )
+ decoder_sequence_length = (
+ decoder_sequence_length
+ or sequence_length
+ or self.decoder_sequence_length
+ )
+ text = x["text"]
+ audio_features = self.audio_converter(x["audio"])
+ encoder_inputs = audio_features["input_values"]
+ encoder_padding_mask = audio_features["attention_mask"]
+ decoder_inputs = self.tokenizer(text)
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
+ decoder_inputs,
+ sequence_length=decoder_sequence_length + 1,
+ )
+ x_out = {
+ "encoder_input_values": encoder_inputs,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_token_ids": decoder_token_ids[..., :-1],
+ "decoder_padding_mask": decoder_padding_mask[..., :-1],
+ }
+ y_out = decoder_token_ids[..., 1:]
+ sample_weight_out = decoder_padding_mask[..., 1:]
+
+ return keras.utils.pack_x_y_sample_weight(
+ x_out, y_out, sample_weight_out
+ )
+
+ @preprocessing_function
+ def generate_preprocess(
+ self,
+ x,
+ encoder_sequence_length=None,
+ decoder_sequence_length=None,
+ sequence_length=None,
+ ):
+ if not self.built:
+ self.build(None)
+ if isinstance(x, tuple) and len(x) == 1:
+ x = x[0]
+ decoder_sequence_length = (
+ decoder_sequence_length
+ or sequence_length
+ or self.decoder_sequence_length
+ )
+ audio_features = self.audio_converter(x["audio"])
+ encoder_token_ids = audio_features["input_values"]
+ encoder_padding_mask = audio_features["attention_mask"]
+ decoder_text = x.get("text", [""] * keras.ops.shape(x["audio"])[0])
+ decoder_token_ids = self.tokenizer(decoder_text)
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
+ decoder_token_ids,
+ sequence_length=decoder_sequence_length,
+ add_end_value=False,
+ )
+
+ return {
+ "encoder_input_values": encoder_token_ids,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_token_ids": decoder_token_ids,
+ "decoder_padding_mask": decoder_padding_mask,
+ }
+
+ @preprocessing_function
+ def generate_postprocess(self, x):
+ if not self.built:
+ self.build(None)
+ token_ids, padding_mask = (
+ x["decoder_token_ids"],
+ x["decoder_padding_mask"],
+ )
+ ids_to_strip = self.tokenizer.special_token_ids
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
+ return self.tokenizer.detokenize(token_ids)
diff --git a/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor_test.py b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor_test.py
new file mode 100644
index 0000000000..1f2f1c9dd2
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor_test.py
@@ -0,0 +1,87 @@
+import os
+
+import keras
+import pytest
+
+from keras_hub.src.models.moonshine.moonshine_audio_converter import (
+ MoonshineAudioConverter,
+)
+from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
+ MoonshineSeq2SeqLMPreprocessor,
+)
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineSeq2SeqLMPreprocessorTest(TestCase):
+ def setUp(self):
+ self.tokenizer = MoonshineTokenizer(
+ proto=os.path.join(
+ self.get_test_data_dir(), "moonshine_test_vocab.spm"
+ )
+ )
+ self.audio_converter = MoonshineAudioConverter(filter_dim=32)
+ self.init_kwargs = {
+ "audio_converter": self.audio_converter,
+ "tokenizer": self.tokenizer,
+ "encoder_sequence_length": None,
+ "decoder_sequence_length": 8,
+ }
+ self.input_data = (
+ {
+ "audio": keras.random.normal((1, 16000, 1)),
+ "text": ["the quick brown fox"],
+ },
+ )
+
+ def test_preprocessor_basics(self):
+ preprocessor = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs)
+ output = preprocessor.call(self.input_data)
+ x_out, y_out, sample_weight_out = output
+ self.assertIn("encoder_input_values", x_out)
+ self.assertIn("encoder_padding_mask", x_out)
+ self.assertIn("decoder_token_ids", x_out)
+ self.assertIn("decoder_padding_mask", x_out)
+ self.assertAllEqual(
+ keras.ops.shape(x_out["encoder_input_values"]), (1, 40, 32)
+ )
+ self.assertAllEqual(
+ keras.ops.shape(x_out["encoder_padding_mask"]), (1, 40)
+ )
+ self.assertAllEqual(keras.ops.shape(x_out["decoder_token_ids"]), (1, 8))
+ self.assertAllEqual(
+ keras.ops.shape(x_out["decoder_padding_mask"]), (1, 8)
+ )
+ self.assertAllEqual(keras.ops.shape(y_out), (1, 8))
+ self.assertAllEqual(keras.ops.shape(sample_weight_out), (1, 8))
+
+ def test_generate_preprocess(self):
+ preprocessor = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs)
+ output = preprocessor.generate_preprocess(self.input_data)
+ self.assertIn("encoder_input_values", output)
+ self.assertAllClose(output["decoder_token_ids"].shape, [1, 8])
+
+ def test_generate_postprocess(self):
+ preprocessor = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs)
+ input_data = {
+ "decoder_token_ids": keras.ops.ones((1, 5), dtype="int32"),
+ "decoder_padding_mask": keras.ops.ones((1, 5)),
+ }
+ output = preprocessor.generate_postprocess(input_data)
+ self.assertIsInstance(output, list)
+ self.assertIsInstance(output[0], str)
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in MoonshineSeq2SeqLMPreprocessor.presets:
+ self.run_preset_test(
+ cls=MoonshineSeq2SeqLMPreprocessor,
+ preset=preset,
+ input_data=self.input_data,
+ )
+
+ def test_serialization(self):
+ instance = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs)
+ self.run_serialization_test(instance=instance)
diff --git a/keras_hub/src/models/moonshine/moonshine_tokenizer.py b/keras_hub/src/models/moonshine/moonshine_tokenizer.py
new file mode 100644
index 0000000000..c55a249af3
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_tokenizer.py
@@ -0,0 +1,105 @@
+import base64
+
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+
+
+@keras_hub_export(
+ [
+ "keras_hub.tokenizers.MoonshineTokenizer",
+ "keras_hub.models.MoonshineTokenizer",
+ ]
+)
+class MoonshineTokenizer(LlamaTokenizer):
+ """
+ Moonshine tokenizer layer based on SentencePiece and LlamaTokenizer.
+
+ This tokenizer class extends the `LlamaTokenizer` to tokenize raw strings
+ to integer sequences while incorporating Moonshine-specific special tokens.
+
+ **Special tokens added:**
+ - **Start token:** ""
+ - **End token:** ""
+ - **Unknown token:** ""
+ - **Padding token:** ""
+ - **Position embedding tokens:** "<>" through "<>"
+ - **Hex tokens:** "<0x00>" through "<0xFF>"
+ - **Empty token:** "<>"
+
+ Args:
+ proto: `str` or `bytes`. Either a string path to a SentencePiece proto
+ file or a bytes object containing a serialized SentencePiece proto.
+ See the [SentencePiece repository](https://github.com/google/sentencepiece)
+ for details on the format.
+ **kwargs: Additional keyword arguments passed to the parent
+ `LlamaTokenizer`.
+
+ Examples:
+ ```python
+ from keras_hub.tokenizers import MoonshineTokenizer
+
+ # Initialize tokenizer.
+ tokenizer = MoonshineTokenizer(
+ "keras_hub/src/tests/test_data/moonshine_test_vocab.spm"
+ )
+
+ # Single input example.
+ single_input = "the quick brown fox"
+ single_tokens = tokenizer(single_input)
+ print("Single input tokenization:")
+ print(f"Input text: {single_input}")
+ print(f"Tokenized: {single_tokens}")
+
+ # Batched input example.
+ batch_input = ["the quick brown fox", "the earth is round"]
+ batch_tokens = tokenizer(batch_input)
+ print("Batch input tokenization:")
+ print(f"Input texts: {batch_input}")
+ print(f"Tokenized: {batch_tokens}")
+
+ # Detokenization example.
+ encoded = tokenizer(single_input)
+ decoded = tokenizer.detokenize(encoded)
+ print("Detokenization:")
+ print(f"Original text: {single_input}")
+ print(f"Encoded: {encoded}")
+ print(f"Decoded: {decoded}")
+ ```
+ """
+
+ # References:
+ # Defined in Section 3.1 of the Moonshine paper, "Moonshine: Speech
+ # Recognition for Live Transcription and Voice Commands" (https://arxiv.org/pdf/2410.15608.pdf)
+
+ backbone_cls = MoonshineBackbone
+
+ def __init__(self, proto, **kwargs):
+ super().__init__(proto=proto, **kwargs)
+
+ for i in range(768):
+ self._add_special_token(f"<>", f"st_token_{i}")
+
+ for i in range(256):
+ self._add_special_token(f"<0x{i:02X}>", f"hex_token_{i}")
+
+ self._add_special_token("<>", "empty_token")
+
+ self.start_token_id = self.token_to_id("") # Beginning of sentence
+ self.end_token_id = self.token_to_id("") # End of sentence
+ self.pad_token_id = self.token_to_id("") # Padding token
+ self.unk_token_id = self.token_to_id("") # Unknown token
+
+ def get_config(self):
+ config = super().get_config()
+ if isinstance(self.proto, bytes):
+ config["proto"] = base64.b64encode(self.proto).decode("utf-8")
+ else:
+ config["proto"] = self.proto
+ return config
+
+ @classmethod
+ def from_config(cls, config):
+ if "proto" in config and isinstance(config["proto"], str):
+ config["proto"] = base64.b64decode(config["proto"])
+ return super().from_config(config)
diff --git a/keras_hub/src/models/moonshine/moonshine_tokenizer_test.py b/keras_hub/src/models/moonshine/moonshine_tokenizer_test.py
new file mode 100644
index 0000000000..144c10c654
--- /dev/null
+++ b/keras_hub/src/models/moonshine/moonshine_tokenizer_test.py
@@ -0,0 +1,96 @@
+import os
+
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
+from keras_hub.src.tests.test_case import TestCase
+
+
+class MoonshineTokenizerTest(TestCase):
+ def setUp(self):
+ self.init_kwargs = {
+ # Generated using create_moonshine_test_proto.py.
+ "proto": os.path.join(
+ self.get_test_data_dir(), "moonshine_test_vocab.spm"
+ )
+ }
+ self.tokenizer = MoonshineTokenizer(**self.init_kwargs)
+ self.input_data = ["the quick brown fox", "the earth is round"]
+ self.special_token_inputs = [
+ "Hello world!",
+ "Test with <>",
+ "Hex test <0x1F>",
+ "Empty token test <>",
+ ]
+
+ def test_tokenizer_basics(self):
+ self.run_preprocessing_layer_test(
+ cls=MoonshineTokenizer,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ )
+
+ def test_special_tokens_existence(self):
+ self.assertIsNotNone(self.tokenizer.start_token_id)
+ self.assertIsNotNone(self.tokenizer.end_token_id)
+ self.assertIsNotNone(self.tokenizer.pad_token_id)
+
+ self.assertIsNotNone(self.tokenizer.unk_token_id)
+ self.assertIsNotNone(self.tokenizer.token_to_id("<>"))
+
+ self.assertIsNotNone(self.tokenizer.token_to_id("<>"))
+ self.assertIsNotNone(self.tokenizer.token_to_id("<>"))
+ self.assertIsNotNone(self.tokenizer.token_to_id("<>"))
+ self.assertIsNotNone(self.tokenizer.token_to_id("<>"))
+
+ self.assertIsNotNone(self.tokenizer.token_to_id("<0x00>"))
+ self.assertIsNotNone(self.tokenizer.token_to_id("<0x1F>"))
+ self.assertIsNotNone(self.tokenizer.token_to_id("<0xA0>"))
+ self.assertIsNotNone(self.tokenizer.token_to_id("<0xFF>"))
+
+ def test_special_token_ids_mapping(self):
+ self.assertEqual(
+ self.tokenizer.token_to_id(""), self.tokenizer.start_token_id
+ )
+ self.assertEqual(
+ self.tokenizer.token_to_id(""), self.tokenizer.end_token_id
+ )
+ self.assertEqual(
+ self.tokenizer.token_to_id(""), self.tokenizer.pad_token_id
+ )
+ self.assertEqual(
+ self.tokenizer.token_to_id(""), self.tokenizer.unk_token_id
+ )
+
+ def test_special_tokens_tokenization(self):
+ tokenized_st42 = self.tokenizer("<>")
+ self.assertEqual(len(tokenized_st42), 1)
+
+ tokenized_hex = self.tokenizer("<0x1F>")
+ self.assertEqual(len(tokenized_hex), 1)
+
+ tokenized_empty = self.tokenizer("<>")
+ self.assertEqual(len(tokenized_empty), 1)
+
+ def test_detokenization(self):
+ for text in self.input_data + self.special_token_inputs:
+ tokens = self.tokenizer(text)
+ decoded = self.tokenizer.detokenize(tokens)
+ if text in self.input_data:
+ self.assertIn(text.lower(), decoded.lower())
+
+ def test_errors_missing_special_tokens(self):
+ with self.assertRaises(ValueError):
+ MoonshineTokenizer(
+ proto=os.path.join(
+ self.get_test_data_dir(), "no_special_token_vocab.spm"
+ )
+ )
+
+ def test_batch_tokenization(self):
+ batch_tokens = self.tokenizer(self.input_data)
+ self.assertEqual(len(batch_tokens), len(self.input_data))
+
+ def test_serialization(self):
+ instance = MoonshineTokenizer(**self.init_kwargs)
+ self.run_serialization_test(instance=instance)
diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav
new file mode 100644
index 0000000000..6ecff04a48
Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav differ
diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav
new file mode 100644
index 0000000000..428c1fee89
Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav differ
diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav
new file mode 100644
index 0000000000..b051143373
Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav differ
diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav
new file mode 100644
index 0000000000..46d0b072c2
Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav differ
diff --git a/keras_hub/src/tests/test_data/llama2_tokenizer_full.spm b/keras_hub/src/tests/test_data/llama2_tokenizer_full.spm
new file mode 100644
index 0000000000..22bccbcb41
Binary files /dev/null and b/keras_hub/src/tests/test_data/llama2_tokenizer_full.spm differ
diff --git a/keras_hub/src/tests/test_data/moonshine_test_vocab.spm b/keras_hub/src/tests/test_data/moonshine_test_vocab.spm
new file mode 100644
index 0000000000..f9ece424de
Binary files /dev/null and b/keras_hub/src/tests/test_data/moonshine_test_vocab.spm differ
diff --git a/tools/checkpoint_conversion/convert_moonshine_checkpoints.py b/tools/checkpoint_conversion/convert_moonshine_checkpoints.py
new file mode 100644
index 0000000000..bea96a748d
--- /dev/null
+++ b/tools/checkpoint_conversion/convert_moonshine_checkpoints.py
@@ -0,0 +1,538 @@
+"""
+Convert Moonshine checkpoints to KerasHub format and provide a complete
+end-to-end example.
+
+The weights are sourced from:
+https://huggingface.co/UsefulSensors/moonshine/tree/main/base
+https://huggingface.co/UsefulSensors/moonshine/tree/main/tiny
+
+The Hugging Face configs are available at:
+https://huggingface.co/UsefulSensors/moonshine-base/blob/main/config.json
+https://huggingface.co/UsefulSensors/moonshine-tiny/blob/main/config.json
+
+Usage:
+```shell
+python -m tools.checkpoint_conversion.convert_moonshine_checkpoints
+```
+"""
+
+import json
+import os
+import warnings
+
+import h5py
+import keras
+
+try:
+ import librosa
+except ImportError:
+ raise ImportError(
+ "Moonshine ASR system requires librosa as a dependency. Please install "
+ "it using 'pip install librosa' before proceeding."
+ )
+import numpy as np
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from transformers import AutoModel
+
+from keras_hub.src.models.moonshine.moonshine_audio_converter import (
+ MoonshineAudioConverter,
+)
+from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
+ MoonshineAudioToText,
+)
+from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
+from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
+ MoonshineSeq2SeqLMPreprocessor,
+)
+from keras_hub.src.models.moonshine.moonshine_tokenizer import (
+ MoonshineTokenizer,
+)
+
+# Set random seed for reproducibility.
+keras.utils.set_random_seed(50)
+
+
+# Utility function to convert tensors to NumPy based on backend.
+def to_numpy(tensor):
+ if keras.backend.backend() == "torch":
+ return tensor.detach().cpu().numpy()
+ elif keras.backend.backend() == "tensorflow":
+ return tensor.numpy()
+ elif keras.backend.backend() == "jax":
+ import jax
+
+ return jax.device_get(tensor)
+ else:
+ raise ValueError("Unsupported backend")
+
+
+def load_h5_weights(filepath):
+ with h5py.File(filepath, "r") as f:
+ weights = {}
+
+ def recursive_load(group, prefix=""):
+ for key in group.keys():
+ path = f"{prefix}/{key}" if prefix else key
+ if isinstance(group[key], h5py.Dataset):
+ weights[path] = np.array(group[key])
+ else:
+ recursive_load(group[key], path)
+
+ recursive_load(f)
+ return weights
+
+
+# Init.
+presets = ["moonshine-base", "moonshine-tiny"]
+for preset in presets:
+ print(f"\n=== Processing {preset} ===")
+ PRESET_NAME = preset
+ PRESET = f"UsefulSensors/{preset}"
+ EXTRACT_DIR = "./{}"
+
+ extract_dir = EXTRACT_DIR.format(PRESET_NAME)
+ if not os.path.exists(extract_dir):
+ os.makedirs(extract_dir)
+
+ # Download and load config.
+ config_path = os.path.join(extract_dir, "config.json")
+ response = requests.get(
+ f"https://huggingface.co/{PRESET}/raw/main/config.json"
+ )
+ open(config_path, "wb").write(response.content)
+
+ cfg = {}
+ with open(config_path, "r") as pt_cfg_handler:
+ pt_cfg = json.load(pt_cfg_handler)
+
+ # Setup Moonshine config.
+ cfg["vocabulary_size"] = pt_cfg["vocab_size"]
+ cfg["num_layers"] = pt_cfg["encoder_num_hidden_layers"]
+ cfg["filter_dim"] = pt_cfg["hidden_size"]
+ cfg["hidden_dim"] = pt_cfg["hidden_size"]
+ cfg["intermediate_dim"] = pt_cfg["intermediate_size"]
+ cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"]
+ cfg["partial_rotary_factor"] = pt_cfg["partial_rotary_factor"]
+ cfg["rope_theta"] = pt_cfg["rope_theta"]
+ cfg["encoder_num_layers"] = pt_cfg["encoder_num_hidden_layers"]
+ cfg["decoder_num_layers"] = pt_cfg["decoder_num_hidden_layers"]
+ cfg["encoder_num_heads"] = pt_cfg.get("encoder_num_attention_heads", 8)
+ cfg["decoder_num_heads"] = pt_cfg.get("decoder_num_attention_heads", 8)
+ cfg["feedforward_expansion_factor"] = 4
+ cfg["attention_bias"] = pt_cfg["attention_bias"]
+ cfg["attention_dropout"] = pt_cfg["attention_dropout"]
+ cfg["dtype"] = pt_cfg["torch_dtype"]
+ cfg["decoder_use_swiglu_activation"] = (
+ pt_cfg["decoder_hidden_act"] == "silu"
+ )
+ cfg["encoder_use_swiglu_activation"] = (
+ pt_cfg["encoder_hidden_act"] == "silu"
+ )
+ cfg["initializer_range"] = pt_cfg["initializer_range"]
+ cfg["rope_scaling"] = pt_cfg["rope_scaling"]
+
+ # Taken from: https://huggingface.co/UsefulSensors/moonshine-{base/tiny}/blob/main/preprocessor_config.json.
+ cfg["sampling_rate"] = 16000
+ cfg["padding_value"] = 0.0
+ cfg["do_normalize"] = False
+ cfg["return_attention_mask"] = True
+
+ # Download weights.
+ weights_dir = os.path.join(extract_dir, "weights")
+ repo_id = "UsefulSensors/moonshine"
+ variant = preset.split("-")[-1]
+ files = [
+ "encoder.weights.h5",
+ "preprocessor.weights.h5",
+ "decoder.weights.h5",
+ ]
+ for fname in files:
+ file_path = os.path.join(weights_dir, f"{variant}/{fname}")
+ if not os.path.exists(file_path):
+ print(f"Downloading {fname} to {file_path}...")
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=f"{variant}/{fname}",
+ local_dir=weights_dir,
+ )
+
+ # Set weights paths.
+ encoder_weights_path = os.path.join(
+ weights_dir, variant, "encoder.weights.h5"
+ )
+ preprocessor_weights_path = os.path.join(
+ weights_dir, variant, "preprocessor.weights.h5"
+ )
+ decoder_weights_path = os.path.join(
+ weights_dir, variant, "decoder.weights.h5"
+ )
+ hf_wts_encoder = load_h5_weights(encoder_weights_path)
+ hf_wts_preprocessor = load_h5_weights(preprocessor_weights_path)
+ hf_wts_decoder = load_h5_weights(decoder_weights_path)
+
+ # Build Keras models.
+ backbone = MoonshineBackbone(
+ vocabulary_size=cfg["vocabulary_size"],
+ encoder_num_layers=cfg["encoder_num_layers"],
+ decoder_num_layers=cfg["decoder_num_layers"],
+ hidden_dim=cfg["hidden_dim"],
+ intermediate_dim=cfg["intermediate_dim"],
+ encoder_num_heads=cfg["encoder_num_heads"],
+ decoder_num_heads=cfg["decoder_num_heads"],
+ feedforward_expansion_factor=cfg["feedforward_expansion_factor"],
+ decoder_use_swiglu_activation=cfg["decoder_use_swiglu_activation"],
+ encoder_use_swiglu_activation=cfg["encoder_use_swiglu_activation"],
+ max_position_embeddings=cfg["max_sequence_length"],
+ partial_rotary_factor=cfg["partial_rotary_factor"],
+ dropout=cfg["attention_dropout"],
+ initializer_range=cfg["initializer_range"],
+ rope_theta=cfg["rope_theta"],
+ attention_bias=cfg["attention_bias"],
+ attention_dropout=cfg["attention_dropout"],
+ rope_scaling=cfg["rope_scaling"],
+ dtype=cfg["dtype"],
+ )
+
+ # Build tokenizer.
+ tokenizer = MoonshineTokenizer(
+ proto="keras_hub/src/tests/test_data/llama2_tokenizer_full.spm"
+ )
+ # Build audio converter.
+ audio_converter = MoonshineAudioConverter(
+ filter_dim=cfg["filter_dim"],
+ initializer_range=cfg["initializer_range"],
+ sampling_rate=cfg["sampling_rate"],
+ padding_value=cfg["padding_value"],
+ do_normalize=cfg["do_normalize"],
+ return_attention_mask=cfg["return_attention_mask"],
+ )
+ # Build preprocessor.
+ preprocessor = MoonshineSeq2SeqLMPreprocessor(
+ audio_converter=audio_converter,
+ tokenizer=tokenizer,
+ encoder_sequence_length=None,
+ decoder_sequence_length=cfg["max_sequence_length"],
+ )
+ # Build the model.
+ keras_model = MoonshineAudioToText(
+ backbone=backbone,
+ preprocessor=preprocessor,
+ )
+
+ # Build the model with dummy data.
+ dummy_audio = np.zeros((1, 16000), dtype="float32")
+ dummy_text = [""]
+ dummy_inputs = {"audio": dummy_audio, "text": dummy_text}
+ preprocessed_inputs, _, _ = preprocessor(dummy_inputs)
+ keras_model(preprocessed_inputs)
+
+ # Assign preprocessor weights.
+ base_path = "layers/sequential/layers/"
+ weights = [
+ hf_wts_preprocessor[f"{base_path}conv1d/vars/0"], # conv1 kernel
+ hf_wts_preprocessor[f"{base_path}group_normalization/vars/0"], # gamma
+ hf_wts_preprocessor[f"{base_path}group_normalization/vars/1"], # beta
+ hf_wts_preprocessor[f"{base_path}conv1d_1/vars/0"], # conv2 kernel
+ hf_wts_preprocessor[f"{base_path}conv1d_1/vars/1"], # conv2 bias
+ hf_wts_preprocessor[f"{base_path}conv1d_2/vars/0"], # conv3 kernel
+ hf_wts_preprocessor[f"{base_path}conv1d_2/vars/1"], # conv3 bias
+ ]
+ keras_model.preprocessor.audio_converter.set_weights(weights)
+
+ # Assign encoder weights.
+ keras_model.backbone.encoder_rotary_embedding.inv_freq.assign(
+ hf_wts_encoder["layers/rotary_embedding/vars/0"]
+ )
+
+ for layer_index in range(cfg["encoder_num_layers"]):
+ if layer_index == 0:
+ base_prefix = "layers/functional/layers"
+ else:
+ base_prefix = f"layers/functional_{layer_index}/layers"
+ attention_prefix = f"{base_prefix}/mha_with_rope"
+ ff_prefix = f"{base_prefix}/functional/layers/sequential/layers"
+
+ # Attention weights.
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].self_attention_layer._query_dense.kernel.assign(
+ hf_wts_encoder[f"{attention_prefix}/query_dense/vars/0"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].self_attention_layer._key_dense.kernel.assign(
+ hf_wts_encoder[f"{attention_prefix}/key_dense/vars/0"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].self_attention_layer._value_dense.kernel.assign(
+ hf_wts_encoder[f"{attention_prefix}/value_dense/vars/0"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].self_attention_layer._output_dense.kernel.assign(
+ hf_wts_encoder[f"{attention_prefix}/output_dense/vars/0"]
+ )
+
+ # Layer norms.
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].self_attention_layer_norm.gamma.assign(
+ hf_wts_encoder[f"{base_prefix}/layer_normalization/vars/0"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].feedforward_layer_norm.gamma.assign(
+ hf_wts_encoder[f"{base_prefix}/layer_normalization_1/vars/0"]
+ )
+
+ # Feedforward weights.
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].feedforward.dense_1.kernel.assign(
+ hf_wts_encoder[f"{ff_prefix}/dense/vars/0"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].feedforward.dense_1.bias.assign(
+ hf_wts_encoder[f"{ff_prefix}/dense/vars/1"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].feedforward.dense_2.kernel.assign(
+ hf_wts_encoder[f"{ff_prefix}/dense_1/vars/0"]
+ )
+ keras_model.backbone.encoder_blocks[
+ layer_index
+ ].feedforward.dense_2.bias.assign(
+ hf_wts_encoder[f"{ff_prefix}/dense_1/vars/1"]
+ )
+
+ keras_model.backbone.encoder_final_layer_norm.gamma.assign(
+ hf_wts_encoder["layers/layer_normalization/vars/0"]
+ )
+
+ # Assign decoder weights.
+ keras_model.backbone.token_embedding.embeddings.assign(
+ hf_wts_decoder["layers/reversible_embedding/vars/0"]
+ )
+ keras_model.backbone.decoder_rotary_embedding.inv_freq.assign(
+ hf_wts_decoder["layers/rotary_embedding/vars/0"]
+ )
+
+ for layer_index in range(cfg["decoder_num_layers"]):
+ if layer_index == 0:
+ base_prefix = "layers/functional/layers"
+ else:
+ base_prefix = f"layers/functional_{layer_index}/layers"
+ self_attention_prefix = f"{base_prefix}/mha_causal_with_rope"
+ cross_attention_prefix = f"{base_prefix}/mha_precomputed_kv"
+ ff_prefix = f"{base_prefix}/functional/layers"
+
+ # Self-attention weights.
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].self_attention._query_dense.kernel.assign(
+ hf_wts_decoder[f"{self_attention_prefix}/query_dense/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].self_attention._key_dense.kernel.assign(
+ hf_wts_decoder[f"{self_attention_prefix}/key_dense/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].self_attention._value_dense.kernel.assign(
+ hf_wts_decoder[f"{self_attention_prefix}/value_dense/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].self_attention._output_dense.kernel.assign(
+ hf_wts_decoder[f"{self_attention_prefix}/output_dense/vars/0"]
+ )
+
+ # Cross-attention weights.
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].cross_attention._query_dense.kernel.assign(
+ hf_wts_decoder[f"{cross_attention_prefix}/query_dense/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].cross_attention._key_dense.kernel.assign(
+ hf_wts_decoder[f"{cross_attention_prefix}/key_dense/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].cross_attention._value_dense.kernel.assign(
+ hf_wts_decoder[f"{cross_attention_prefix}/value_dense/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].cross_attention._output_dense.kernel.assign(
+ hf_wts_decoder[f"{cross_attention_prefix}/output_dense/vars/0"]
+ )
+
+ # Layer norms.
+ keras_model.backbone.decoder_blocks[layer_index].norm1.gamma.assign(
+ hf_wts_decoder[f"{base_prefix}/layer_normalization/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[layer_index].norm2.gamma.assign(
+ hf_wts_decoder[f"{base_prefix}/layer_normalization_1/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[layer_index].norm3.gamma.assign(
+ hf_wts_decoder[f"{base_prefix}/layer_normalization_2/vars/0"]
+ )
+
+ # Feedforward weights.
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].ff.dense_1.kernel.assign(hf_wts_decoder[f"{ff_prefix}/dense/vars/0"])
+ keras_model.backbone.decoder_blocks[layer_index].ff.dense_1.bias.assign(
+ hf_wts_decoder[f"{ff_prefix}/dense/vars/1"]
+ )
+ keras_model.backbone.decoder_blocks[
+ layer_index
+ ].ff.dense_2.kernel.assign(
+ hf_wts_decoder[f"{ff_prefix}/dense_1/vars/0"]
+ )
+ keras_model.backbone.decoder_blocks[layer_index].ff.dense_2.bias.assign(
+ hf_wts_decoder[f"{ff_prefix}/dense_1/vars/1"]
+ )
+
+ keras_model.backbone.decoder_post_norm.gamma.assign(
+ hf_wts_decoder["layers/layer_normalization/vars/0"]
+ )
+
+ # Save Keras model weights.
+ output_dir = os.path.join(extract_dir, f"{preset}-model.keras")
+ keras_model.save(output_dir)
+ print(f"Saved Keras model weights to {output_dir}")
+
+ # Prepare inputs.
+ sample_text = [
+ np.random.randn(16000).astype("float32")
+ ] # Random audio sample
+ keras_preprocessed_inputs = keras_model.preprocessor.audio_converter(
+ keras.ops.convert_to_tensor(sample_text), padding="longest"
+ )
+ encoder_input_values = keras_preprocessed_inputs["input_values"]
+ encoder_padding_mask = keras_preprocessed_inputs["attention_mask"]
+
+ # Prepare raw audio for HF model.
+ raw_audio = np.array(sample_text) # Shape: (1, 16000)
+
+ # For HF model, use raw audio instead of preprocessed features.
+ hf_inputs = {
+ "input_values": torch.from_numpy(raw_audio), # Shape: (1, 16000)
+ "decoder_input_ids": torch.randint(
+ 0, cfg["vocabulary_size"], (1, 32), dtype=torch.int32
+ ),
+ }
+ position_ids = torch.arange(0, 32, dtype=torch.long).unsqueeze(0)
+
+ # Prepare Keras inputs for backbone.
+ decoder_token_ids = keras.ops.convert_to_tensor(
+ hf_inputs["decoder_input_ids"]
+ )
+ decoder_padding_mask = keras.ops.cast(
+ keras.ops.not_equal(decoder_token_ids, 0), "bool"
+ )
+
+ # Run Keras backbone.
+ keras_backbone_outputs = keras_model.backbone(
+ {
+ "encoder_input_values": encoder_input_values,
+ "decoder_token_ids": decoder_token_ids,
+ "encoder_padding_mask": encoder_padding_mask,
+ "decoder_padding_mask": decoder_padding_mask,
+ },
+ training=False,
+ )
+ keras_encoder_output = to_numpy(
+ keras_backbone_outputs["encoder_sequence_output"]
+ )
+ keras_decoder_output = to_numpy(
+ keras_backbone_outputs["decoder_sequence_output"]
+ )
+
+ # Run Hugging Face model and compute outputs.
+ hf_model = AutoModel.from_pretrained(PRESET)
+ hf_model.eval()
+ with torch.no_grad():
+ hf_outputs = hf_model(
+ input_values=hf_inputs["input_values"],
+ decoder_input_ids=hf_inputs["decoder_input_ids"],
+ decoder_position_ids=position_ids,
+ output_hidden_states=True,
+ )
+ hf_encoder_hidden_states = hf_outputs.encoder_hidden_states[-1]
+ hf_encoder_output_np = hf_encoder_hidden_states.numpy()
+ hf_decoder_hidden_states = hf_outputs.last_hidden_state
+ hf_decoder_output_np = hf_decoder_hidden_states.numpy()
+
+ # Compute absolute differences between HF and Keras outputs.
+ encoder_abs_diff = np.abs(keras_encoder_output - hf_encoder_output_np)
+ encoder_min_abs_diff = np.min(encoder_abs_diff)
+ encoder_max_abs_diff = np.max(encoder_abs_diff)
+ decoder_abs_diff = np.abs(keras_decoder_output - hf_decoder_output_np)
+ decoder_min_abs_diff = np.min(decoder_abs_diff)
+ decoder_max_abs_diff = np.max(decoder_abs_diff)
+ # Print differences.
+ print(f"\n=== Differences for {preset} ===")
+ if preset == "moonshine-tiny":
+ warnings.warn(
+ "Note: The 'moonshine-tiny' numerics results differ between "
+ "implementations. This discrepancy stems from a bug in the HF "
+ "implementation, likely in the rotary embeddings calculation. The "
+ "bug causes failures with longer transcripts, while the Keras "
+ "implementation handles these correctly, as demonstrated in the "
+ "notebook."
+ )
+ print(
+ f"Encoder output absolute differences: min={encoder_min_abs_diff}, "
+ f"max={encoder_max_abs_diff}"
+ )
+ print(
+ f"Decoder output absolute differences: min={decoder_min_abs_diff}, "
+ f"max={decoder_max_abs_diff}"
+ )
+ # Test: End-to-End ASR Examples.
+ print(f"\n=== End-to-End ASR Example for {preset} ===")
+ # Test 1: Male Clear Voice, Snippet (Length - 3 Sec)
+ print("\nTest: Male Clear Voice, Snippet (Length - 3 Sec)")
+ audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav" # noqa: E501
+ audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"])
+ audio = audio.reshape(1, -1)
+ inputs = {"audio": audio, "text": [""]}
+ transcription = keras_model.generate(inputs)
+ print("Transcription:", transcription)
+
+ # Test 2: Female Clear Voice, Excerpt (Length - 17 Sec)
+ print("\nTest: Female Clear Voice, Excerpt (Length - 17 Sec)")
+ audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav" # noqa: E501
+ audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"])
+ audio = audio.reshape(1, -1)
+ inputs = {"audio": audio, "text": [""]}
+ transcription = keras_model.generate(inputs)
+ print("Transcription:", transcription)
+
+ # Test 3: Male Muffled Voice, Manuscript (Length - 46 Sec)
+ print("\nTest: Male Muffled Voice, Manuscript (Length - 46 Sec)")
+ audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav" # noqa: E501
+ audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"])
+ audio = audio.reshape(1, -1)
+ inputs = {"audio": audio, "text": [""]}
+ transcription = keras_model.generate(inputs, max_length=200)
+ print("Transcription:", transcription)
+
+ # Test 4: Female Clear Voice, Odyssey (Maximum Length - 64 Sec)
+ print("\nTest: Female Clear Voice, Odyssey (Maximum Length - 64 Sec)")
+ audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav" # noqa: E501
+ audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"])
+ audio = audio.reshape(1, -1)
+ inputs = {"audio": audio, "text": [""]}
+ transcription = keras_model.generate(inputs, max_length=200)
+ print("Transcription:", transcription)
diff --git a/tools/sentencepiece_testing/create_moonshine_test_proto.py b/tools/sentencepiece_testing/create_moonshine_test_proto.py
new file mode 100644
index 0000000000..ef2173d01b
--- /dev/null
+++ b/tools/sentencepiece_testing/create_moonshine_test_proto.py
@@ -0,0 +1,42 @@
+import os
+
+from tools.sentencepiece_testing.utils import train_sentencepiece
+
+
+def create_moonshine_test_vocab():
+ special_tokens = (
+ [
+ "<>", # Empty token
+ ]
+ + [f"<0x{i:02X}>" for i in range(256)]
+ + [f"<>" for i in range(768)]
+ )
+ training_texts = ["the quick brown fox", "the earth is round"]
+ test_data_dir = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ "..",
+ "..",
+ "keras_hub",
+ "src",
+ "tests",
+ "test_data",
+ )
+ os.makedirs(test_data_dir, exist_ok=True)
+ model_prefix = "moonshine_test_vocab"
+ target_path = os.path.join(test_data_dir, f"{model_prefix}.spm")
+ train_sentencepiece(
+ training_texts,
+ target_path,
+ vocab_size=11 + len(special_tokens),
+ model_type="WORD",
+ pad_id=0, # token
+ unk_id=1, # token
+ bos_id=2, # token
+ eos_id=3, # token
+ user_defined_symbols=special_tokens,
+ )
+ print(f"Moonshine test vocabulary created at: {target_path}")
+
+
+if __name__ == "__main__":
+ create_moonshine_test_vocab()