diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 9562e7df58..6b65ab8cec 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -58,6 +58,9 @@ from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( MobileNetImageConverter, ) +from keras_hub.src.models.moonshine.moonshine_audio_converter import ( + MoonshineAudioConverter, +) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 2f510446d7..041531efbb 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -232,6 +232,16 @@ from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( + MoonshineAudioToText, +) +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import ( + MoonshineSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 9f73bfd665..29accb7b08 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -24,6 +24,9 @@ from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer, diff --git a/keras_hub/src/models/moonshine/__init__.py b/keras_hub/src/models/moonshine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/moonshine/moonshine_audio_converter.py b/keras_hub/src/models/moonshine/moonshine_audio_converter.py new file mode 100644 index 0000000000..9c713418b3 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_audio_converter.py @@ -0,0 +1,260 @@ +import warnings + +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.models.moonshine.moonshine_layers import ( + moonshine_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +@keras_hub_export("keras_hub.layers.MoonshineAudioConverter") +class MoonshineAudioConverter(AudioConverter): + """Moonshine preprocessor and audio converter layer. + + This layer processes raw audio waveforms for the Moonshine ASR model. Audio + is formatted as a batched tensor at a 16kHz sample rate and validated for + length (0.1 to 64 seconds). The layer downsamples and extracts key features + from the audio signal through a series of convolutional operations, + normalization, and nonlinear activations. + + Args: + filter_dim: int. The number of filters for the first convolutional + layer. This influences the dimensionality of the feature extraction + pipeline and determines the richness of the audio representation. + sampling_rate: int, optional. The audio sampling rate in Hz. Defaults to + 16,000. + padding_value: float, optional. The value for padding. Defaults to 0.0. + do_normalize: bool, optional. Whether to normalize inputs. Defaults to + False. + return_attention_mask: bool, optional. Whether to return an attention + mask. Defaults to True. + initializer_range: float, optional. The standard deviation for kernel + initialization. Defaults to 0.02. + **kwargs: Additional keyword arguments passed to the base AudioConverter + class for customizing the underlying preprocessing behavior. + + Examples: + ```python + import keras + from keras_hub.layers import MoonshineAudioConverter + + # Create a dummy audio input (1 second at 16kHz). + dummy_audio = keras.ops.convert_to_tensor( + [[0.1] * 16000], + dtype="float32" + ) + dummy_audio = keras.ops.expand_dims(dummy_audio, axis=-1) + + # Initialize the preprocessor. + preprocessor = MoonshineAudioConverter(filter_dim=256) + + # Process the audio. + features = preprocessor(dummy_audio) + + # Output shapes. + print(features["input_values"].shape) # Expected: (1, 40, 256) + print(features["attention_mask"].shape) # Expected: (1, 40) + ``` + """ + + # References: + # Defined and formulated based on the Hugging Face implementation of the + # Wav2Vec2FeatureExtractor class (https://github.com/huggingface/transformers/blob/66f29aaaf55c8fe0c3dbcd24beede2ca4effac56/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L31-L243) + # and the convolutional layer structure defined in the UsefulSensors + # implementation of the AudioPreprocessor class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L6-L32). + + backbone_cls = MoonshineBackbone + + def __init__( + self, + filter_dim, + sampling_rate=16000, + padding_value=0.0, + do_normalize=False, + return_attention_mask=True, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.filter_dim = filter_dim + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.do_normalize = do_normalize + self.return_attention_mask = return_attention_mask + self.initializer_range = initializer_range + self.kernel_initializer = moonshine_kernel_initializer( + initializer_range=initializer_range + ) + + self.conv1 = keras.layers.Conv1D( + filters=filter_dim, + kernel_size=127, + strides=64, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self.tanh = keras.layers.Activation("tanh") + self.group_norm = keras.layers.GroupNormalization( + groups=1, + axis=-1, + epsilon=1e-5, + ) + self.conv2 = keras.layers.Conv1D( + filters=2 * filter_dim, + kernel_size=7, + strides=3, + padding="valid", + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self.gelu1 = keras.layers.Activation("gelu") + self.conv3 = keras.layers.Conv1D( + filters=filter_dim, + kernel_size=3, + strides=2, + padding="valid", + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self.gelu2 = keras.layers.Activation("gelu") + + def build(self, input_shape): + self.conv1.build((None, None, 1)) + self.group_norm.build((None, None, self.filter_dim)) + self.conv2.build((None, None, self.filter_dim)) + self.conv3.build((None, None, 2 * self.filter_dim)) + self.built = True + + def call( + self, + inputs, + sampling_rate=None, + padding=None, + max_length=None, + pad_to_multiple_of=None, + return_tensors=None, + ): + # Validate sampling rate. + if sampling_rate is not None and sampling_rate != self.sampling_rate: + raise ValueError( + f"Expected sampling_rate {self.sampling_rate}, got " + f"{sampling_rate}" + ) + + # Ensure inputs are (batch_size, time_steps, 1). + if keras.ops.ndim(inputs) == 2: + inputs = keras.ops.expand_dims(inputs, axis=-1) + elif keras.ops.ndim(inputs) != 3 or keras.ops.shape(inputs)[-1] != 1: + raise ValueError( + "Inputs must be mono audio: (batch_size, time_steps, 1)" + ) + + # Get original length and validate duration. + original_length = keras.ops.shape(inputs)[1] + duration = original_length / self.sampling_rate + # Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/transcribe.py#L20 + if duration < 0.1 or duration > 64: + raise warnings.warn( + f"Audio duration must be between 0.1 and 64 seconds, got " + f"{duration:.2f} seconds in a single transcribe call. For " + "transcribing longer segments, pre-segment your audio and " + "provide shorter segments." + ) + # Handle padding. + if padding == "longest": + max_length = original_length + elif padding == "max_length" and max_length is None: + max_length = original_length + if max_length is not None: + if pad_to_multiple_of: + max_length = ( + (max_length + pad_to_multiple_of - 1) // pad_to_multiple_of + ) * pad_to_multiple_of + if original_length < max_length: + padding_amount = max_length - original_length + inputs = keras.ops.pad( + inputs, + [(0, 0), (0, padding_amount), (0, 0)], + constant_values=self.padding_value, + ) + + # Normalize if enabled. + if self.do_normalize: + mean = keras.ops.mean(inputs, axis=1, keepdims=True) + var = keras.ops.var(inputs, axis=1, keepdims=True) + inputs = (inputs - mean) / keras.ops.sqrt(var + 1e-7) + + # Apply convolutional feature extraction. + x = self.conv1(inputs) + x = self.tanh(x) + x = self.group_norm(x) + x = self.conv2(x) + x = self.gelu1(x) + x = self.conv3(x) + features = self.gelu2(x) + + # Generate attention mask. + output_length = keras.ops.shape(features)[1] + attention_mask = None + if self.return_attention_mask: + # Calculate mask length through the network's downsampling ops. + # Step 1: First conv layer (conv1). + conv1_out = (original_length - 127 + 1) / 64 + # Step 2: Second conv layer (conv2). + conv2_out = (conv1_out - 7 + 1) / 3 + # Step 3: Third conv layer (conv3). + conv3_out = (conv2_out - 3 + 1) / 2 + + # Apply ceil() to get the final mask length as an int. + mask_length = keras.ops.cast( + keras.ops.ceil(keras.ops.cast(conv3_out, "float32")), "int32" + ) + # Broadcast the mask length to match the batch size. + batch_size = keras.ops.shape(inputs)[0] + mask_length = keras.ops.broadcast_to(mask_length, [batch_size]) + indices = keras.ops.arange(output_length, dtype="int32") + attention_mask = keras.ops.cast( + indices[None, :] < mask_length[:, None], dtype="int32" + ) + + output = {"input_values": features} + if attention_mask is not None: + output["attention_mask"] = attention_mask + + return output + + def compute_output_shape(self, input_shape): + # [batch_size, time_steps] → [batch_size, time_steps, 1]. + if len(input_shape) == 2: + expanded_shape = (input_shape[0], input_shape[1], 1) + else: + expanded_shape = input_shape + # Compute output shape sequentially. + x_shape = self.conv1.compute_output_shape(expanded_shape) + x_shape = self.tanh.compute_output_shape(x_shape) + x_shape = self.group_norm.compute_output_shape(x_shape) + x_shape = self.conv2.compute_output_shape(x_shape) + x_shape = self.gelu1.compute_output_shape(x_shape) + x_shape = self.conv3.compute_output_shape(x_shape) + x_shape = self.gelu2.compute_output_shape(x_shape) + output_shape = {"input_values": x_shape} + if self.return_attention_mask: + # [batch_size, output_time_steps]. + output_shape["attention_mask"] = (expanded_shape[0], x_shape[1]) + return output_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "filter_dim": self.filter_dim, + "sampling_rate": self.sampling_rate, + "padding_value": self.padding_value, + "do_normalize": self.do_normalize, + "return_attention_mask": self.return_attention_mask, + "initializer_range": self.initializer_range, + } + ) + return config diff --git a/keras_hub/src/models/moonshine/moonshine_audio_converter_test.py b/keras_hub/src/models/moonshine/moonshine_audio_converter_test.py new file mode 100644 index 0000000000..ff80fcfe10 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_audio_converter_test.py @@ -0,0 +1,114 @@ +import keras + +from keras_hub.src.models.moonshine.moonshine_audio_converter import ( + MoonshineAudioConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineAudioConverterTest(TestCase): + def setUp(self): + super().setUp() + self.filter_dim = 256 + self.preprocessor = MoonshineAudioConverter(filter_dim=self.filter_dim) + self.input_data = keras.ops.convert_to_tensor( + [[0.1] * 16000], dtype="float32" + ) + self.input_data = keras.ops.expand_dims(self.input_data, axis=-1) + self.init_kwargs = { + "filter_dim": self.filter_dim, + "sampling_rate": 16000, + "padding_value": 0.0, + "do_normalize": False, + "return_attention_mask": True, + "initializer_range": 0.02, + } + + def test_output_shape(self): + output = self.preprocessor(self.input_data) + self.assertEqual( + keras.ops.shape(output["input_values"]), (1, 40, self.filter_dim) + ) + self.assertEqual(keras.ops.shape(output["attention_mask"]), (1, 40)) + self.assertAllEqual( + output["attention_mask"], keras.ops.ones((1, 40), dtype="int32") + ) + + def test_padding(self): + max_length = 20000 + output = self.preprocessor( + self.input_data, padding="max_length", max_length=max_length + ) + self.assertEqual( + keras.ops.shape(output["input_values"]), (1, 50, self.filter_dim) + ) + self.assertEqual(keras.ops.shape(output["attention_mask"]), (1, 50)) + expected_mask = keras.ops.concatenate( + [ + keras.ops.ones((1, 40), dtype="int32"), + keras.ops.zeros((1, 10), dtype="int32"), + ], + axis=1, + ) + self.assertAllEqual(output["attention_mask"], expected_mask) + + def test_normalization(self): + preprocessor_no_norm = MoonshineAudioConverter( + filter_dim=self.filter_dim, do_normalize=False + ) + preprocessor_norm = MoonshineAudioConverter( + filter_dim=self.filter_dim, do_normalize=True + ) + input_data = keras.ops.arange(16000, dtype="float32") / 16000 # Values + # from 0 to ~1 + input_data = keras.ops.expand_dims(input_data, axis=0) # (1, 16000) + input_data = keras.ops.expand_dims(input_data, axis=-1) # (1, 16000, 1) + output_no_norm = preprocessor_no_norm(input_data) + output_norm = preprocessor_norm(input_data) + self.assertFalse( + keras.ops.all( + output_no_norm["input_values"] == output_norm["input_values"] + ) + ) + + def test_sampling_rate_validation(self): + # Test with the correct sampling rate (should not raise an error). + self.preprocessor( + self.input_data, sampling_rate=self.preprocessor.sampling_rate + ) + # Test with an incorrect sampling rate (should raise ValueError). + with self.assertRaises(ValueError): + self.preprocessor(self.input_data, sampling_rate=8000) + + def test_get_config(self): + config = self.preprocessor.get_config() + self.assertIsInstance(config, dict) + self.assertEqual(config["filter_dim"], self.filter_dim) + self.assertEqual(config["sampling_rate"], 16000) + self.assertEqual(config["padding_value"], 0.0) + self.assertEqual(config["do_normalize"], False) + self.assertEqual(config["return_attention_mask"], True) + self.assertEqual(config["initializer_range"], 0.02) + + def test_correctness(self): + audio_input = keras.ops.convert_to_tensor( + [[1.0, 2.0, 3.0] + [0.0] * 15997], dtype="float32" + ) + audio_input = keras.ops.expand_dims(audio_input, axis=-1) + converter = MoonshineAudioConverter(**self.init_kwargs) + + outputs = converter(audio_input) + self.assertIn("input_values", outputs) + self.assertIn("attention_mask", outputs) + + self.assertEqual( + keras.ops.shape(outputs["input_values"]), (1, 40, self.filter_dim) + ) + self.assertEqual(keras.ops.shape(outputs["attention_mask"]), (1, 40)) + self.assertAllEqual( + outputs["attention_mask"], keras.ops.ones((1, 40), dtype="int32") + ) + + def test_serialization(self): + instance = MoonshineAudioConverter(**self.init_kwargs) + self.run_serialization_test(instance=instance) diff --git a/keras_hub/src/models/moonshine/moonshine_audio_to_text.py b/keras_hub/src/models/moonshine/moonshine_audio_to_text.py new file mode 100644 index 0000000000..1248f819ea --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_audio_to_text.py @@ -0,0 +1,483 @@ +import keras +import tensorflow as tf +from keras import tree + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import ( + MoonshineSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.MoonshineAudioToText") +class MoonshineAudioToText(Seq2SeqLM): + """An end-to-end Moonshine model for audio-to-text tasks. + + A Seq2Seq LM designed for audio-to-text tasks, such as speech recognition. + The encoder processes audio features, and the decoder generates text + transcriptions. You can finetune `MoonshineAudioToText` for any + audio-to-text task (e.g., live transcription or voice commands). + + This model includes a `generate()` method for text generation based on audio + inputs and an optional text prompt for the decoder. The generation strategy + is controlled by a `sampler` argument passed to `compile()`. By default, + `"top_k"` sampling is used. + + Args: + backbone: A `keras_hub.models.MoonshineBackbone` instance. + preprocessor: A `keras_hub.models.MoonshineSeq2SeqLMPreprocessor` or + `None`. If `None`, inputs must be preprocessed before calling the + model. + + Examples: + ```python + # Initialize model from preset. + moonshine_lm = keras_hub.models.MoonshineAudioToText.from_preset( + "moonshine_base" + ) + + # Generate with single audio input. + audio_tensor = keras.random.normal((1, 16000, 1)) + moonshine_lm.generate({"audio": audio_tensor}) + + # Generate with text prompt. + moonshine_lm.generate({"audio": audio_tensor, "text": "quick"}) + + # Use different sampling strategy. + moonshine_lm.compile(sampler="greedy") + moonshine_lm.generate({"audio": audio_tensor}) + """ + + # References: + # Defined and formulated based on the Hugging Face implementation of the + # MoonshineForConditionalGeneration class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1509-L1626). + + backbone_cls = MoonshineBackbone + preprocessor_cls = MoonshineSeq2SeqLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs)["decoder_sequence_output"] + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_decoder_with_cache( + self, + encoder_hidden_states, + encoder_padding_mask, + decoder_token_ids, + self_attention_cache=None, + self_attention_cache_update_index=None, + cross_attention_cache=None, + decoder_padding_mask=None, + ): + """Process decoder inputs with attention caching for efficient + generation. + + Args: + encoder_hidden_states: Tensor. Encoder outputs. + encoder_padding_mask: Tensor. Padding mask for encoder outputs. + decoder_token_ids: Tensor. Decoder input token IDs. + self_attention_cache: Tensor. Cache for self-attention layers. + self_attention_cache_update_index: int. Index for cache updates. + decoder_attention_mask: Tensor, optional. Mask for decoder attention + + Returns: + Tuple of (logits, hidden_states, self_attention_cache). + """ + tokens = self.backbone.token_embedding(decoder_token_ids) + x = tokens + + # Cache management for audio-to-text generation. + self_attention_caches = [] + cross_attention_caches = [] + + # Determine if this is initialization or generation. + if self_attention_cache_update_index is None: + # Initialization: Process full sequence, compute caches. + seq_len = keras.ops.shape(decoder_token_ids)[1] + positions = keras.ops.arange(0, seq_len, dtype="int32") + rotary_embedding = self.backbone.decoder_rotary_embedding(positions) + + self_attention_caches = [] + cross_attention_caches = [] + for layer in self.backbone.decoder_blocks: + x, cache_k, cache_v, x_attn_cache_k, x_attn_cache_v = layer( + [x, encoder_hidden_states, rotary_embedding], + use_cache=False, + decoder_attention_mask=decoder_padding_mask, + encoder_attention_mask=encoder_padding_mask, + ) + # Stack key and value for each layer. + self_attention_caches.append( + keras.ops.stack([cache_k, cache_v], axis=1) + ) + cross_attention_caches.append( + keras.ops.stack([x_attn_cache_k, x_attn_cache_v], axis=1) + ) + self_attention_cache = keras.ops.stack( + self_attention_caches, axis=1 + ) + cross_attention_cache = keras.ops.stack( + cross_attention_caches, axis=1 + ) + + else: + position = keras.ops.array( + [self_attention_cache_update_index], dtype="int32" + ) + position_ids = keras.ops.expand_dims(position, axis=0) + batch_size = keras.ops.shape(decoder_token_ids)[0] + if batch_size > 1: + position_ids = keras.ops.repeat( + position_ids, batch_size, axis=0 + ) + rotary_embedding = self.backbone.decoder_rotary_embedding(position) + + for i, layer in enumerate(self.backbone.decoder_blocks): + # [batch_size, 2, seq_len, num_heads, head_dim]. + current_self_cache = self_attention_cache[:, i, :, :, :, :] + cache_k = current_self_cache[ + :, 0, :, :, : + ] # [batch_size, seq_len, num_heads, head_dim] + cache_v = current_self_cache[ + :, 1, :, :, : + ] # [batch_size, seq_len, num_heads, head_dim] + # [batch_size, 2, context_len, num_heads, head_dim]. + current_cross_cache = cross_attention_cache[:, i, :, :, :, :] + x_attn_cache_k = current_cross_cache[ + :, 0, :, :, : + ] # [batch_size, context_len, num_heads, head_dim] + x_attn_cache_v = current_cross_cache[ + :, 1, :, :, : + ] # [batch_size, context_len, num_heads, head_dim] + + # Call layer with 7 inputs. + x, new_cache_k, new_cache_v = layer( + [ + x, + encoder_hidden_states, + cache_k, + cache_v, + x_attn_cache_k, + x_attn_cache_v, + rotary_embedding, + ], + use_cache=True, + decoder_attention_mask=decoder_padding_mask, + encoder_attention_mask=encoder_padding_mask, + training=False, + ) + # Update self-attention cache. + new_self_cache = keras.ops.stack( + [new_cache_k, new_cache_v], axis=1 + ) + self_attention_caches.append(new_self_cache) + + # [batch_size, num_layers, 2, seq_len, num_heads, head_dim]. + self_attention_cache = keras.ops.stack( + self_attention_caches, axis=1 + ) + + hidden_states = self.backbone.decoder_post_norm(x) + logits = self.backbone.logits(hidden_states) + return ( + logits, + hidden_states, + self_attention_cache, + cross_attention_cache, + ) + + def call_encoder(self, encoder_input_values, padding_mask): + """Process audio input through the encoder stack.""" + x = encoder_input_values + seq_length = keras.ops.shape(x)[1] + positions = keras.ops.arange(0, seq_length, dtype="int32") + rotary_embedding = self.backbone.encoder_rotary_embedding(positions) + if hasattr(self.backbone, "encoder_dropout"): + x = self.backbone.encoder_dropout(x, training=False) + for transformer_layer in self.backbone.encoder_blocks: + x = transformer_layer( + inputs=x, + rotary_embedding=rotary_embedding, + attention_mask=padding_mask, + training=False, + ) + if hasattr(self.backbone, "encoder_final_layer_norm"): + x = self.backbone.encoder_final_layer_norm(x) + return x + + def _build_cache( + self, + audio_inputs, + audio_padding_mask, + decoder_token_ids, + ): + """Initialize and populate attention caches with encoder and decoder + outputs.""" + encoder_hidden_states = self.call_encoder( + audio_inputs, padding_mask=audio_padding_mask + ) + _, hidden_states, self_attention_cache, cross_attention_cache = ( + self.call_decoder_with_cache( + encoder_hidden_states=encoder_hidden_states, + encoder_padding_mask=audio_padding_mask, + decoder_token_ids=decoder_token_ids, + self_attention_cache=None, + cross_attention_cache=None, + ) + ) + return ( + hidden_states, + encoder_hidden_states, + self_attention_cache, + cross_attention_cache, + ) + + # Source: https://github.com/huggingface/transformers/blob/9e94801146ceeb3b215bbdb9492be74d7d7b7210/src/transformers/generation/utils.py#L1970-L2463 + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"encoder_input_values"`, + `"encoder_padding_mask"`, `"decoder_token_ids"` and + `"decoder_padding_mask"`. + + Args: + inputs: A dictionary with four keys - `"encoder_input_values"`, + `"encoder_padding_mask"`, `"decoder_token_ids"` and + `"decoder_padding_mask"`, with batched tensor values. + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + + Returns: + Dictionary: A dictionary with two keys - `"decoder_token_ids"` + containing the updated token sequence with newly generated + tokens, and `"decoder_padding_mask"` containing the updated + padding mask for the generated sequence. + """ + encoder_input_values = inputs["encoder_input_values"] + encoder_padding_mask = inputs["encoder_padding_mask"] + decoder_token_ids = inputs["decoder_token_ids"] + decoder_padding_mask = inputs["decoder_padding_mask"] + + if ( + encoder_input_values is None + or encoder_padding_mask is None + or decoder_token_ids is None + ): + raise ValueError("Input tensors cannot be None") + + batch_size = keras.ops.shape(encoder_input_values)[0] + # Calculate the length of the valid prompt before building the cache. + row_lengths = keras.ops.sum( + keras.ops.cast(decoder_padding_mask, "int32"), + axis=-1, + ) + index = keras.ops.min(row_lengths) + # NOTE: For the JAX backend, pre-allocate the cache based on max_length. + max_length = keras.ops.shape(decoder_token_ids)[1] + + encoder_hidden_states = self.call_encoder( + encoder_input_values=encoder_input_values, + padding_mask=encoder_padding_mask, + ) + initial_decoder_token_ids = keras.ops.slice( + decoder_token_ids, [0, 0], [batch_size, index] + ) + initial_decoder_padding_mask = keras.ops.slice( + decoder_padding_mask, [0, 0], [batch_size, index] + ) + ( + _, + hidden_states, + init_self_attention_cache, + init_cross_attention_cache, + ) = self.call_decoder_with_cache( + encoder_hidden_states=encoder_hidden_states, + encoder_padding_mask=encoder_padding_mask, + decoder_token_ids=initial_decoder_token_ids, + self_attention_cache=None, + cross_attention_cache=None, + decoder_padding_mask=initial_decoder_padding_mask, + ) + self_attention_cache = init_self_attention_cache + cross_attention_cache = init_cross_attention_cache + + row_lengths = keras.ops.sum( + keras.ops.cast(decoder_padding_mask, "int32"), + axis=-1, + ) + index = keras.ops.min(row_lengths) + + def next(prompt, cache, index): + if isinstance(cache, tuple) and len(cache) == 1: + cache = cache[0] + elif isinstance(cache, tuple) and len(cache) == 0: + cache = None + cache_index = index - 1 + num_samples = keras.ops.shape(prompt)[0] + next_token_input = keras.ops.slice( + prompt, [0, cache_index], [num_samples, 1] + ) + single_token_padding_mask = keras.ops.ones_like( + next_token_input, dtype="bool" + ) + + def repeat_tensor(x): + if keras.ops.shape(x)[0] == num_samples: + return x + return keras.ops.repeat( + x, repeats=num_samples // batch_size, axis=0 + ) + + logits, hidden_states, new_cache, _ = self.call_decoder_with_cache( + encoder_hidden_states=repeat_tensor(encoder_hidden_states), + encoder_padding_mask=repeat_tensor(encoder_padding_mask), + decoder_token_ids=next_token_input, + self_attention_cache=cache, + self_attention_cache_update_index=cache_index, + cross_attention_cache=repeat_tensor(cross_attention_cache), + decoder_padding_mask=single_token_padding_mask, + ) + return ( + logits[:, 0, :], + hidden_states[:, 0, :], + new_cache, + ) + + if keras.config.backend() == "jax": + current_prompt = decoder_token_ids + current_cache = self_attention_cache + current_index = index + for _ in range(max_length - index): + if stop_token_ids is not None: + prompt_mask = keras.ops.cast( + current_prompt + == ( + self.preprocessor.tokenizer.pad_token_id + if self.preprocessor + else -1 + ), + dtype="bool", + ) + valid_token_mask = ~prompt_mask + full_range = keras.ops.arange(max_length) + generated_range_mask = (full_range >= index) & ( + full_range < current_index + ) + check_mask = valid_token_mask & keras.ops.expand_dims( + generated_range_mask, 0 + ) + end_tokens = any_equal( + current_prompt, stop_token_ids, check_mask + ) + prompt_done = keras.ops.any(end_tokens, axis=-1) + if keras.ops.all(prompt_done): + break + + logits, _, current_cache = next( + current_prompt, current_cache, current_index + ) + probabilities = self.sampler.compute_probabilities(logits) + next_token = self.sampler.get_next_token(probabilities) + next_token = keras.ops.cast(next_token, current_prompt.dtype) + next_token = next_token[:, None] + current_prompt = keras.ops.slice_update( + current_prompt, [0, current_index], next_token + ) + current_index += 1 + + decoder_token_ids = current_prompt + else: + decoder_token_ids = self.sampler( + next=next, + prompt=decoder_token_ids, + cache=self_attention_cache, + index=index, + mask=keras.ops.cast( + decoder_token_ids + != self.preprocessor.tokenizer.pad_token_id + if self.preprocessor is not None + else decoder_padding_mask, + dtype="bool", + ), + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + if stop_token_ids is not None: + end_locations = any_equal( + decoder_token_ids, + stop_token_ids, + decoder_token_ids == self.preprocessor.tokenizer.pad_token_id + if self.preprocessor is not None + else False, + ) + end_locations = keras.ops.cast(end_locations, "int32") + cumsum = keras.ops.cumsum(end_locations, axis=-1) + overflow = cumsum - end_locations + decoder_padding_mask = keras.ops.logical_not( + keras.ops.cast(overflow, "bool") + ) + else: + decoder_padding_mask = keras.ops.ones_like( + decoder_token_ids, dtype="bool" + ) + + return { + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + def make_generate_function(self): + """Create or return the compiled generation function.""" + if self.generate_function is not None: + return self.generate_function + + self.generate_function = self.generate_step + if keras.config.backend() == "torch": + import torch + + def wrapped_generate_function( + inputs, + stop_token_ids=None, + ): + with torch.no_grad(): + return self.generate_step(inputs, stop_token_ids) + + self.generate_function = wrapped_generate_function + elif keras.config.backend() == "tensorflow" and not self.run_eagerly: + # `jit_compile` is a property of keras.Model after TF 2.12. + # Use `getattr()` for backwards compatibility. + # NOTE: Override, explicitly disabled JIT compilation for the + # TensorFlow backend. + self.generate_function = tf.function( + self.generate_step, jit_compile=False + ) + elif keras.config.backend() == "jax" and not self.run_eagerly: + + def wrapped_generate_function( + inputs, + stop_token_ids=None, + ): + inputs = tree.map_structure(keras.ops.convert_to_tensor, inputs) + return self.generate_step(inputs, stop_token_ids) + + self.generate_function = wrapped_generate_function + + return self.generate_function diff --git a/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py b/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py new file mode 100644 index 0000000000..ec7d4e0516 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_audio_to_text_test.py @@ -0,0 +1,172 @@ +import os +from unittest.mock import patch + +import keras +import numpy as np +import pytest +from keras import ops + +from keras_hub.src.models.moonshine.moonshine_audio_converter import ( + MoonshineAudioConverter, +) +from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( + MoonshineAudioToText, +) +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import ( + MoonshineSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineAudioToTextTest(TestCase): + def setUp(self): + self.tokenizer = MoonshineTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "moonshine_test_vocab.spm" + ) + ) + self.vocab_size = 1036 + hidden_dim = 32 + self.audio_converter = MoonshineAudioConverter( + filter_dim=hidden_dim, + sampling_rate=16000, + do_normalize=False, + return_attention_mask=True, + padding_value=0.0, + initializer_range=0.02, + ) + self.preprocessor = MoonshineSeq2SeqLMPreprocessor( + audio_converter=self.audio_converter, + tokenizer=self.tokenizer, + decoder_sequence_length=10, + ) + self.backbone = MoonshineBackbone( + vocabulary_size=self.vocab_size, + hidden_dim=hidden_dim, + encoder_num_layers=2, + decoder_num_layers=2, + encoder_num_heads=4, + decoder_num_heads=4, + intermediate_dim=hidden_dim * 4, + feedforward_expansion_factor=4, + encoder_use_swiglu_activation=False, + decoder_use_swiglu_activation=True, + max_position_embeddings=2048, + pad_head_dim_to_multiple_of=None, + partial_rotary_factor=0.62, + dropout=0.0, + initializer_range=0.02, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + rope_scaling=None, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + # NOTE: Since keras.ops.convert_to_tensor() does not support + # dtype="string" for the JAX and PyTorch backends, the only way to pass + # inputs that aren't a mix of tensors and non-tensors is to use a + # library-specific function. Using an np.ndarray here as a substitute to + # a librosa.load() call. + self.train_data = ( + { + "audio": np.random.normal(size=(2, 16000, 1)), + "text": ["quick brown", "earth is round"], + }, + ) + self.input_data = self.preprocessor(self.train_data[0])[0] + + @pytest.mark.skipif( + keras.config.backend() == "torch" or keras.config.backend() == "jax", + reason="NotImplementedError: Cannot convert a symbolic tf.Tensor (args_" + "0:0) to a numpy array. This error may indicate that you're trying to" + "pass a Tensor to a NumPy call, which is not supported.", + ) + def test_causal_lm_basics(self): + self.run_task_test( + cls=MoonshineAudioToText, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 10, self.tokenizer.vocabulary_size()), + ) + + def test_generate(self): + inputs = {"audio": keras.random.normal((1, 16000, 1)), "text": "quick"} + seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs) + output = seq_2_seq_lm.generate(inputs) + self.assertTrue("quick" in output) + + seq_2_seq_lm.preprocessor = None + preprocessed = self.preprocessor.generate_preprocess(inputs) + outputs = seq_2_seq_lm.generate(preprocessed, stop_token_ids=None) + self.assertAllEqual( + outputs["decoder_token_ids"][:, :2], + preprocessed["decoder_token_ids"][:, :2], + ) + + def test_early_stopping(self): + seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs) + call_decoder_with_cache = seq_2_seq_lm.call_decoder_with_cache + + def wrapper(*args, **kwargs): + logits, hidden_states, self_cache, cross_cache = ( + call_decoder_with_cache(*args, **kwargs) + ) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, self_cache, cross_cache + + with patch.object( + seq_2_seq_lm, "call_decoder_with_cache", wraps=wrapper + ): + inputs = { + "audio": keras.random.normal((2, 16000, 1)), + "text": ["quick", "earth"], + } + output = seq_2_seq_lm.generate(inputs) + self.assertAllEqual(inputs["text"], output) + + def test_generate_compilation(self): + seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs) + seq_2_seq_lm.generate({"audio": keras.random.normal((1, 16000, 1))}) + first_fn = seq_2_seq_lm.generate_function + seq_2_seq_lm.generate({"audio": keras.random.normal((1, 16000, 1))}) + second_fn = seq_2_seq_lm.generate_function + self.assertEqual(first_fn, second_fn) + seq_2_seq_lm.compile(sampler="greedy") + self.assertIsNone(seq_2_seq_lm.generate_function) + + @pytest.mark.skipif( + keras.config.backend() == "jax", + reason="Beam search involves state management not supported in the " + "JAX manual eager loop override.", + ) + def test_beam_search(self): + seq_2_seq_lm = MoonshineAudioToText(**self.init_kwargs) + seq_2_seq_lm.compile(sampler="beam") + seq_2_seq_lm.generate({"audio": keras.random.normal((1, 16000, 1))}) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MoonshineAudioToText, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MoonshineAudioToText.presets: + self.run_preset_test( + cls=MoonshineAudioToText, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/moonshine/moonshine_backbone.py b/keras_hub/src/models/moonshine/moonshine_backbone.py new file mode 100644 index 0000000000..f873653202 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_backbone.py @@ -0,0 +1,360 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.moonshine.moonshine_decoder import ( + MoonshineDecoderBlock, +) +from keras_hub.src.models.moonshine.moonshine_encoder import ( + MoonshineEncoderBlock, +) +from keras_hub.src.models.moonshine.moonshine_layers import ( + MoonshineRotaryEmbedding, +) +from keras_hub.src.models.moonshine.moonshine_layers import ( + moonshine_kernel_initializer, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +class Arange(keras.layers.Layer): + def call(self, inputs): + sequence_length = keras.ops.shape(inputs)[1] + return keras.ops.arange(sequence_length, dtype="int32") + + +@keras_hub_export("keras_hub.models.MoonshineBackbone") +class MoonshineBackbone(Backbone): + """Moonshine backbone for speech recognition. + + This class implements an encoder-decoder backbone, as used in the Moonshine + ASR system. It combines `MoonshineEncoderBlock` instances for processing + input sequences and `MoonshineDecoderBlock` instances for generating output + sequences. + + Args: + vocabulary_size: int. The size of the vocabulary for the embedding + layers. + encoder_num_layers: int. The number of stacked encoder blocks. + decoder_num_layers: int. The number of stacked decoder blocks. + hidden_dim: int. The dimensionality of the model's hidden + representations and embeddings. + intermediate_dim: int. The dimensionality of the intermediate + representations in feedforward networks. + encoder_num_heads: int. The number of attention heads in the encoder's + multi-head attention. + decoder_num_heads: int. The number of attention heads in the decoder's + multi-head attention. + feedforward_expansion_factor: int, optional. A multiplier applied to + `intermediate_dim` to determine the total width of the feedforward + network. Defaults to 4. + use_swiglu_activation: bool, optional. When True, uses the SwiGLU + activation in the feedforward network for improved performance. + Defaults to False. + max_position_embeddings: int, optional. The maximum sequence length for + position embeddings. Defaults to 2048. + pad_head_dim_to_multiple_of: int, optional. If specified, pads the head + dimension to be a multiple of this value for performance + optimization. Defaults to None. + partial_rotary_factor: float, optional. The fraction of dimensions to + apply rotary position embeddings to. Defaults to 0.62. + dropout: float, optional. The dropout probability for input dropout + layers. Defaults to 0.0. + initializer_range: float, optional. The standard deviation of the + truncated normal initializer for weights. Defaults to 0.02. + rope_theta: float, optional. The base frequency for rotary position + embeddings. Defaults to 10,000.0. + attention_bias: bool, optional. Whether to use bias in attention + mechanisms. Defaults to False. + attention_dropout: float, optional. The dropout probability for + attention mechanisms. Defaults to 0.0. + rope_scaling: dict, optional. The scaling configuration for rotary + position embeddings. Defaults to None. + dtype: str, optional. The dtype to use for model computations and + weights. Defaults to None. + + Examples: + ```python + # Create random input data for demonstration. + encoder_input_values = np.random.rand(1, 100, 256).astype("float32") + decoder_token_ids = np.random.randint( + 0, 1000, size=(1, 20), dtype="int32" + ) + + # Initialize the Moonshine backbone with specific parameters. + backbone = MoonshineBackbone( + vocabulary_size=10000, + encoder_num_layers=6, + decoder_num_layers=6, + hidden_dim=256, + intermediate_dim=512, + encoder_num_heads=8, + decoder_num_heads=8, + feedforward_expansion_factor=4, + use_swiglu_activation=True, + ) + + # Forward pass through the model. + outputs = backbone( + { + "encoder_input_values": encoder_input_values, + "decoder_token_ids": decoder_token_ids, + } + ) + + # Display the outputs. + print("Encoder output shape:", outputs["encoder_sequence_output"].shape) + print("Decoder output shape:", outputs["decoder_sequence_output"].shape) + ``` + """ + + # References: + # Defined and formulated based on the Hugging Face implementation of the + # MoonshineModel class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1326-L1486). + + def __init__( + self, + vocabulary_size, + encoder_num_layers, + decoder_num_layers, + hidden_dim, + intermediate_dim, + encoder_num_heads, + decoder_num_heads, + feedforward_expansion_factor=4, + encoder_use_swiglu_activation=False, + decoder_use_swiglu_activation=True, + max_position_embeddings=2048, + pad_head_dim_to_multiple_of=None, + partial_rotary_factor=0.62, + dropout=0.0, + initializer_range=0.02, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + rope_scaling=None, + dtype=None, + **kwargs, + ): + # ==== Config ==== + self.vocabulary_size = vocabulary_size + self.encoder_num_layers = encoder_num_layers + self.decoder_num_layers = decoder_num_layers + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.encoder_num_heads = encoder_num_heads + self.decoder_num_heads = decoder_num_heads + self.feedforward_expansion_factor = feedforward_expansion_factor + self.encoder_use_swiglu_activation = encoder_use_swiglu_activation + self.decoder_use_swiglu_activation = decoder_use_swiglu_activation + self.max_position_embeddings = max_position_embeddings + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + self.partial_rotary_factor = partial_rotary_factor + self.dropout = dropout + self.initializer_range = initializer_range + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + self.embeddings_initializer = moonshine_kernel_initializer( + initializer_range=initializer_range + ) + + # ==== Layers ==== + encoder_head_dim = hidden_dim // encoder_num_heads + if pad_head_dim_to_multiple_of: + encoder_head_dim = ( + (encoder_head_dim + pad_head_dim_to_multiple_of - 1) + // pad_head_dim_to_multiple_of + ) * pad_head_dim_to_multiple_of + + decoder_head_dim = hidden_dim // decoder_num_heads + if pad_head_dim_to_multiple_of: + decoder_head_dim = ( + (decoder_head_dim + pad_head_dim_to_multiple_of - 1) + // pad_head_dim_to_multiple_of + ) * pad_head_dim_to_multiple_of + + # Embedding layer for decoder. + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=clone_initializer( + self.embeddings_initializer + ), + name="token_embedding", + dtype=dtype, + ) + + # Rotary embeddings for encoder and decoder. + self.encoder_rotary_embedding = MoonshineRotaryEmbedding( + head_dim=encoder_head_dim, + max_position_embeddings=max_position_embeddings, + partial_rotary_factor=partial_rotary_factor, + base_value=rope_theta, + rope_scaling=rope_scaling, + name="encoder_rotary_embedding", + dtype=dtype, + ) + + self.decoder_rotary_embedding = MoonshineRotaryEmbedding( + head_dim=decoder_head_dim, + max_position_embeddings=max_position_embeddings, + partial_rotary_factor=partial_rotary_factor, + base_value=rope_theta, + rope_scaling=rope_scaling, + name="decoder_rotary_embedding", + dtype=dtype, + ) + + # Dropout for encoder. + self.encoder_dropout = keras.layers.Dropout( + dropout, name="encoder_dropout", dtype=dtype + ) + # Dropout for decoder. + self.decoder_dropout = keras.layers.Dropout( + dropout, name="decoder_dropout", dtype=dtype + ) + + # Encoder blocks. + self.encoder_blocks = [] + for i in range(encoder_num_layers): + encoder_block = MoonshineEncoderBlock( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_heads=encoder_num_heads, + feedforward_expansion_factor=feedforward_expansion_factor, + use_swiglu_activation=encoder_use_swiglu_activation, + pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of, + initializer_range=initializer_range, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + name=f"encoder_block_{i}", + dtype=dtype, + ) + self.encoder_blocks.append(encoder_block) + + # Layer normalization for encoder. + self.encoder_final_layer_norm = keras.layers.LayerNormalization( + epsilon=1e-5, + center=False, + scale=True, + name="encoder_final_layer_norm", + dtype=dtype, + ) + + # Decoder blocks. + self.decoder_blocks = [] + for i in range(decoder_num_layers): + decoder_block = MoonshineDecoderBlock( + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_heads=decoder_num_heads, + feedforward_expansion_factor=feedforward_expansion_factor, + use_swiglu_activation=decoder_use_swiglu_activation, + pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of, + initializer_range=initializer_range, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + name=f"decoder_block_{i}", + dtype=dtype, + ) + self.decoder_blocks.append(decoder_block) + + # Layer normalization for decoder. + self.decoder_post_norm = keras.layers.LayerNormalization( + epsilon=1e-5, + center=False, + scale=True, + name="decoder_post_norm", + dtype=dtype, + ) + + # === Functional Model === + encoder_input = keras.Input( + shape=(None, hidden_dim), name="encoder_input_values", dtype=dtype + ) + decoder_input = keras.Input( + shape=(None,), name="decoder_token_ids", dtype="int32" + ) + encoder_padding_mask = keras.Input( + shape=(None,), name="encoder_padding_mask", dtype="bool" + ) + decoder_padding_mask = keras.Input( + shape=(None,), name="decoder_padding_mask", dtype="bool" + ) + + # Encoder. + encoder_positions = Arange(name="encoder_positions")(encoder_input) + encoder_rotary_emb = self.encoder_rotary_embedding(encoder_positions) + encoder_hidden_states = self.encoder_dropout(encoder_input) + for encoder_block in self.encoder_blocks: + encoder_hidden_states = encoder_block( + encoder_hidden_states, + encoder_rotary_emb, + attention_mask=encoder_padding_mask, + ) + encoder_output = self.encoder_final_layer_norm(encoder_hidden_states) + + # Decoder. + decoder_positions = Arange(name="decoder_positions")(decoder_input) + decoder_rotary_emb = self.decoder_rotary_embedding(decoder_positions) + decoder_hidden_states = self.token_embedding(decoder_input) + decoder_hidden_states = self.decoder_dropout(decoder_hidden_states) + for decoder_block in self.decoder_blocks: + decoder_hidden_states, _, _, _, _ = decoder_block( + [decoder_hidden_states, encoder_output, decoder_rotary_emb], + decoder_attention_mask=decoder_padding_mask, + encoder_attention_mask=encoder_padding_mask, + ) + decoder_output = self.decoder_post_norm(decoder_hidden_states) + + super().__init__( + inputs={ + "encoder_input_values": encoder_input, + "decoder_token_ids": decoder_input, + "encoder_padding_mask": encoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask, + }, + outputs={ + "encoder_sequence_output": encoder_output, + "decoder_sequence_output": decoder_output, + }, + dtype=dtype, + **kwargs, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "encoder_num_layers": self.encoder_num_layers, + "decoder_num_layers": self.decoder_num_layers, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "encoder_num_heads": self.encoder_num_heads, + "decoder_num_heads": self.decoder_num_heads, + "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501 + "encoder_use_swiglu_activation": self.encoder_use_swiglu_activation, # noqa: E501 + "decoder_use_swiglu_activation": self.decoder_use_swiglu_activation, # noqa: E501 + "max_position_embeddings": self.max_position_embeddings, + "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of, + "partial_rotary_factor": self.partial_rotary_factor, + "dropout": self.dropout, + "initializer_range": self.initializer_range, + "rope_theta": self.rope_theta, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "rope_scaling": self.rope_scaling, + "dtype": self.dtype, + } + ) + return config + + # Use the MoonshineBackbone class as part of a trainable model. + def logits(self, decoder_hidden_states): + return self.token_embedding(decoder_hidden_states, reverse=True) diff --git a/keras_hub/src/models/moonshine/moonshine_backbone_test.py b/keras_hub/src/models/moonshine/moonshine_backbone_test.py new file mode 100644 index 0000000000..18d04973f1 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_backbone_test.py @@ -0,0 +1,197 @@ +import keras +import pytest + +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10000, + "encoder_num_layers": 2, + "decoder_num_layers": 2, + "hidden_dim": 64, + "intermediate_dim": 512, + "encoder_num_heads": 8, + "decoder_num_heads": 8, + "feedforward_expansion_factor": 4, + "encoder_use_swiglu_activation": False, + "decoder_use_swiglu_activation": True, + "max_position_embeddings": 2048, + "pad_head_dim_to_multiple_of": None, + "partial_rotary_factor": 0.62, + "dropout": 0.0, + "initializer_range": 0.02, + "rope_theta": 10000.0, + "attention_bias": False, + "attention_dropout": 0.0, + "rope_scaling": None, + } + encoder_input_values = keras.random.uniform((2, 16, 64)) + decoder_token_ids = keras.random.randint( + shape=(2, 10), minval=0, maxval=10000 + ) + encoder_padding_mask = keras.ops.ones((2, 16), dtype="bool") + decoder_padding_mask = keras.ops.ones((2, 10), dtype="bool") + self.input_data = { + "encoder_input_values": encoder_input_values, + "decoder_token_ids": decoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask, + } + + def test_forward_pass(self): + backbone = MoonshineBackbone(**self.init_kwargs) + outputs = backbone(self.input_data) + self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64)) + self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64)) + + def test_serialization(self): + instance = MoonshineBackbone(**self.init_kwargs) + self.run_serialization_test(instance=instance) + + def test_swiglu_feedforward(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs["encoder_use_swiglu_activation"] = True + backbone = MoonshineBackbone(**init_kwargs) + outputs = backbone(self.input_data) + self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64)) + self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64)) + + def test_different_sequence_lengths(self): + backbone = MoonshineBackbone(**self.init_kwargs) + + # Short sequences. + short_encoder_input_values = keras.random.uniform((2, 8, 64)) + short_decoder_token_ids = keras.random.randint( + shape=(2, 5), minval=0, maxval=10000 + ) + short_encoder_padding_mask = keras.ops.ones((2, 8), dtype="bool") + short_decoder_padding_mask = keras.ops.ones((2, 5), dtype="bool") + short_input_data = { + "encoder_input_values": short_encoder_input_values, + "decoder_token_ids": short_decoder_token_ids, + "encoder_padding_mask": short_encoder_padding_mask, + "decoder_padding_mask": short_decoder_padding_mask, + } + short_outputs = backbone(short_input_data) + self.assertEqual( + short_outputs["encoder_sequence_output"].shape, (2, 8, 64) + ) + self.assertEqual( + short_outputs["decoder_sequence_output"].shape, (2, 5, 64) + ) + + # Long sequences. + long_encoder_input_values = keras.random.uniform((2, 32, 64)) + long_decoder_token_ids = keras.random.randint( + shape=(2, 15), minval=0, maxval=10000 + ) + long_encoder_padding_mask = keras.ops.ones((2, 32), dtype="bool") + long_decoder_padding_mask = keras.ops.ones((2, 15), dtype="bool") + long_input_data = { + "encoder_input_values": long_encoder_input_values, + "decoder_token_ids": long_decoder_token_ids, + "encoder_padding_mask": long_encoder_padding_mask, + "decoder_padding_mask": long_decoder_padding_mask, + } + long_outputs = backbone(long_input_data) + self.assertEqual( + long_outputs["encoder_sequence_output"].shape, (2, 32, 64) + ) + self.assertEqual( + long_outputs["decoder_sequence_output"].shape, (2, 15, 64) + ) + + def test_predict_model(self): + backbone = MoonshineBackbone(**self.init_kwargs) + outputs = backbone.predict(self.input_data) + self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64)) + self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64)) + + def test_varying_batch_sizes(self): + backbone = MoonshineBackbone(**self.init_kwargs) + for batch_size in [1, 3, 5]: + encoder_input_values = keras.random.uniform((batch_size, 16, 64)) + decoder_token_ids = keras.random.randint( + shape=(batch_size, 10), minval=0, maxval=10000 + ) + encoder_padding_mask = keras.ops.ones( + (batch_size, 16), dtype="bool" + ) + decoder_padding_mask = keras.ops.ones( + (batch_size, 10), dtype="bool" + ) + input_data = { + "encoder_input_values": encoder_input_values, + "decoder_token_ids": decoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask, + } + outputs = backbone(input_data) + self.assertEqual( + outputs["encoder_sequence_output"].shape, (batch_size, 16, 64) + ) + self.assertEqual( + outputs["decoder_sequence_output"].shape, (batch_size, 10, 64) + ) + + def test_attention_parameters(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs["attention_bias"] = True + init_kwargs["attention_dropout"] = 0.1 + backbone = MoonshineBackbone(**init_kwargs) + outputs = backbone(self.input_data) + self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64)) + self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64)) + + def test_rope_parameters(self): + init_kwargs = self.init_kwargs.copy() + init_kwargs["rope_theta"] = 5000.0 + init_kwargs["rope_scaling"] = {"type": "linear", "factor": 2.0} + backbone = MoonshineBackbone(**init_kwargs) + outputs = backbone(self.input_data) + self.assertEqual(outputs["encoder_sequence_output"].shape, (2, 16, 64)) + self.assertEqual(outputs["decoder_sequence_output"].shape, (2, 10, 64)) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MoonshineBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MoonshineBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "encoder_sequence_output": (2, 16, 64), + "decoder_sequence_output": (2, 10, 64), + }, + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MoonshineBackbone.presets.keys(): + hidden_size = 288 if preset == "moonshine_tiny_en" else 416 + encoder_input_values = keras.ops.ones((1, 100, hidden_size)) + decoder_token_ids = keras.ops.ones((1, 10), dtype="int32") + encoder_padding_mask = keras.ops.ones((1, 100), dtype="bool") + decoder_padding_mask = keras.ops.ones((1, 10), dtype="bool") + input_data = { + "encoder_input_values": encoder_input_values, + "decoder_token_ids": decoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask, + } + self.run_preset_test( + cls=MoonshineBackbone, + preset=preset, + input_data=input_data, + ) diff --git a/keras_hub/src/models/moonshine/moonshine_decoder.py b/keras_hub/src/models/moonshine/moonshine_decoder.py new file mode 100644 index 0000000000..e8c98ec0aa --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_decoder.py @@ -0,0 +1,394 @@ +import keras + +from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder +from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP +from keras_hub.src.models.moonshine.moonshine_layers import ( + moonshine_kernel_initializer, +) +from keras_hub.src.models.moonshine.moonshine_multi_head_attention import ( + MoonshineMultiHeadAttention, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +@keras.saving.register_keras_serializable(package="keras_hub") +class MoonshineDecoderBlock(TransformerDecoder): + """Moonshine decoder block for sequence processing. + + This layer implements a decoder block that includes self-attention with + causal masking, cross-attention with precomputed key/value pairs, and a + feedforward network. It supports both cached and uncached operation modes. + + Args: + hidden_dim: int. The dimensionality of the model's hidden + representations. + intermediate_dim: int. The dimensionality of the intermediate + representations in the feedforward network. + num_heads: int. The number of attention heads for multi-head attention + mechanisms. + feedforward_expansion_factor: int, optional. A multiplicative factor for + scaling the feedforward network dimension. Defaults to 4. + use_swiglu_activation: bool, optional. Whether to use the SwiGLU + activation in the feedforward network for improved performance. + Defaults to True. + pad_head_dim_to_multiple_of: int, optional. If specified, pads the head + dimension to be a multiple of this value for performance + optimization. Defaults to None. + initializer_range: float, optional. The standard deviation of the + truncated normal distribution used to initialize model weights. + Defaults to 0.02. + attention_bias: bool, optional. Whether to add a bias term to the + attention computations. Defaults to False. + attention_dropout: float, optional. The dropout rate applied to + attention weights during training. Defaults to 0.0. + dtype: str, optional. The data type to use for model computations and + weights. Defaults to None. + **kwargs: Additional keyword arguments passed to the base layer. + """ + + # References: + # Defined and formulated based on the UsefulSensors implementation of the + # DecoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L348-L466). + + def __init__( + self, + hidden_dim, + intermediate_dim, + num_heads, + feedforward_expansion_factor=4, + use_swiglu_activation=True, + pad_head_dim_to_multiple_of=None, + initializer_range=0.02, + attention_bias=False, + attention_dropout=0.0, + dtype=None, + **kwargs, + ): + kwargs.pop("dropout", None) + kwargs.pop("activation", None) + kwargs.pop("kernel_initializer", None) + self.kernel_initializer = moonshine_kernel_initializer( + initializer_range=initializer_range + ) + super().__init__( + intermediate_dim=intermediate_dim, + num_heads=num_heads, + dropout=attention_dropout, + activation="gelu" if use_swiglu_activation else "silu", + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + **kwargs, + ) + self.initializer_range = initializer_range + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.feedforward_expansion_factor = feedforward_expansion_factor + self.use_swiglu_activation = use_swiglu_activation + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + + self.head_dim = hidden_dim // num_heads + if pad_head_dim_to_multiple_of is not None: + self.head_dim = ( + (self.head_dim + pad_head_dim_to_multiple_of - 1) + // pad_head_dim_to_multiple_of + ) * pad_head_dim_to_multiple_of + + self.norm1 = keras.layers.LayerNormalization( + axis=-1, + epsilon=1e-5, + center=False, + scale=True, + dtype=self.dtype, + ) + self.self_attention = MoonshineMultiHeadAttention( + num_heads=num_heads, + key_dim=self.head_dim, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + attention_bias=attention_bias, + attention_dropout=attention_dropout, + use_causal_mask=True, + apply_rotary_embedding=True, + cache_mode="autoregressive", + dtype=self.dtype, + ) + self.norm2 = keras.layers.LayerNormalization( + axis=-1, + epsilon=1e-5, + center=False, + scale=True, + dtype=self.dtype, + ) + self.cross_attention = MoonshineMultiHeadAttention( + num_heads=num_heads, + key_dim=self.head_dim, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + attention_bias=attention_bias, + attention_dropout=attention_dropout, + use_causal_mask=False, + apply_rotary_embedding=False, + cache_mode="precomputed", + dtype=self.dtype, + ) + self.norm3 = keras.layers.LayerNormalization( + axis=-1, + epsilon=1e-5, + center=False, + scale=True, + dtype=self.dtype, + ) + self.ff = MoonshineMLP( + hidden_dim=hidden_dim, + feedforward_expansion_factor=feedforward_expansion_factor, + use_swiglu_activation=use_swiglu_activation, + initializer_range=initializer_range, + dtype=self.dtype, + ) + + def build(self, input_shape): + if not isinstance(input_shape, (list, tuple)) or len(input_shape) < 2: + raise ValueError( + "Expected input_shape to be a list of at least two shapes." + ) + decoder_sequence_shape = ( + input_shape[0]["decoder_token_ids"] # Shape of x + if isinstance(input_shape[0], dict) + else input_shape[0] + ) + context_shape = ( + input_shape[1]["input_values"] # Shape of context + if isinstance(input_shape[1], dict) + else input_shape[1] + ) + + # Build sublayers. + self.norm1.build(decoder_sequence_shape) + self.norm2.build(decoder_sequence_shape) + self.norm3.build(decoder_sequence_shape) + + self.self_attention.build( + query_shape=decoder_sequence_shape, + key_shape=decoder_sequence_shape, + value_shape=decoder_sequence_shape, + ) + + self.cross_attention.build( + query_shape=decoder_sequence_shape, + key_shape=context_shape, + value_shape=context_shape, + ) + + self.ff.build(decoder_sequence_shape) + self.built = True + + def compute_output_spec( + self, + inputs, + training=None, + use_cache=False, + decoder_attention_mask=None, + encoder_attention_mask=None, + ): + if use_cache: + # Cached case: expect 7 inputs. + if len(inputs) != 7: + raise ValueError( + "When use_cache=True, expected 7 inputs: " + "[x, context, cache_k, cache_v, x_attn_cache_k, " + "x_attn_cache_v, rotary_embedding]" + ) + ( + x, + context, + cache_k, + cache_v, + x_attn_cache_k, + x_attn_cache_v, + rotary_embedding, + ) = inputs + # Output shape for x is the same as input x_shape but with + # hidden_dim. + x_shape = x.shape if hasattr(x, "shape") else x + output_shape = x_shape[:-1] + (self.hidden_dim,) + # New cache shapes are the same as input cache_k_shape and + # cache_v_shape. + # Note: In practice, sequence length may increase due to + # concatenation, but symbolically, it remains None. + new_cache_shape = ( + cache_k.shape if hasattr(cache_k, "shape") else cache_k + ) + return ( + keras.KerasTensor(shape=output_shape, dtype=self.dtype), # x + keras.KerasTensor( + shape=new_cache_shape, dtype=self.dtype + ), # new_cache_k + keras.KerasTensor( + shape=new_cache_shape, dtype=self.dtype + ), # new_cache_v + ) + else: + # Uncached case: expect 3 inputs. + if len(inputs) != 3: + raise ValueError( + "When use_cache=False, expected 3 inputs: [x, context, " + "rotary_embedding]" + ) + x, context, rotary_embedding = inputs + x_shape = x.shape if hasattr(x, "shape") else x + context_shape = ( + context.shape if hasattr(context, "shape") else context + ) + batch_size = x_shape[0] # None (symbolic) + seq_len = x_shape[1] # None (symbolic) + context_len = context_shape[1] # None (symbolic) + hidden_dim = self.hidden_dim + num_heads = self.num_heads + head_dim = self.head_dim + + # Define output shapes. + output_shape = (batch_size, seq_len, hidden_dim) # x + cache_shape_self = ( + batch_size, + seq_len, + num_heads, + head_dim, + ) # Self-attention caches + cache_shape_cross = ( + batch_size, + context_len, + num_heads, + head_dim, + ) # Cross-attention caches + + return ( + keras.KerasTensor(shape=output_shape, dtype=self.dtype), # x + keras.KerasTensor( + shape=cache_shape_self, dtype=self.dtype + ), # cache_k + keras.KerasTensor( + shape=cache_shape_self, dtype=self.dtype + ), # cache_v + keras.KerasTensor( + shape=cache_shape_cross, dtype=self.dtype + ), # x_attn_cache_k + keras.KerasTensor( + shape=cache_shape_cross, dtype=self.dtype + ), # x_attn_cache_v + ) + + def call( + self, + inputs, + training=None, + use_cache=False, + decoder_attention_mask=None, + encoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + ): + if use_cache: + if not isinstance(inputs, (list, tuple)) or len(inputs) != 7: + raise ValueError( + "When use_cache=True, expected 7 inputs: " + "[x, context, cache_k, cache_v, x_attn_cache_k, " + "x_attn_cache_v, rotary_embedding]. " + f"Received {len(inputs)} inputs." + ) + ( + x, + context, + cache_k, # Self-attn key cache + cache_v, # Self-attn value cache + x_attn_cache_k, # Cross-attn key cache (precomputed) + x_attn_cache_v, # Cross-attn value cache (precomputed) + rotary_embedding, + ) = inputs + else: + if not isinstance(inputs, (list, tuple)) or len(inputs) != 3: + raise ValueError( + "When use_cache=False, expected 3 inputs: [x, context, " + f"rotary_embedding]. Received {len(inputs)} inputs." + ) + x, context, rotary_embedding = inputs + cache_k, cache_v, x_attn_cache_k, x_attn_cache_v = ( + None, + None, + None, + None, + ) + residual = x + x_norm1 = self.norm1(x) + + if use_cache: + x_self_attn, new_cache_k, new_cache_v = self.self_attention( + query=x_norm1, + key=x_norm1, + value=x_norm1, + rotary_embedding=rotary_embedding, + key_cache=cache_k, + value_cache=cache_v, + attention_mask=decoder_attention_mask, + training=training, + ) + else: + x_self_attn, cache_k, cache_v = self.self_attention( + query=x_norm1, + key=x_norm1, + value=x_norm1, + rotary_embedding=rotary_embedding, + attention_mask=decoder_attention_mask, + training=training, + ) + x = x_self_attn + residual + residual = x + x_norm2 = self.norm2(x) + + if use_cache: + x_cross_attn = self.cross_attention( + query=x_norm2, + key=context, + value=context, + key_cache=x_attn_cache_k, + value_cache=x_attn_cache_v, + attention_mask=encoder_attention_mask, + training=training, + ) + else: + x_cross_attn, x_attn_cache_k, x_attn_cache_v = self.cross_attention( + query=x_norm2, + key=context, + value=context, + attention_mask=encoder_attention_mask, + training=training, + ) + x = x_cross_attn + residual + residual = x + x_norm3 = self.norm3(x) + x_ff = self.ff(x_norm3) + x = x_ff + residual + + if use_cache: + return x, new_cache_k, new_cache_v + return x, cache_k, cache_v, x_attn_cache_k, x_attn_cache_v + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501 + "use_swiglu_activation": self.use_swiglu_activation, + "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of, # noqa: E501 + "initializer_range": self.initializer_range, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "dtype": self.dtype, + } + ) + return config diff --git a/keras_hub/src/models/moonshine/moonshine_decoder_test.py b/keras_hub/src/models/moonshine/moonshine_decoder_test.py new file mode 100644 index 0000000000..9161b2cb1c --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_decoder_test.py @@ -0,0 +1,204 @@ +import keras + +from keras_hub.src.models.moonshine.moonshine_decoder import ( + MoonshineDecoderBlock, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineDecoderTest(TestCase): + def setUp(self): + super().setUp() + self.hidden_dim = 64 + self.intermediate_dim = 256 + self.num_heads = 4 + self.init_kwargs = { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "feedforward_expansion_factor": 4, + "use_swiglu_activation": True, + "pad_head_dim_to_multiple_of": None, + "initializer_range": 0.02, + "attention_bias": False, + "attention_dropout": 0.0, + } + self.decoder_block = MoonshineDecoderBlock(**self.init_kwargs) + self.batch_size = 2 + self.seq_len = 10 + self.encoder_seq_len = 16 + self.head_dim = self.hidden_dim // self.num_heads # 16 + self.rotary_dim = int( + self.head_dim * 0.62 + ) # Default partial_rotary_factor = 0.62 + self.rotary_dim = (self.rotary_dim // 2) * 2 # Ensure even + self.rotary_dim = self.rotary_dim // 2 # Half for freqs, e.g., 4 + self.x = keras.random.normal( + (self.batch_size, self.seq_len, self.hidden_dim) + ) + self.context = keras.random.normal( + (self.batch_size, self.encoder_seq_len, self.hidden_dim) + ) + self.rotary_embedding = keras.random.normal( + (self.seq_len, self.rotary_dim) + ) + self.decoder_attention_mask = keras.ops.ones( + (self.batch_size, self.seq_len), dtype="bool" + ) + self.encoder_attention_mask = keras.ops.ones( + (self.batch_size, self.encoder_seq_len), dtype="bool" + ) + + def test_initialization(self): + self.assertEqual(self.decoder_block.hidden_dim, self.hidden_dim) + self.assertEqual( + self.decoder_block.intermediate_dim, self.intermediate_dim + ) + self.assertEqual(self.decoder_block.num_heads, self.num_heads) + self.assertTrue(self.decoder_block.use_swiglu_activation) + + def test_forward_pass_without_caching(self): + outputs = self.decoder_block( + [self.x, self.context, self.rotary_embedding], + decoder_attention_mask=self.decoder_attention_mask, + encoder_attention_mask=self.encoder_attention_mask, + ) + x, cache_k, cache_v, x_attn_cache_k, x_attn_cache_v = outputs + self.assertEqual( + x.shape, (self.batch_size, self.seq_len, self.hidden_dim) + ) + self.assertEqual( + cache_k.shape, + (self.batch_size, self.seq_len, self.num_heads, self.head_dim), + ) + self.assertEqual( + cache_v.shape, + (self.batch_size, self.seq_len, self.num_heads, self.head_dim), + ) + self.assertEqual( + x_attn_cache_k.shape, + ( + self.batch_size, + self.encoder_seq_len, + self.num_heads, + self.head_dim, + ), + ) + self.assertEqual( + x_attn_cache_v.shape, + ( + self.batch_size, + self.encoder_seq_len, + self.num_heads, + self.head_dim, + ), + ) + + def test_forward_pass_with_padding(self): + # Padding in decoder sequence. + padded_mask = keras.ops.concatenate( + [ + keras.ops.ones((self.batch_size, 5), dtype="bool"), + keras.ops.zeros( + (self.batch_size, self.seq_len - 5), dtype="bool" + ), + ], + axis=1, + ) + outputs = self.decoder_block( + [self.x, self.context, self.rotary_embedding], + decoder_attention_mask=padded_mask, + encoder_attention_mask=self.encoder_attention_mask, + ) + x, _, _, _, _ = outputs + self.assertEqual( + x.shape, (self.batch_size, self.seq_len, self.hidden_dim) + ) + + def test_autoregressive_caching(self): + # First pass to get initial caches. + outputs_full = self.decoder_block( + [self.x, self.context, self.rotary_embedding], + decoder_attention_mask=self.decoder_attention_mask, + encoder_attention_mask=self.encoder_attention_mask, + ) + _, cache_k_full, cache_v_full, x_attn_cache_k, x_attn_cache_v = ( + outputs_full + ) + + # Autoregressive decoding. + for i in range(self.seq_len): + x_i = self.x[:, i : i + 1, :] + rotary_i = self.rotary_embedding[i : i + 1, :] + mask_i = self.decoder_attention_mask[:, i : i + 1] + cache_k = None if i == 0 else cache_k_full[:, :i, :, :] + cache_v = None if i == 0 else cache_v_full[:, :i, :, :] + outputs_i = self.decoder_block( + [ + x_i, + self.context, + cache_k, + cache_v, + x_attn_cache_k, + x_attn_cache_v, + rotary_i, + ], + use_cache=True, + decoder_attention_mask=mask_i, + encoder_attention_mask=self.encoder_attention_mask, + ) + x_i_out, new_cache_k, new_cache_v = outputs_i + self.assertEqual( + x_i_out.shape, (self.batch_size, 1, self.hidden_dim) + ) + self.assertEqual( + new_cache_k.shape, + (self.batch_size, i + 1, self.num_heads, self.head_dim), + ) + self.assertEqual( + new_cache_v.shape, + (self.batch_size, i + 1, self.num_heads, self.head_dim), + ) + + def test_caching_consistency(self): + # Full sequence without caching. + outputs_full = self.decoder_block( + [self.x, self.context, self.rotary_embedding], + decoder_attention_mask=self.decoder_attention_mask, + encoder_attention_mask=self.encoder_attention_mask, + ) + x_full, _, _, _, _ = outputs_full + + # Autoregressive with caching. + x_auto = [] + cache_k, cache_v = None, None + x_attn_cache_k, x_attn_cache_v = ( + outputs_full[3], + outputs_full[4], + ) # Precomputed cross-attention caches + for i in range(self.seq_len): + x_i = self.x[:, i : i + 1, :] + rotary_i = self.rotary_embedding[i : i + 1, :] + mask_i = self.decoder_attention_mask[:, i : i + 1] + outputs_i = self.decoder_block( + [ + x_i, + self.context, + cache_k, + cache_v, + x_attn_cache_k, + x_attn_cache_v, + rotary_i, + ], + use_cache=True, + decoder_attention_mask=mask_i, + encoder_attention_mask=self.encoder_attention_mask, + ) + x_i_out, cache_k, cache_v = outputs_i + x_auto.append(x_i_out) + x_auto = keras.ops.concatenate(x_auto, axis=1) + self.assertAllClose(x_full, x_auto, atol=1e-5) + + def test_serialization(self): + instance = MoonshineDecoderBlock(**self.init_kwargs) + self.run_serialization_test(instance=instance) diff --git a/keras_hub/src/models/moonshine/moonshine_encoder.py b/keras_hub/src/models/moonshine/moonshine_encoder.py new file mode 100644 index 0000000000..5e8838384e --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_encoder.py @@ -0,0 +1,213 @@ +import keras + +from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder +from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP +from keras_hub.src.models.moonshine.moonshine_layers import ( + moonshine_kernel_initializer, +) +from keras_hub.src.models.moonshine.moonshine_multi_head_attention import ( + MoonshineMultiHeadAttention, +) +from keras_hub.src.utils.keras_utils import clone_initializer + + +@keras.saving.register_keras_serializable(package="keras_hub") +class MoonshineEncoderBlock(TransformerEncoder): + """ + Moonshine encoder block for sequence processing. + + Implements a standard encoder block with self-attention and feedforward + sublayers, including residual connections and layer normalization. The + implementation utilizes Moonshine-specific attention and feedforward + mechanisms. + + Args: + hidden_dim: int. The dimensionality of the model's hidden + representations throughout the block. + intermediate_dim: int. The dimensionality used in projections before + applying non-linearities. + num_heads: int. The number of attention heads for multi-head attention + computation. + feedforward_expansion_factor: int, optional. A multiplier for expanding + the dimension in the feedforward network. Defaults to 4. + use_swiglu_activation: bool, optional. Whether to use SwiGLU activation + (True) or LinearGeLU (False) in the feedforward sublayer. Defaults + to False. + pad_head_dim_to_multiple_of: int, optional. If specified, pads the head + dimension to be a multiple of this value for hardware optimization. + Defaults to None. + initializer_range: float, optional. The standard deviation of the + truncated normal distribution used for weight initialization. + Defaults to 0.02. + attention_bias: bool, optional. Whether to use a bias term in the + attention mechanism. Defaults to False. + attention_dropout: float, optional. The dropout rate applied to the + attention weights. Defaults to 0.0. + dtype: str, optional. The data type to use for model computations and + weights. Defaults to None. + **kwargs: Additional keyword arguments passed to the base layer. + """ + + # References: + # Defined and formulated based on the UsefulSensors implementation of the + # EncoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L124-L161). + + def __init__( + self, + hidden_dim, + intermediate_dim, + num_heads, + feedforward_expansion_factor=4, + use_swiglu_activation=False, + pad_head_dim_to_multiple_of=None, + dtype=None, + initializer_range=0.02, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + kwargs.pop("dropout", None) + kwargs.pop("activation", None) + kwargs.pop("kernel_initializer", None) + self.kernel_initializer = moonshine_kernel_initializer( + initializer_range=initializer_range + ) + super().__init__( + intermediate_dim=intermediate_dim, + num_heads=num_heads, + dropout=attention_dropout, + activation="gelu" if use_swiglu_activation else "silu", + kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=dtype, + **kwargs, + ) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_heads = num_heads + self.feedforward_expansion_factor = feedforward_expansion_factor + self.use_swiglu_activation = use_swiglu_activation + + # Self-attention sublayers. + self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of + + self.head_dim = hidden_dim // num_heads + if pad_head_dim_to_multiple_of is not None: + self.head_dim = ( + (self.head_dim + pad_head_dim_to_multiple_of - 1) + // pad_head_dim_to_multiple_of + ) * pad_head_dim_to_multiple_of + + self.self_attention_layer = MoonshineMultiHeadAttention( + num_heads=num_heads, + key_dim=self.head_dim, + use_bias=False, + kernel_initializer=clone_initializer(self.kernel_initializer), + attention_bias=attention_bias, + attention_dropout=attention_dropout, + use_causal_mask=False, + apply_rotary_embedding=True, + cache_mode="none", + name="self_attention_layer", + dtype=self.dtype, + ) + self.self_attention_layer_norm = keras.layers.LayerNormalization( + axis=-1, + epsilon=1e-5, + center=False, + scale=True, + name="self_attention_layer_norm", + dtype=self.dtype, + ) + + # Feedforward sublayers. + self.feedforward_layer_norm = keras.layers.LayerNormalization( + axis=-1, + epsilon=1e-5, + center=False, + scale=True, + name="feedforward_layer_norm", + dtype=self.dtype, + ) + self.feedforward = MoonshineMLP( + hidden_dim=hidden_dim, + feedforward_expansion_factor=feedforward_expansion_factor, + use_swiglu_activation=use_swiglu_activation, + initializer_range=initializer_range, + name="feedforward", + dtype=self.dtype, + ) + + def build(self, input_shape): + if isinstance(input_shape, dict): + encoder_input_shape = input_shape["input_values"] + else: + encoder_input_shape = input_shape + # Build self-attention branch. + self.self_attention_layer_norm.build(encoder_input_shape) + self.self_attention_layer.build( + encoder_input_shape, encoder_input_shape, encoder_input_shape + ) + # Build feedforward branch. + self.feedforward_layer_norm.build(encoder_input_shape) + # The feedforward layer expects the last dimension to be hidden_dim. + feed_forward_input_shape = list(encoder_input_shape) + feed_forward_input_shape[-1] = self.hidden_dim + self.feedforward.build(tuple(feed_forward_input_shape)) + self.built = True + + def call( + self, + inputs, + rotary_embedding, + attention_mask=None, + training=None, + **kwargs, + ): + x = inputs + + # Self-attention block with residual connection. + attention_residual = x + x = self.self_attention_layer_norm(x) + x = self.self_attention_layer( + query=x, + value=x, + key=x, + rotary_embedding=rotary_embedding, + attention_mask=attention_mask, + training=training, + **kwargs, + ) + x = x + attention_residual + + # Feedforward block with residual connection. + ff_residual = x + x = self.feedforward_layer_norm(x) + x = self.feedforward(x) + x = x + ff_residual + + return x + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + # ==== Config ==== + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_heads": self.num_heads, + "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501 + "use_swiglu_activation": self.use_swiglu_activation, + "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of, + "initializer_range": self.initializer_range, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "dtype": self.dtype, + } + ) + return config diff --git a/keras_hub/src/models/moonshine/moonshine_layers.py b/keras_hub/src/models/moonshine/moonshine_layers.py new file mode 100644 index 0000000000..228e56af26 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_layers.py @@ -0,0 +1,315 @@ +import keras + +from keras_hub.src.utils.keras_utils import clone_initializer + + +def moonshine_kernel_initializer(initializer_range=0.02): + return keras.initializers.TruncatedNormal(stddev=initializer_range) + + +@keras.saving.register_keras_serializable(package="keras_hub") +class MoonshineRotaryEmbedding(keras.layers.Layer): + """ + Moonshine rotary embedding layer. + + Computes rotary positional embeddings using precomputed inverse frequencies. + Supports two RoPE types: "default" and "dynamic". + + - **Default RoPE**: Applies rotary embeddings to a fraction of dimensions + controlled by `partial_rotary_factor`. + - **Dynamic RoPE**: Updates frequencies dynamically based on sequence length + aligning functionally with the Hugging Face implementation. + + The layer stores inverse frequency weights as a non-trainable parameter and + computes sinusoidal embeddings based on input positions. Unlike KerasHub's + `RotaryEmbedding` class, this implementation explicitly requires `head_dim` + and applies `partial_rotary_factor` for selective rotary embedding, whereas + KerasHub uses `max_wavelength` without partial application. + + Args: + head_dim: int. The dimensionality of each attention head, determining + the feature space for rotary embeddings. + max_position_embeddings: int, optional. The maximum sequence length the + model can process, controlling the positional embedding scale. + Defaults to 2048. + base_value: float, optional. Base value for computing inverse + frequencies. Higher values result in longer wavelengths. Defaults to + 10000. + rope_scaling: dict, optional. Configuration for RoPE scaling, such as + `{"rope_type": "default"}` or `{"rope_type": "dynamic"}`. + Defaults to `{"rope_type": "default"}` if None. + partial_rotary_factor: float, optional. The fraction of `head_dim` + dimensions that receive rotary embeddings, balancing rotary and + non-rotary components. Defaults to 0.62. + dtype: string, optional. The data type for model computations and + weights. Defaults to None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + # References: + # Based on the UsefulSensors implementation of the RotaryEmbedding class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L176-L193). + # Incorporates dynamic RoPE concepts from the Hugging Face implementation (https://github.com/huggingface/transformers/blob/bc30dd1efb99f571d45b2e2131a555d09285ddd8/src/transformers/models/moonshine/modeling_moonshine.py#L311-L369). + + def __init__( + self, + head_dim, + max_position_embeddings=2048, + base_value=10000, + rope_scaling=None, + partial_rotary_factor=0.62, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.head_dim = head_dim + self.max_position_embeddings = max_position_embeddings + self.base_value = base_value + self.partial_rotary_factor = partial_rotary_factor + + if rope_scaling is None: + rope_scaling = {"rope_type": "default"} + self.rope_scaling = rope_scaling + self.rope_type = rope_scaling.get("rope_type", "default") + + if self.rope_type == "default": + self.scaling_factor = 1.0 + self.attention_scaling = 1.0 + elif "dynamic" in self.rope_type: + self.scaling_factor = 1.0 # Initial scaling, updated dynamically + self.attention_scaling = 1.0 + else: + raise NotImplementedError( + f"rope_type '{self.rope_type}' not implemented" + ) + + def build(self, input_shape): + # Create and track the non-trainable weight immediately. + rotary_dim = int(self.head_dim * self.partial_rotary_factor) + rotary_dim = (rotary_dim // 2) * 2 + if rotary_dim <= 0: + raise ValueError( + f"Calculated rotary_dim ({rotary_dim}) must be a positive even " + f"number. Check head_dim ({self.head_dim}) and " + f"partial_rotary_factor ({self.partial_rotary_factor})." + ) + rotary_dim_half = rotary_dim // 2 + + # Compute inv_freq. + inv_freq = 1.0 / ( + self.base_value + ** ( + keras.ops.arange(0, rotary_dim_half, dtype=self.dtype) + / rotary_dim_half + ) + ) + inv_freq = inv_freq * self.scaling_factor + + # Set the non-trainable weight using the computed tensor. + self.inv_freq = self.add_weight( + name="inv_freq", + shape=(rotary_dim_half,), + initializer=keras.initializers.Constant(inv_freq), + trainable=False, + dtype=self.dtype, + ) + self.original_inv_freq = keras.ops.convert_to_tensor(self.inv_freq) + self.max_sequence_length_cached = self.max_position_embeddings + self.built = True + + def call(self, t, position_ids=None): + # "Dynamic" RoPE behavior. + if "dynamic" in self.rope_type: + if position_ids is None: + position_ids = keras.ops.expand_dims(t, axis=0) + seq_len = keras.ops.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + scaling = keras.ops.cast( + self.max_position_embeddings, self.dtype + ) / keras.ops.cast(seq_len, self.dtype) + else: + scaling = keras.ops.cast(1.0, self.dtype) + current_inv_freq = self.original_inv_freq * scaling + if seq_len > self.max_sequence_length_cached: + self.max_sequence_length_cached = seq_len + elif ( + seq_len < self.max_position_embeddings + and self.max_sequence_length_cached + > self.max_position_embeddings + ): + self.max_sequence_length_cached = self.max_position_embeddings + + pos_cast = keras.ops.cast(position_ids, self.dtype) + freqs = pos_cast[:, :, None] * current_inv_freq[None, None, :] + cos = keras.ops.cos(freqs) * self.attention_scaling + sin = keras.ops.sin(freqs) * self.attention_scaling + return cos, sin + # Original "default" behavior. + else: + t_cast = keras.ops.cast(t, keras.ops.dtype(self.inv_freq)) + original_shape = keras.ops.shape(t_cast) + is_generation_step = ( + len(original_shape) == 2 and original_shape[1] == 1 + ) + if is_generation_step: + # (batch, 1) -> Squeeze to (batch,) for einsum "i,j->ij". + t_cast_for_einsum = keras.ops.squeeze(t_cast, axis=1) + freqs = keras.ops.einsum( + "i,j->ij", t_cast_for_einsum, self.inv_freq + ) # Shape (batch, rotary_dim_half) + elif len(original_shape) == 1: + t_cast_for_einsum = t_cast + freqs = keras.ops.einsum( + "i,j->ij", t_cast_for_einsum, self.inv_freq + ) # Shape (seq_len, rotary_dim_half) + else: + raise ValueError( + f"Unexpected shape for input 't' in " + f"MoonshineRotaryEmbedding default path: {original_shape}. " + "Expected (seq_len,) or (batch, 1)." + ) + emb = keras.ops.stack((freqs, freqs), axis=-1) + shape_list = list(keras.ops.shape(emb)) + shape_list[-2:] = [-1] + emb_flat = keras.ops.reshape(emb, shape_list) + if is_generation_step: + final_emb = keras.ops.expand_dims(emb_flat, axis=1) + else: + final_emb = keras.ops.expand_dims(emb_flat, axis=0) + return final_emb + + def get_config(self): + config = super().get_config() + config.update( + { + "head_dim": self.head_dim, + "max_position_embeddings": self.max_position_embeddings, + "base_value": self.base_value, + "rope_scaling": self.rope_scaling, + "partial_rotary_factor": self.partial_rotary_factor, + "dtype": self.dtype, + } + ) + return config + + +@keras.saving.register_keras_serializable(package="keras_hub") +class MoonshineMLP(keras.layers.Layer): + """ + Moonshine MLP layer. + + Implements a Multi-Layer Perceptron (MLP) for Moonshine models with support + for both `SwiGLU` and `LinearGeLU` activation patterns. The MLP consists of + two dense layers with an activation function in between, expanding the input + dimension before projecting back to the original dimension. + + Args: + hidden_dim: int. The dimensionality of the input and output tensors. + feedforward_expansion_factor: float. The factor by which to expand the + hidden dimension in the intermediate layer. + use_swiglu_activation: bool, optional. If `True`, uses SwiGLU activation + (SiLU with gating). If `False`, uses standard GeLU activation. + Defaults to `True`. + initializer_range: float, optional. The standard deviation for kernel + initialization. Defaults to 0.02. + dtype: string, optional. The data type for model computations and + weights. Defaults to `None`. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + # References: + # Based on the HuggingFace implementation of the MoonshineEncoderMLP and + # MoonshineDecoderMLP classes (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L66-L94). + + def __init__( + self, + hidden_dim, + feedforward_expansion_factor, + use_swiglu_activation=True, + initializer_range=0.02, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.hidden_dim = hidden_dim + self.feedforward_expansion_factor = feedforward_expansion_factor + self.use_swiglu_activation = use_swiglu_activation + self.kernel_initializer = moonshine_kernel_initializer( + initializer_range=initializer_range + ) + self.initializer_range = initializer_range + + if use_swiglu_activation: + # First dense layer produces (2 * feedforward_expansion_factor * + # hidden_dim) outputs. + self.dense_1 = keras.layers.Dense( + int(hidden_dim * feedforward_expansion_factor * 2), + use_bias=True, + name="dense_1", + dtype=self.dtype, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + # Activation layer using "silu" (Swish activation). + self.activation = keras.layers.Activation( + "silu", name="activation", dtype=self.dtype + ) + else: + # Taken from pretrained weights. + # First dense layer: output dimension is (hidden_dim * + # feedforward_expansion_factor). + self.dense_1 = keras.layers.Dense( + int(hidden_dim * feedforward_expansion_factor), + use_bias=True, + name="dense_1", + dtype=self.dtype, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + self.activation = keras.layers.Activation( + "gelu", name="activation", dtype=self.dtype + ) + + # Second dense layer projects back to hidden_dim. + self.dense_2 = keras.layers.Dense( + hidden_dim, + use_bias=True, + name="dense_2", + dtype=self.dtype, + kernel_initializer=clone_initializer(self.kernel_initializer), + ) + + def build(self, input_shape): + super().build(input_shape) + # Build the first dense layer using the original input shape. + self.dense_1.build(input_shape) + # After dense_1, the output shape becomes: (..., 2 * + # feedforward_expansion_factor * hidden_dim). + # When splitting, each part will have shape (..., + # feedforward_expansion_factor * hidden_dim). + new_input_shape = list(input_shape) + new_input_shape[-1] = ( + self.hidden_dim * self.feedforward_expansion_factor + ) + self.dense_2.build(tuple(new_input_shape)) + + def call(self, inputs): + x = self.dense_1(inputs) + if self.use_swiglu_activation: + x1, gate = keras.ops.split(x, 2, axis=-1) + activated_gate = self.activation(gate) + x = x1 * activated_gate + else: + x = self.activation(x) + output = self.dense_2(x) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501 + "use_swiglu_activation": self.use_swiglu_activation, + "initializer_range": self.initializer_range, + "dtype": self.dtype, + } + ) + return config diff --git a/keras_hub/src/models/moonshine/moonshine_layers_test.py b/keras_hub/src/models/moonshine/moonshine_layers_test.py new file mode 100644 index 0000000000..c8880ddb0e --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_layers_test.py @@ -0,0 +1,92 @@ +import keras + +from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP +from keras_hub.src.models.moonshine.moonshine_layers import ( + MoonshineRotaryEmbedding, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineLayersTest(TestCase): + def test_moonshine_rotary_embedding(self): + layer = MoonshineRotaryEmbedding( + head_dim=64, + max_position_embeddings=2048, + base_value=10000, + rope_scaling=None, + partial_rotary_factor=0.62, + dtype="float32", + ) + input_data = keras.ops.arange(10, dtype="float32") + output_data = layer(input_data) + expected_output_shape = (1, 10, 38) + self.assertEqual(keras.ops.shape(output_data), expected_output_shape) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.non_trainable_weights), 1) + self.assertEqual(len(layer.non_trainable_variables), 1) + + def test_moonshine_rotary_embedding_dynamic(self): + layer = MoonshineRotaryEmbedding( + head_dim=64, + max_position_embeddings=10, + base_value=10000, + rope_scaling={"rope_type": "dynamic"}, + partial_rotary_factor=1.0, + ) + # Compute original inverse frequencies. + rotary_dim = 32 # Derived from head_dim = 64, partial_rotary_factor = 1 + arange = keras.ops.arange(0, rotary_dim, dtype="float32") + original_inv_freq = 1.0 / (10000 ** (arange / rotary_dim)) + + # seq_len = 5 < 10. + position_ids = keras.ops.arange(5, dtype="int32")[None, :] # [1, 5] + cos1, sin1 = layer(None, position_ids=position_ids) # [1, 5, 32] + expected_cos1 = keras.ops.cos(original_inv_freq) + expected_sin1 = keras.ops.sin(original_inv_freq) + self.assertAllClose(cos1[0, 1, :], expected_cos1, rtol=1e-5) + self.assertAllClose(sin1[0, 1, :], expected_sin1, rtol=1e-5) + + # seq_len = 15 > 10. + position_ids = keras.ops.arange(15, dtype="int32")[None, :] # [1, 15] + cos2, sin2 = layer(None, position_ids=position_ids) # [1, 15, 32] + scaling = 10 / 15 # 2 / 3 + expected_cos2 = keras.ops.cos(original_inv_freq * scaling) + expected_sin2 = keras.ops.sin(original_inv_freq * scaling) + self.assertAllClose(cos2[0, 1, :], expected_cos2, rtol=1e-5) + self.assertAllClose(sin2[0, 1, :], expected_sin2, rtol=1e-5) + + # seq_len = 8 < 10, should reset. + position_ids = keras.ops.arange(8, dtype="int32")[None, :] # [1, 8] + cos3, sin3 = layer(None, position_ids=position_ids) # [1, 8, 32] + self.assertAllClose(cos3[0, 1, :], expected_cos1, rtol=1e-5) + self.assertAllClose(sin3[0, 1, :], expected_sin1, rtol=1e-5) + + def test_moonshine_mlp_swiglu(self): + self.run_layer_test( + cls=MoonshineMLP, + init_kwargs={ + "hidden_dim": 64, + "feedforward_expansion_factor": 4, + "use_swiglu_activation": True, + }, + input_data=keras.random.uniform((2, 10, 64), dtype="float32"), + expected_output_shape=(2, 10, 64), + expected_num_trainable_weights=4, + expected_num_non_trainable_weights=0, + run_precision_checks=False, + ) + + def test_moonshine_mlp_linear_gelu(self): + self.run_layer_test( + cls=MoonshineMLP, + init_kwargs={ + "hidden_dim": 64, + "feedforward_expansion_factor": 4, + "use_swiglu_activation": False, + }, + input_data=keras.random.uniform((2, 10, 64), dtype="float32"), + expected_output_shape=(2, 10, 64), + expected_num_trainable_weights=4, + expected_num_non_trainable_weights=0, + run_precision_checks=False, + ) diff --git a/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py b/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py new file mode 100644 index 0000000000..072af883da --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_multi_head_attention.py @@ -0,0 +1,401 @@ +import keras +from keras import backend + +from keras_hub.src.layers.modeling.cached_multi_head_attention import ( + CachedMultiHeadAttention, +) +from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import ( + _build_proj_equation, +) +from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import ( + _get_output_shape, +) + + +# Removed dependence on einops. +# Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L35 +def _rotate_half(x): + """ + Rotates the two halves of the last dimension. + + This function splits the last dimension of the input tensor into two equal + halves and swaps them with a sign inversion. Specifically, for an input of + shape `[..., 2*d]`, it returns a tensor of the same shape where `[x1, x2]` + is transformed into `[-x2, x1]`. + + Args: + x: Tensor. Shape `[..., 2*d]`. The input tensor to be rotated. + + Returns: + Tensor: A tensor of shape `[..., 2*d]` with the two halves rotated. + """ + # Conditional for Tensorflow backend. + if backend.backend() == "tensorflow": + x_shape = keras.ops.shape(x) + last_dim = x_shape[-1] + d = last_dim // 2 + x_shape_tensor = keras.ops.convert_to_tensor(x_shape) + new_shape = keras.ops.concatenate( + [x_shape_tensor[:-1], keras.ops.convert_to_tensor([d, 2])], axis=0 + ) + x = keras.ops.reshape(x, new_shape) + x1 = x[..., 0] + x2 = x[..., 1] + x_rotated = keras.ops.stack([-x2, x1], axis=-1) + x_rotated = keras.ops.reshape(x_rotated, x_shape) + return x_rotated + + # Conditional for PyTorch and JAX backends. + if backend.backend() == "torch" or backend.backend() == "jax": + x_shape = keras.ops.shape(x) + x_shape_tuple = tuple( + int(keras.ops.convert_to_numpy(dim).item()) for dim in x_shape + ) + last_dim = x_shape_tuple[-1] + d = last_dim // 2 + new_shape = x_shape_tuple[:-1] + (d, 2) + x = keras.ops.reshape(x, new_shape) + x1 = x[..., 0] + x2 = x[..., 1] + x_rotated = keras.ops.stack([-x2, x1], axis=-1) + x_rotated = keras.ops.reshape(x_rotated, x_shape_tuple) + return x_rotated + + else: + raise NotImplementedError( + "Backend not supported. Please use TensorFlow, PyTorch, or JAX." + ) + + +def _apply_rotary_pos_emb(t, freqs): + """ + Applies rotary positional embeddings to the input tensor. Used in on-the-fly + computation of rotary positional embeddings in multi-head attention layers. + + Args: + t: A tensor with shape `[..., seq_len, ..., hidden_dim]` where the + rotary embedding is applied to the first `rot_dim` channels of the + last dimension. + freqs: A tensor of frequency values with shape `[max_seq_len, rot_dim]`. + The last `seq_len` entries are used to compute the rotary + embeddings. + + Returns: + Tensor: A tensor of the same shape as `t` with the rotary positional + embeddings applied to the first `rot_dim` channels of the last dimension + and the remaining channels concatenated unchanged. + """ + rot_dim = keras.ops.shape(freqs)[-1] + seq_len = keras.ops.shape(t)[-3] + orig_dtype = t.dtype + freqs = freqs[:seq_len, :] + freqs = keras.ops.reshape(freqs, (seq_len, 1, rot_dim)) + t_rot = t[..., :rot_dim] + t_nonrot = t[..., rot_dim:] + t_rotated = t_rot * keras.ops.cos(freqs) + _rotate_half( + t_rot + ) * keras.ops.sin(freqs) + out = keras.ops.concatenate([t_rotated, t_nonrot], axis=-1) + return keras.ops.cast(out, orig_dtype) + + +@keras.saving.register_keras_serializable(package="keras_hub") +class MoonshineMultiHeadAttention(CachedMultiHeadAttention): + """ + Moonshine multi-head attention layer. + + Implements a multi-head attention mechanism for Moonshine models with + support for rotary position embeddings and different caching strategies. + This layer extends the `CachedMultiHeadAttention` base class to include + specialized functionality for Moonshine models, such as rotary embeddings + and causal masking. + + Args: + num_heads: int. Number of attention heads. + key_dim: int. Size of each attention head for key. + value_dim: int, optional. Size of each attention head for value. If + None, defaults to `key_dim`. + attention_bias: bool, optional. Whether to include bias in attention + projection layers. Defaults to `False`. + attention_dropout: float, optional. Dropout probability for attention + weights. Defaults to 0.0. + use_causal_mask: bool, optional. Whether to apply causal masking to + prevent positions from attending to subsequent positions. Defaults + to `False`. + apply_rotary_embedding: bool, optional. Whether to apply rotary position + embeddings to queries and keys. Defaults to `True`. + cache_mode: str, optional. Mode for key-value caching. Must be one of: + 'none': No caching. + 'autoregressive': Incremental caching for autoregressive generation. + 'precomputed': Use precomputed key-value pairs. Defaults to None. + **kwargs: Additional keyword arguments passed to the parent class. + """ + + # References: + # Based on the HuggingFace implementation of the MoonshineAttention class (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L184-L315). + + def __init__( + self, + num_heads, + key_dim, + value_dim=None, + attention_bias=False, + attention_dropout=0.0, + use_causal_mask=False, + apply_rotary_embedding=True, + cache_mode="none", + **kwargs, + ): + kwargs.pop("use_bias", None) + kwargs.pop("dropout", None) + super().__init__( + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + use_bias=attention_bias, + dropout=attention_dropout, + **kwargs, + ) + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_causal_mask = use_causal_mask + self.apply_rotary_embedding = apply_rotary_embedding + if cache_mode not in ["none", "autoregressive", "precomputed"]: + raise ValueError( + "cache_mode must be 'none', 'autoregressive', or 'precomputed'" + ) + self.cache_mode = cache_mode + + def build(self, query_shape, value_shape, key_shape=None): + # Ensure key_shape is defined. + key_shape = value_shape if key_shape is None else key_shape + query_rank = len(query_shape) + value_rank = len(value_shape) + key_rank = len(key_shape) + + # Build query projection layer. + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims=query_rank - 1, bound_dims=1, output_dims=2 + ) + self._query_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="query", + **self._get_common_kwargs_for_sublayer(), + ) + self._query_dense.build(query_shape) + + # Build key projection layer. + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims=key_rank - 1, bound_dims=1, output_dims=2 + ) + self._key_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="key", + **self._get_common_kwargs_for_sublayer(), + ) + self._key_dense.build(key_shape) + + # Build value projection layer. + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims=value_rank - 1, bound_dims=1, output_dims=2 + ) + self._value_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._value_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="value", + **self._get_common_kwargs_for_sublayer(), + ) + self._value_dense.build(value_shape) + + # Build the internal attention computation sublayer. + self._build_attention(output_rank) + + # Build output projection layer. + output_shape = ( + query_shape[-1] if not self._output_shape else self._output_shape + ) + if isinstance(output_shape, (list, tuple)): + output_shape = list(output_shape) + else: + output_shape = [output_shape] + + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims=query_rank - 1, + bound_dims=2, + output_dims=len(output_shape), + ) + self._output_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, output_shape), + bias_axes=bias_axes if self._use_bias else None, + name="attention_output", + **self._get_common_kwargs_for_sublayer(), + ) + output_dense_input_shape = list( + self._query_dense.compute_output_shape(query_shape) + ) + output_dense_input_shape[-1] = self._value_dim + self._output_dense.build(tuple(output_dense_input_shape)) + + self.built = True + + def _compute_causal_mask(self, query, value=None, for_cache=False): + if backend.backend() == "torch" or backend.backend() == "jax": + q_seq_length = int( + keras.ops.convert_to_numpy(keras.ops.shape(query)[1]).item() + ) + v_seq_length = ( + int( + keras.ops.convert_to_numpy(keras.ops.shape(value)[1]).item() + ) + if value is not None + else q_seq_length + ) + elif backend.backend() == "tensorflow": + if for_cache: + assert value is not None + v_seq_length = keras.ops.shape(value)[1] + else: + v_seq_length = keras.ops.shape(query)[1] + q_seq_length = keras.ops.shape(query)[1] + n_rows = v_seq_length if for_cache else q_seq_length + ones_mask = keras.ops.ones((1, n_rows, v_seq_length), dtype="int32") + row_index = keras.ops.cumsum(ones_mask, axis=-2) + col_index = keras.ops.cumsum(ones_mask, axis=-1) + mask = keras.ops.greater_equal(row_index, col_index) + + if for_cache: + mask = mask[:, -q_seq_length:, :] + + return mask + + def call( + self, + query, + value, + key, + rotary_embedding=None, + attention_mask=None, + key_cache=None, + value_cache=None, + training=None, + **kwargs, + ): + # Project inputs. + query_proj = self._query_dense(query) + if rotary_embedding is not None: + query_proj = _apply_rotary_pos_emb(query_proj, rotary_embedding) + + # Handle caching. + if self.cache_mode == "none": + key_proj = self._key_dense(key) + value_proj = self._value_dense(value) + if self.apply_rotary_embedding and rotary_embedding is not None: + key_proj = _apply_rotary_pos_emb(key_proj, rotary_embedding) + final_key = key_proj + final_value = value_proj + elif self.cache_mode == "autoregressive": + if key_cache is None and value_cache is not None: + raise ValueError( + "key_cache must be provided if value_cache is provided" + ) + new_key = self._key_dense(key) + new_value = self._value_dense(value) + if self.apply_rotary_embedding and rotary_embedding is not None: + new_key = _apply_rotary_pos_emb(new_key, rotary_embedding) + if key_cache is not None and value_cache is not None: + final_key = keras.ops.concatenate((key_cache, new_key), axis=-3) + final_value = keras.ops.concatenate( + (value_cache, new_value), axis=-3 + ) + else: + final_key = new_key + final_value = new_value + elif self.cache_mode == "precomputed": + if key_cache is None and value_cache is not None: + raise ValueError( + "key_cache must be provided if value_cache is provided" + ) + if key_cache is not None and value_cache is not None: + final_key = key_cache + final_value = value_cache + else: + final_key = self._key_dense(key) + final_value = self._value_dense(value) + else: + raise ValueError(f"Invalid cache_mode: {self.cache_mode}") + + # Compute attention mask. + if self.use_causal_mask: + causal_mask = self._compute_causal_mask( + query, + final_value if self.cache_mode == "autoregressive" else None, + for_cache=( + self.cache_mode == "autoregressive" + and key_cache is not None + ), + ) + # Combine with attention_mask if provided. + if attention_mask is not None: + # [batch_size, seq_len_k] → [batch_size, 1, 1, seq_len_k]. + attention_mask_expanded = keras.ops.expand_dims( + attention_mask, axis=1 + ) + attention_mask_expanded = keras.ops.expand_dims( + attention_mask_expanded, axis=-1 + ) + final_mask = keras.ops.logical_and( + causal_mask, attention_mask_expanded + ) + else: + final_mask = causal_mask + else: + if attention_mask is not None: + if self.cache_mode == "none": + seq_len = keras.ops.shape(query)[1] + final_mask = keras.ops.tile( + attention_mask[:, None, :], [1, seq_len, 1] + ) + elif self.cache_mode == "precomputed": + final_mask = attention_mask[:, None, None, :] + else: # Autoregressive + final_mask = attention_mask + else: + final_mask = None + + attention_kwargs = { + k: v for k, v in kwargs.items() if k != "padding_mask" + } + # Compute attention. + attention_output, _ = self._compute_attention( + query=query_proj, + key=final_key, + value=final_value, + attention_mask=final_mask, + training=training, + **attention_kwargs, + ) + + # Project the attention output. + output = self._output_dense(attention_output) + + # Return based on cache_mode. + if self.cache_mode == "none": + return output + elif self.cache_mode == "autoregressive": + return output, final_key, final_value + elif self.cache_mode == "precomputed": + if key_cache is not None and value_cache is not None: + return output + return output, final_key, final_value diff --git a/keras_hub/src/models/moonshine/moonshine_multi_head_attention_test.py b/keras_hub/src/models/moonshine/moonshine_multi_head_attention_test.py new file mode 100644 index 0000000000..2d15cba0c4 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_multi_head_attention_test.py @@ -0,0 +1,160 @@ +import keras + +from keras_hub.src.models.moonshine.moonshine_multi_head_attention import ( + MoonshineMultiHeadAttention, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineMultiHeadAttentionTest(TestCase): + def setUp(self): + super().setUp() + self.num_heads = 4 + self.key_dim = 16 + self.hidden_dim = self.num_heads * self.key_dim + self.init_kwargs = { + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "value_dim": None, + "attention_bias": False, + "attention_dropout": 0.0, + "use_causal_mask": False, + "apply_rotary_embedding": True, + "cache_mode": "none", + } + self.attention_layer = MoonshineMultiHeadAttention(**self.init_kwargs) + self.batch_size = 2 + self.query_seq_len = 10 + self.key_seq_len = 16 + self.rotary_dim = int( + self.key_dim * 0.62 + ) # Default partial_rotary_factor = 0.62 + self.rotary_dim = (self.rotary_dim // 2) * 2 # Ensure even + self.rotary_dim = self.rotary_dim // 2 # Half for freqs, e.g., 4 + self.query = keras.random.normal( + (self.batch_size, self.query_seq_len, self.hidden_dim) + ) + self.key = keras.random.normal( + (self.batch_size, self.key_seq_len, self.hidden_dim) + ) + self.value = self.key # For testing purposes + self.rotary_embedding = keras.random.normal( + (self.query_seq_len, self.rotary_dim) + ) + self.attention_mask = keras.ops.ones( + (self.batch_size, self.key_seq_len), dtype="bool" + ) + + def test_initialization(self): + self.assertEqual(self.attention_layer.num_heads, self.num_heads) + self.assertEqual(self.attention_layer.key_dim, self.key_dim) + self.assertFalse(self.attention_layer.attention_bias) + self.assertTrue(self.attention_layer.apply_rotary_embedding) + + def test_forward_pass_without_caching(self): + self.attention_layer.apply_rotary_embedding = ( + False # Test cross-attention + ) + output = self.attention_layer( + query=self.query, + key=self.key, + value=self.value, + rotary_embedding=self.rotary_embedding, + attention_mask=self.attention_mask, + ) + self.assertEqual( + output.shape, (self.batch_size, self.query_seq_len, self.hidden_dim) + ) + + def test_precomputed_caching(self): + self.attention_layer.build( + query_shape=(self.batch_size, self.query_seq_len, self.hidden_dim), + value_shape=(self.batch_size, self.key_seq_len, self.hidden_dim), + key_shape=(self.batch_size, self.key_seq_len, self.hidden_dim), + ) + self.attention_layer.cache_mode = "precomputed" + self.attention_layer.apply_rotary_embedding = False + key_proj = self.attention_layer._key_dense(self.key) + value_proj = self.attention_layer._value_dense(self.value) + output_precomputed = self.attention_layer( + query=self.query, + key=None, + value=None, + key_cache=key_proj, + value_cache=value_proj, + rotary_embedding=self.rotary_embedding, + attention_mask=self.attention_mask, + ) + self.attention_layer.cache_mode = "none" + output_normal = self.attention_layer( + query=self.query, + key=self.key, + value=self.value, + rotary_embedding=self.rotary_embedding, + attention_mask=self.attention_mask, + ) + self.assertEqual( + output_precomputed.shape, + (self.batch_size, self.query_seq_len, self.hidden_dim), + ) + self.assertAllClose(output_precomputed, output_normal, atol=1e-5) + + def test_autoregressive_caching(self): + self.attention_layer.cache_mode = "autoregressive" + self.attention_layer.use_causal_mask = True # Ensure causal attention + cache_k, cache_v = None, None + outputs_auto = [] + for i in range(self.query_seq_len): + query_i = self.query[:, i : i + 1, :] + key_i = self.query[:, i : i + 1, :] # Self-attention + value_i = self.query[:, i : i + 1, :] + rotary_i = self.rotary_embedding[i : i + 1, :] + output_i, new_cache_k, new_cache_v = self.attention_layer( + query=query_i, + key=key_i, + value=value_i, + rotary_embedding=rotary_i, + key_cache=cache_k, + value_cache=cache_v, + ) + outputs_auto.append(output_i) + self.assertEqual( + output_i.shape, (self.batch_size, 1, self.hidden_dim) + ) + self.assertEqual( + new_cache_k.shape, + (self.batch_size, i + 1, self.num_heads, self.key_dim), + ) + self.assertEqual( + new_cache_v.shape, + (self.batch_size, i + 1, self.num_heads, self.key_dim), + ) + cache_k, cache_v = new_cache_k, new_cache_v + outputs_auto = keras.ops.concatenate(outputs_auto, axis=1) + self.attention_layer.cache_mode = "none" + self.attention_layer.use_causal_mask = ( + True # Consistent with autoregressive + ) + output_full = self.attention_layer( + query=self.query, + key=self.query, + value=self.query, + rotary_embedding=self.rotary_embedding, + ) + self.assertAllClose(outputs_auto, output_full, atol=1e-5) + + def test_forward_pass_with_causal_mask(self): + self.attention_layer.use_causal_mask = True + output = self.attention_layer( + query=self.query, + key=self.query, # Self-attention for causal test + value=self.query, + rotary_embedding=self.rotary_embedding, + ) + self.assertEqual( + output.shape, (self.batch_size, self.query_seq_len, self.hidden_dim) + ) + + def test_serialization(self): + instance = MoonshineMultiHeadAttention(**self.init_kwargs) + self.run_serialization_test(instance=instance) diff --git a/keras_hub/src/models/moonshine/moonshine_presets.py b/keras_hub/src/models/moonshine/moonshine_presets.py new file mode 100644 index 0000000000..b394c454c0 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_presets.py @@ -0,0 +1,25 @@ +# Metadata for loading pretrained model weights. +backbone_presets = { + "moonshine_tiny_en": { + "metadata": { + "description": ( + "Moonshine tiny model for English speech recognition. " + "Developed by Useful Sensors for real-time transcription." + ), + "params": 27092736, + "path": "moonshine", + }, + "kaggle_handle": "", + }, + "moonshine_base_en": { + "metadata": { + "description": ( + "Moonshine base model for English speech recognition. " + "Developed by Useful Sensors for real-time transcription." + ), + "params": 61513920, + "path": "moonshine", + }, + "kaggle_handle": "", + }, +} diff --git a/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor.py new file mode 100644 index 0000000000..2500505bf8 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor.py @@ -0,0 +1,177 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.utils.tensor_utils import preprocessing_function +from keras_hub.src.utils.tensor_utils import strip_to_ragged + + +@keras_hub_export("keras_hub.models.MoonshineSeq2SeqLMPreprocessor") +class MoonshineSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor): + """Moonshine Seq2Seq LM preprocessor for audio-to-text tasks. + + This preprocessor converts raw audio and text inputs into a format suitable + for the `MoonshineAudioToText` model. It processes audio waveforms into + features using `MoonshineAudioConverter` for the encoder and tokenizes text + using `MoonshineTokenizer` for the decoder. It supports training and + generation. + + Args: + audio_converter: A `MoonshineAudioConverter` instance to process audio. + tokenizer: A `MoonshineTokenizer` instance to tokenize text. + encoder_sequence_length: int, optional. Maximum length for audio + features. If None, features are variable-length with padding masks. + Defaults to None. + decoder_sequence_length: int, optional. Maximum length for decoder token + sequences. Defaults to 1024. + **kwargs: Additional keyword arguments for the parent class. + + Examples: + ```python + # Create audio converter and tokenizer instances. + audio_converter = keras_hub.models.MoonshineAudioConverter() + tokenizer = keras_hub.models.MoonshineTokenizer.from_preset( + "moonshine_base" + ) + + # Initialize the preprocessor. + preprocessor = keras_hub.models.MoonshineSeq2SeqLMPreprocessor( + audio_converter=audio_converter, + tokenizer=tokenizer, + decoder_sequence_length=8 + ) + + # Prepare input data (audio tensor and text). + inputs = { + "audio": keras.random.normal((1, 16000, 1)), + "text": ["the quick brown fox"] + } + + # Process the inputs. + preprocessed = preprocessor(inputs) + """ + + backbone_cls = MoonshineBackbone + tokenizer_cls = MoonshineTokenizer + + def __init__( + self, + audio_converter, + tokenizer, + encoder_sequence_length=None, + decoder_sequence_length=1024, + **kwargs, + ): + super().__init__(tokenizer=tokenizer, **kwargs) + self.audio_converter = audio_converter + self.encoder_sequence_length = encoder_sequence_length + self.decoder_sequence_length = decoder_sequence_length + self.encoder_packer = None + self.decoder_packer = None + + def build(self, input_shape): + self.audio_converter.build(input_shape) + self.decoder_packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.decoder_sequence_length, + return_padding_mask=True, + ) + self.built = True + + @preprocessing_function + def call( + self, + x, + y=None, + sample_weight=None, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if not self.built: + self.build(None) + if isinstance(x, tuple) and len(x) == 1: + x = x[0] + encoder_sequence_length = ( + encoder_sequence_length or self.encoder_sequence_length + ) + decoder_sequence_length = ( + decoder_sequence_length + or sequence_length + or self.decoder_sequence_length + ) + text = x["text"] + audio_features = self.audio_converter(x["audio"]) + encoder_inputs = audio_features["input_values"] + encoder_padding_mask = audio_features["attention_mask"] + decoder_inputs = self.tokenizer(text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_inputs, + sequence_length=decoder_sequence_length + 1, + ) + x_out = { + "encoder_input_values": encoder_inputs, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids[..., :-1], + "decoder_padding_mask": decoder_padding_mask[..., :-1], + } + y_out = decoder_token_ids[..., 1:] + sample_weight_out = decoder_padding_mask[..., 1:] + + return keras.utils.pack_x_y_sample_weight( + x_out, y_out, sample_weight_out + ) + + @preprocessing_function + def generate_preprocess( + self, + x, + encoder_sequence_length=None, + decoder_sequence_length=None, + sequence_length=None, + ): + if not self.built: + self.build(None) + if isinstance(x, tuple) and len(x) == 1: + x = x[0] + decoder_sequence_length = ( + decoder_sequence_length + or sequence_length + or self.decoder_sequence_length + ) + audio_features = self.audio_converter(x["audio"]) + encoder_token_ids = audio_features["input_values"] + encoder_padding_mask = audio_features["attention_mask"] + decoder_text = x.get("text", [""] * keras.ops.shape(x["audio"])[0]) + decoder_token_ids = self.tokenizer(decoder_text) + decoder_token_ids, decoder_padding_mask = self.decoder_packer( + decoder_token_ids, + sequence_length=decoder_sequence_length, + add_end_value=False, + ) + + return { + "encoder_input_values": encoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_token_ids": decoder_token_ids, + "decoder_padding_mask": decoder_padding_mask, + } + + @preprocessing_function + def generate_postprocess(self, x): + if not self.built: + self.build(None) + token_ids, padding_mask = ( + x["decoder_token_ids"], + x["decoder_padding_mask"], + ) + ids_to_strip = self.tokenizer.special_token_ids + token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor_test.py b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor_test.py new file mode 100644 index 0000000000..1f2f1c9dd2 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_seq_2_seq_lm_preprocessor_test.py @@ -0,0 +1,87 @@ +import os + +import keras +import pytest + +from keras_hub.src.models.moonshine.moonshine_audio_converter import ( + MoonshineAudioConverter, +) +from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import ( + MoonshineSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineSeq2SeqLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = MoonshineTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "moonshine_test_vocab.spm" + ) + ) + self.audio_converter = MoonshineAudioConverter(filter_dim=32) + self.init_kwargs = { + "audio_converter": self.audio_converter, + "tokenizer": self.tokenizer, + "encoder_sequence_length": None, + "decoder_sequence_length": 8, + } + self.input_data = ( + { + "audio": keras.random.normal((1, 16000, 1)), + "text": ["the quick brown fox"], + }, + ) + + def test_preprocessor_basics(self): + preprocessor = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs) + output = preprocessor.call(self.input_data) + x_out, y_out, sample_weight_out = output + self.assertIn("encoder_input_values", x_out) + self.assertIn("encoder_padding_mask", x_out) + self.assertIn("decoder_token_ids", x_out) + self.assertIn("decoder_padding_mask", x_out) + self.assertAllEqual( + keras.ops.shape(x_out["encoder_input_values"]), (1, 40, 32) + ) + self.assertAllEqual( + keras.ops.shape(x_out["encoder_padding_mask"]), (1, 40) + ) + self.assertAllEqual(keras.ops.shape(x_out["decoder_token_ids"]), (1, 8)) + self.assertAllEqual( + keras.ops.shape(x_out["decoder_padding_mask"]), (1, 8) + ) + self.assertAllEqual(keras.ops.shape(y_out), (1, 8)) + self.assertAllEqual(keras.ops.shape(sample_weight_out), (1, 8)) + + def test_generate_preprocess(self): + preprocessor = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs) + output = preprocessor.generate_preprocess(self.input_data) + self.assertIn("encoder_input_values", output) + self.assertAllClose(output["decoder_token_ids"].shape, [1, 8]) + + def test_generate_postprocess(self): + preprocessor = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs) + input_data = { + "decoder_token_ids": keras.ops.ones((1, 5), dtype="int32"), + "decoder_padding_mask": keras.ops.ones((1, 5)), + } + output = preprocessor.generate_postprocess(input_data) + self.assertIsInstance(output, list) + self.assertIsInstance(output[0], str) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MoonshineSeq2SeqLMPreprocessor.presets: + self.run_preset_test( + cls=MoonshineSeq2SeqLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) + + def test_serialization(self): + instance = MoonshineSeq2SeqLMPreprocessor(**self.init_kwargs) + self.run_serialization_test(instance=instance) diff --git a/keras_hub/src/models/moonshine/moonshine_tokenizer.py b/keras_hub/src/models/moonshine/moonshine_tokenizer.py new file mode 100644 index 0000000000..c55a249af3 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_tokenizer.py @@ -0,0 +1,105 @@ +import base64 + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone + + +@keras_hub_export( + [ + "keras_hub.tokenizers.MoonshineTokenizer", + "keras_hub.models.MoonshineTokenizer", + ] +) +class MoonshineTokenizer(LlamaTokenizer): + """ + Moonshine tokenizer layer based on SentencePiece and LlamaTokenizer. + + This tokenizer class extends the `LlamaTokenizer` to tokenize raw strings + to integer sequences while incorporating Moonshine-specific special tokens. + + **Special tokens added:** + - **Start token:** "" + - **End token:** "" + - **Unknown token:** "" + - **Padding token:** "" + - **Position embedding tokens:** "<>" through "<>" + - **Hex tokens:** "<0x00>" through "<0xFF>" + - **Empty token:** "<>" + + Args: + proto: `str` or `bytes`. Either a string path to a SentencePiece proto + file or a bytes object containing a serialized SentencePiece proto. + See the [SentencePiece repository](https://github.com/google/sentencepiece) + for details on the format. + **kwargs: Additional keyword arguments passed to the parent + `LlamaTokenizer`. + + Examples: + ```python + from keras_hub.tokenizers import MoonshineTokenizer + + # Initialize tokenizer. + tokenizer = MoonshineTokenizer( + "keras_hub/src/tests/test_data/moonshine_test_vocab.spm" + ) + + # Single input example. + single_input = "the quick brown fox" + single_tokens = tokenizer(single_input) + print("Single input tokenization:") + print(f"Input text: {single_input}") + print(f"Tokenized: {single_tokens}") + + # Batched input example. + batch_input = ["the quick brown fox", "the earth is round"] + batch_tokens = tokenizer(batch_input) + print("Batch input tokenization:") + print(f"Input texts: {batch_input}") + print(f"Tokenized: {batch_tokens}") + + # Detokenization example. + encoded = tokenizer(single_input) + decoded = tokenizer.detokenize(encoded) + print("Detokenization:") + print(f"Original text: {single_input}") + print(f"Encoded: {encoded}") + print(f"Decoded: {decoded}") + ``` + """ + + # References: + # Defined in Section 3.1 of the Moonshine paper, "Moonshine: Speech + # Recognition for Live Transcription and Voice Commands" (https://arxiv.org/pdf/2410.15608.pdf) + + backbone_cls = MoonshineBackbone + + def __init__(self, proto, **kwargs): + super().__init__(proto=proto, **kwargs) + + for i in range(768): + self._add_special_token(f"<>", f"st_token_{i}") + + for i in range(256): + self._add_special_token(f"<0x{i:02X}>", f"hex_token_{i}") + + self._add_special_token("<>", "empty_token") + + self.start_token_id = self.token_to_id("") # Beginning of sentence + self.end_token_id = self.token_to_id("") # End of sentence + self.pad_token_id = self.token_to_id("") # Padding token + self.unk_token_id = self.token_to_id("") # Unknown token + + def get_config(self): + config = super().get_config() + if isinstance(self.proto, bytes): + config["proto"] = base64.b64encode(self.proto).decode("utf-8") + else: + config["proto"] = self.proto + return config + + @classmethod + def from_config(cls, config): + if "proto" in config and isinstance(config["proto"], str): + config["proto"] = base64.b64decode(config["proto"]) + return super().from_config(config) diff --git a/keras_hub/src/models/moonshine/moonshine_tokenizer_test.py b/keras_hub/src/models/moonshine/moonshine_tokenizer_test.py new file mode 100644 index 0000000000..144c10c654 --- /dev/null +++ b/keras_hub/src/models/moonshine/moonshine_tokenizer_test.py @@ -0,0 +1,96 @@ +import os + +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class MoonshineTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_moonshine_test_proto.py. + "proto": os.path.join( + self.get_test_data_dir(), "moonshine_test_vocab.spm" + ) + } + self.tokenizer = MoonshineTokenizer(**self.init_kwargs) + self.input_data = ["the quick brown fox", "the earth is round"] + self.special_token_inputs = [ + "Hello world!", + "Test with <>", + "Hex test <0x1F>", + "Empty token test <>", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=MoonshineTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_special_tokens_existence(self): + self.assertIsNotNone(self.tokenizer.start_token_id) + self.assertIsNotNone(self.tokenizer.end_token_id) + self.assertIsNotNone(self.tokenizer.pad_token_id) + + self.assertIsNotNone(self.tokenizer.unk_token_id) + self.assertIsNotNone(self.tokenizer.token_to_id("<>")) + + self.assertIsNotNone(self.tokenizer.token_to_id("<>")) + self.assertIsNotNone(self.tokenizer.token_to_id("<>")) + self.assertIsNotNone(self.tokenizer.token_to_id("<>")) + self.assertIsNotNone(self.tokenizer.token_to_id("<>")) + + self.assertIsNotNone(self.tokenizer.token_to_id("<0x00>")) + self.assertIsNotNone(self.tokenizer.token_to_id("<0x1F>")) + self.assertIsNotNone(self.tokenizer.token_to_id("<0xA0>")) + self.assertIsNotNone(self.tokenizer.token_to_id("<0xFF>")) + + def test_special_token_ids_mapping(self): + self.assertEqual( + self.tokenizer.token_to_id(""), self.tokenizer.start_token_id + ) + self.assertEqual( + self.tokenizer.token_to_id(""), self.tokenizer.end_token_id + ) + self.assertEqual( + self.tokenizer.token_to_id(""), self.tokenizer.pad_token_id + ) + self.assertEqual( + self.tokenizer.token_to_id(""), self.tokenizer.unk_token_id + ) + + def test_special_tokens_tokenization(self): + tokenized_st42 = self.tokenizer("<>") + self.assertEqual(len(tokenized_st42), 1) + + tokenized_hex = self.tokenizer("<0x1F>") + self.assertEqual(len(tokenized_hex), 1) + + tokenized_empty = self.tokenizer("<>") + self.assertEqual(len(tokenized_empty), 1) + + def test_detokenization(self): + for text in self.input_data + self.special_token_inputs: + tokens = self.tokenizer(text) + decoded = self.tokenizer.detokenize(tokens) + if text in self.input_data: + self.assertIn(text.lower(), decoded.lower()) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + MoonshineTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) + + def test_batch_tokenization(self): + batch_tokens = self.tokenizer(self.input_data) + self.assertEqual(len(batch_tokens), len(self.input_data)) + + def test_serialization(self): + instance = MoonshineTokenizer(**self.init_kwargs) + self.run_serialization_test(instance=instance) diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav new file mode 100644 index 0000000000..6ecff04a48 Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav differ diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav new file mode 100644 index 0000000000..428c1fee89 Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav differ diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav new file mode 100644 index 0000000000..b051143373 Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav differ diff --git a/keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav b/keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav new file mode 100644 index 0000000000..46d0b072c2 Binary files /dev/null and b/keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav differ diff --git a/keras_hub/src/tests/test_data/llama2_tokenizer_full.spm b/keras_hub/src/tests/test_data/llama2_tokenizer_full.spm new file mode 100644 index 0000000000..22bccbcb41 Binary files /dev/null and b/keras_hub/src/tests/test_data/llama2_tokenizer_full.spm differ diff --git a/keras_hub/src/tests/test_data/moonshine_test_vocab.spm b/keras_hub/src/tests/test_data/moonshine_test_vocab.spm new file mode 100644 index 0000000000..f9ece424de Binary files /dev/null and b/keras_hub/src/tests/test_data/moonshine_test_vocab.spm differ diff --git a/tools/checkpoint_conversion/convert_moonshine_checkpoints.py b/tools/checkpoint_conversion/convert_moonshine_checkpoints.py new file mode 100644 index 0000000000..bea96a748d --- /dev/null +++ b/tools/checkpoint_conversion/convert_moonshine_checkpoints.py @@ -0,0 +1,538 @@ +""" +Convert Moonshine checkpoints to KerasHub format and provide a complete +end-to-end example. + +The weights are sourced from: +https://huggingface.co/UsefulSensors/moonshine/tree/main/base +https://huggingface.co/UsefulSensors/moonshine/tree/main/tiny + +The Hugging Face configs are available at: +https://huggingface.co/UsefulSensors/moonshine-base/blob/main/config.json +https://huggingface.co/UsefulSensors/moonshine-tiny/blob/main/config.json + +Usage: +```shell +python -m tools.checkpoint_conversion.convert_moonshine_checkpoints +``` +""" + +import json +import os +import warnings + +import h5py +import keras + +try: + import librosa +except ImportError: + raise ImportError( + "Moonshine ASR system requires librosa as a dependency. Please install " + "it using 'pip install librosa' before proceeding." + ) +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from transformers import AutoModel + +from keras_hub.src.models.moonshine.moonshine_audio_converter import ( + MoonshineAudioConverter, +) +from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( + MoonshineAudioToText, +) +from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone +from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import ( + MoonshineSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.moonshine.moonshine_tokenizer import ( + MoonshineTokenizer, +) + +# Set random seed for reproducibility. +keras.utils.set_random_seed(50) + + +# Utility function to convert tensors to NumPy based on backend. +def to_numpy(tensor): + if keras.backend.backend() == "torch": + return tensor.detach().cpu().numpy() + elif keras.backend.backend() == "tensorflow": + return tensor.numpy() + elif keras.backend.backend() == "jax": + import jax + + return jax.device_get(tensor) + else: + raise ValueError("Unsupported backend") + + +def load_h5_weights(filepath): + with h5py.File(filepath, "r") as f: + weights = {} + + def recursive_load(group, prefix=""): + for key in group.keys(): + path = f"{prefix}/{key}" if prefix else key + if isinstance(group[key], h5py.Dataset): + weights[path] = np.array(group[key]) + else: + recursive_load(group[key], path) + + recursive_load(f) + return weights + + +# Init. +presets = ["moonshine-base", "moonshine-tiny"] +for preset in presets: + print(f"\n=== Processing {preset} ===") + PRESET_NAME = preset + PRESET = f"UsefulSensors/{preset}" + EXTRACT_DIR = "./{}" + + extract_dir = EXTRACT_DIR.format(PRESET_NAME) + if not os.path.exists(extract_dir): + os.makedirs(extract_dir) + + # Download and load config. + config_path = os.path.join(extract_dir, "config.json") + response = requests.get( + f"https://huggingface.co/{PRESET}/raw/main/config.json" + ) + open(config_path, "wb").write(response.content) + + cfg = {} + with open(config_path, "r") as pt_cfg_handler: + pt_cfg = json.load(pt_cfg_handler) + + # Setup Moonshine config. + cfg["vocabulary_size"] = pt_cfg["vocab_size"] + cfg["num_layers"] = pt_cfg["encoder_num_hidden_layers"] + cfg["filter_dim"] = pt_cfg["hidden_size"] + cfg["hidden_dim"] = pt_cfg["hidden_size"] + cfg["intermediate_dim"] = pt_cfg["intermediate_size"] + cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"] + cfg["partial_rotary_factor"] = pt_cfg["partial_rotary_factor"] + cfg["rope_theta"] = pt_cfg["rope_theta"] + cfg["encoder_num_layers"] = pt_cfg["encoder_num_hidden_layers"] + cfg["decoder_num_layers"] = pt_cfg["decoder_num_hidden_layers"] + cfg["encoder_num_heads"] = pt_cfg.get("encoder_num_attention_heads", 8) + cfg["decoder_num_heads"] = pt_cfg.get("decoder_num_attention_heads", 8) + cfg["feedforward_expansion_factor"] = 4 + cfg["attention_bias"] = pt_cfg["attention_bias"] + cfg["attention_dropout"] = pt_cfg["attention_dropout"] + cfg["dtype"] = pt_cfg["torch_dtype"] + cfg["decoder_use_swiglu_activation"] = ( + pt_cfg["decoder_hidden_act"] == "silu" + ) + cfg["encoder_use_swiglu_activation"] = ( + pt_cfg["encoder_hidden_act"] == "silu" + ) + cfg["initializer_range"] = pt_cfg["initializer_range"] + cfg["rope_scaling"] = pt_cfg["rope_scaling"] + + # Taken from: https://huggingface.co/UsefulSensors/moonshine-{base/tiny}/blob/main/preprocessor_config.json. + cfg["sampling_rate"] = 16000 + cfg["padding_value"] = 0.0 + cfg["do_normalize"] = False + cfg["return_attention_mask"] = True + + # Download weights. + weights_dir = os.path.join(extract_dir, "weights") + repo_id = "UsefulSensors/moonshine" + variant = preset.split("-")[-1] + files = [ + "encoder.weights.h5", + "preprocessor.weights.h5", + "decoder.weights.h5", + ] + for fname in files: + file_path = os.path.join(weights_dir, f"{variant}/{fname}") + if not os.path.exists(file_path): + print(f"Downloading {fname} to {file_path}...") + hf_hub_download( + repo_id=repo_id, + filename=f"{variant}/{fname}", + local_dir=weights_dir, + ) + + # Set weights paths. + encoder_weights_path = os.path.join( + weights_dir, variant, "encoder.weights.h5" + ) + preprocessor_weights_path = os.path.join( + weights_dir, variant, "preprocessor.weights.h5" + ) + decoder_weights_path = os.path.join( + weights_dir, variant, "decoder.weights.h5" + ) + hf_wts_encoder = load_h5_weights(encoder_weights_path) + hf_wts_preprocessor = load_h5_weights(preprocessor_weights_path) + hf_wts_decoder = load_h5_weights(decoder_weights_path) + + # Build Keras models. + backbone = MoonshineBackbone( + vocabulary_size=cfg["vocabulary_size"], + encoder_num_layers=cfg["encoder_num_layers"], + decoder_num_layers=cfg["decoder_num_layers"], + hidden_dim=cfg["hidden_dim"], + intermediate_dim=cfg["intermediate_dim"], + encoder_num_heads=cfg["encoder_num_heads"], + decoder_num_heads=cfg["decoder_num_heads"], + feedforward_expansion_factor=cfg["feedforward_expansion_factor"], + decoder_use_swiglu_activation=cfg["decoder_use_swiglu_activation"], + encoder_use_swiglu_activation=cfg["encoder_use_swiglu_activation"], + max_position_embeddings=cfg["max_sequence_length"], + partial_rotary_factor=cfg["partial_rotary_factor"], + dropout=cfg["attention_dropout"], + initializer_range=cfg["initializer_range"], + rope_theta=cfg["rope_theta"], + attention_bias=cfg["attention_bias"], + attention_dropout=cfg["attention_dropout"], + rope_scaling=cfg["rope_scaling"], + dtype=cfg["dtype"], + ) + + # Build tokenizer. + tokenizer = MoonshineTokenizer( + proto="keras_hub/src/tests/test_data/llama2_tokenizer_full.spm" + ) + # Build audio converter. + audio_converter = MoonshineAudioConverter( + filter_dim=cfg["filter_dim"], + initializer_range=cfg["initializer_range"], + sampling_rate=cfg["sampling_rate"], + padding_value=cfg["padding_value"], + do_normalize=cfg["do_normalize"], + return_attention_mask=cfg["return_attention_mask"], + ) + # Build preprocessor. + preprocessor = MoonshineSeq2SeqLMPreprocessor( + audio_converter=audio_converter, + tokenizer=tokenizer, + encoder_sequence_length=None, + decoder_sequence_length=cfg["max_sequence_length"], + ) + # Build the model. + keras_model = MoonshineAudioToText( + backbone=backbone, + preprocessor=preprocessor, + ) + + # Build the model with dummy data. + dummy_audio = np.zeros((1, 16000), dtype="float32") + dummy_text = [""] + dummy_inputs = {"audio": dummy_audio, "text": dummy_text} + preprocessed_inputs, _, _ = preprocessor(dummy_inputs) + keras_model(preprocessed_inputs) + + # Assign preprocessor weights. + base_path = "layers/sequential/layers/" + weights = [ + hf_wts_preprocessor[f"{base_path}conv1d/vars/0"], # conv1 kernel + hf_wts_preprocessor[f"{base_path}group_normalization/vars/0"], # gamma + hf_wts_preprocessor[f"{base_path}group_normalization/vars/1"], # beta + hf_wts_preprocessor[f"{base_path}conv1d_1/vars/0"], # conv2 kernel + hf_wts_preprocessor[f"{base_path}conv1d_1/vars/1"], # conv2 bias + hf_wts_preprocessor[f"{base_path}conv1d_2/vars/0"], # conv3 kernel + hf_wts_preprocessor[f"{base_path}conv1d_2/vars/1"], # conv3 bias + ] + keras_model.preprocessor.audio_converter.set_weights(weights) + + # Assign encoder weights. + keras_model.backbone.encoder_rotary_embedding.inv_freq.assign( + hf_wts_encoder["layers/rotary_embedding/vars/0"] + ) + + for layer_index in range(cfg["encoder_num_layers"]): + if layer_index == 0: + base_prefix = "layers/functional/layers" + else: + base_prefix = f"layers/functional_{layer_index}/layers" + attention_prefix = f"{base_prefix}/mha_with_rope" + ff_prefix = f"{base_prefix}/functional/layers/sequential/layers" + + # Attention weights. + keras_model.backbone.encoder_blocks[ + layer_index + ].self_attention_layer._query_dense.kernel.assign( + hf_wts_encoder[f"{attention_prefix}/query_dense/vars/0"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].self_attention_layer._key_dense.kernel.assign( + hf_wts_encoder[f"{attention_prefix}/key_dense/vars/0"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].self_attention_layer._value_dense.kernel.assign( + hf_wts_encoder[f"{attention_prefix}/value_dense/vars/0"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].self_attention_layer._output_dense.kernel.assign( + hf_wts_encoder[f"{attention_prefix}/output_dense/vars/0"] + ) + + # Layer norms. + keras_model.backbone.encoder_blocks[ + layer_index + ].self_attention_layer_norm.gamma.assign( + hf_wts_encoder[f"{base_prefix}/layer_normalization/vars/0"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].feedforward_layer_norm.gamma.assign( + hf_wts_encoder[f"{base_prefix}/layer_normalization_1/vars/0"] + ) + + # Feedforward weights. + keras_model.backbone.encoder_blocks[ + layer_index + ].feedforward.dense_1.kernel.assign( + hf_wts_encoder[f"{ff_prefix}/dense/vars/0"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].feedforward.dense_1.bias.assign( + hf_wts_encoder[f"{ff_prefix}/dense/vars/1"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].feedforward.dense_2.kernel.assign( + hf_wts_encoder[f"{ff_prefix}/dense_1/vars/0"] + ) + keras_model.backbone.encoder_blocks[ + layer_index + ].feedforward.dense_2.bias.assign( + hf_wts_encoder[f"{ff_prefix}/dense_1/vars/1"] + ) + + keras_model.backbone.encoder_final_layer_norm.gamma.assign( + hf_wts_encoder["layers/layer_normalization/vars/0"] + ) + + # Assign decoder weights. + keras_model.backbone.token_embedding.embeddings.assign( + hf_wts_decoder["layers/reversible_embedding/vars/0"] + ) + keras_model.backbone.decoder_rotary_embedding.inv_freq.assign( + hf_wts_decoder["layers/rotary_embedding/vars/0"] + ) + + for layer_index in range(cfg["decoder_num_layers"]): + if layer_index == 0: + base_prefix = "layers/functional/layers" + else: + base_prefix = f"layers/functional_{layer_index}/layers" + self_attention_prefix = f"{base_prefix}/mha_causal_with_rope" + cross_attention_prefix = f"{base_prefix}/mha_precomputed_kv" + ff_prefix = f"{base_prefix}/functional/layers" + + # Self-attention weights. + keras_model.backbone.decoder_blocks[ + layer_index + ].self_attention._query_dense.kernel.assign( + hf_wts_decoder[f"{self_attention_prefix}/query_dense/vars/0"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].self_attention._key_dense.kernel.assign( + hf_wts_decoder[f"{self_attention_prefix}/key_dense/vars/0"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].self_attention._value_dense.kernel.assign( + hf_wts_decoder[f"{self_attention_prefix}/value_dense/vars/0"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].self_attention._output_dense.kernel.assign( + hf_wts_decoder[f"{self_attention_prefix}/output_dense/vars/0"] + ) + + # Cross-attention weights. + keras_model.backbone.decoder_blocks[ + layer_index + ].cross_attention._query_dense.kernel.assign( + hf_wts_decoder[f"{cross_attention_prefix}/query_dense/vars/0"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].cross_attention._key_dense.kernel.assign( + hf_wts_decoder[f"{cross_attention_prefix}/key_dense/vars/0"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].cross_attention._value_dense.kernel.assign( + hf_wts_decoder[f"{cross_attention_prefix}/value_dense/vars/0"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].cross_attention._output_dense.kernel.assign( + hf_wts_decoder[f"{cross_attention_prefix}/output_dense/vars/0"] + ) + + # Layer norms. + keras_model.backbone.decoder_blocks[layer_index].norm1.gamma.assign( + hf_wts_decoder[f"{base_prefix}/layer_normalization/vars/0"] + ) + keras_model.backbone.decoder_blocks[layer_index].norm2.gamma.assign( + hf_wts_decoder[f"{base_prefix}/layer_normalization_1/vars/0"] + ) + keras_model.backbone.decoder_blocks[layer_index].norm3.gamma.assign( + hf_wts_decoder[f"{base_prefix}/layer_normalization_2/vars/0"] + ) + + # Feedforward weights. + keras_model.backbone.decoder_blocks[ + layer_index + ].ff.dense_1.kernel.assign(hf_wts_decoder[f"{ff_prefix}/dense/vars/0"]) + keras_model.backbone.decoder_blocks[layer_index].ff.dense_1.bias.assign( + hf_wts_decoder[f"{ff_prefix}/dense/vars/1"] + ) + keras_model.backbone.decoder_blocks[ + layer_index + ].ff.dense_2.kernel.assign( + hf_wts_decoder[f"{ff_prefix}/dense_1/vars/0"] + ) + keras_model.backbone.decoder_blocks[layer_index].ff.dense_2.bias.assign( + hf_wts_decoder[f"{ff_prefix}/dense_1/vars/1"] + ) + + keras_model.backbone.decoder_post_norm.gamma.assign( + hf_wts_decoder["layers/layer_normalization/vars/0"] + ) + + # Save Keras model weights. + output_dir = os.path.join(extract_dir, f"{preset}-model.keras") + keras_model.save(output_dir) + print(f"Saved Keras model weights to {output_dir}") + + # Prepare inputs. + sample_text = [ + np.random.randn(16000).astype("float32") + ] # Random audio sample + keras_preprocessed_inputs = keras_model.preprocessor.audio_converter( + keras.ops.convert_to_tensor(sample_text), padding="longest" + ) + encoder_input_values = keras_preprocessed_inputs["input_values"] + encoder_padding_mask = keras_preprocessed_inputs["attention_mask"] + + # Prepare raw audio for HF model. + raw_audio = np.array(sample_text) # Shape: (1, 16000) + + # For HF model, use raw audio instead of preprocessed features. + hf_inputs = { + "input_values": torch.from_numpy(raw_audio), # Shape: (1, 16000) + "decoder_input_ids": torch.randint( + 0, cfg["vocabulary_size"], (1, 32), dtype=torch.int32 + ), + } + position_ids = torch.arange(0, 32, dtype=torch.long).unsqueeze(0) + + # Prepare Keras inputs for backbone. + decoder_token_ids = keras.ops.convert_to_tensor( + hf_inputs["decoder_input_ids"] + ) + decoder_padding_mask = keras.ops.cast( + keras.ops.not_equal(decoder_token_ids, 0), "bool" + ) + + # Run Keras backbone. + keras_backbone_outputs = keras_model.backbone( + { + "encoder_input_values": encoder_input_values, + "decoder_token_ids": decoder_token_ids, + "encoder_padding_mask": encoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask, + }, + training=False, + ) + keras_encoder_output = to_numpy( + keras_backbone_outputs["encoder_sequence_output"] + ) + keras_decoder_output = to_numpy( + keras_backbone_outputs["decoder_sequence_output"] + ) + + # Run Hugging Face model and compute outputs. + hf_model = AutoModel.from_pretrained(PRESET) + hf_model.eval() + with torch.no_grad(): + hf_outputs = hf_model( + input_values=hf_inputs["input_values"], + decoder_input_ids=hf_inputs["decoder_input_ids"], + decoder_position_ids=position_ids, + output_hidden_states=True, + ) + hf_encoder_hidden_states = hf_outputs.encoder_hidden_states[-1] + hf_encoder_output_np = hf_encoder_hidden_states.numpy() + hf_decoder_hidden_states = hf_outputs.last_hidden_state + hf_decoder_output_np = hf_decoder_hidden_states.numpy() + + # Compute absolute differences between HF and Keras outputs. + encoder_abs_diff = np.abs(keras_encoder_output - hf_encoder_output_np) + encoder_min_abs_diff = np.min(encoder_abs_diff) + encoder_max_abs_diff = np.max(encoder_abs_diff) + decoder_abs_diff = np.abs(keras_decoder_output - hf_decoder_output_np) + decoder_min_abs_diff = np.min(decoder_abs_diff) + decoder_max_abs_diff = np.max(decoder_abs_diff) + # Print differences. + print(f"\n=== Differences for {preset} ===") + if preset == "moonshine-tiny": + warnings.warn( + "Note: The 'moonshine-tiny' numerics results differ between " + "implementations. This discrepancy stems from a bug in the HF " + "implementation, likely in the rotary embeddings calculation. The " + "bug causes failures with longer transcripts, while the Keras " + "implementation handles these correctly, as demonstrated in the " + "notebook." + ) + print( + f"Encoder output absolute differences: min={encoder_min_abs_diff}, " + f"max={encoder_max_abs_diff}" + ) + print( + f"Decoder output absolute differences: min={decoder_min_abs_diff}, " + f"max={decoder_max_abs_diff}" + ) + # Test: End-to-End ASR Examples. + print(f"\n=== End-to-End ASR Example for {preset} ===") + # Test 1: Male Clear Voice, Snippet (Length - 3 Sec) + print("\nTest: Male Clear Voice, Snippet (Length - 3 Sec)") + audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/male_short_voice_clip_3sec.wav" # noqa: E501 + audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"]) + audio = audio.reshape(1, -1) + inputs = {"audio": audio, "text": [""]} + transcription = keras_model.generate(inputs) + print("Transcription:", transcription) + + # Test 2: Female Clear Voice, Excerpt (Length - 17 Sec) + print("\nTest: Female Clear Voice, Excerpt (Length - 17 Sec)") + audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/female_short_voice_clip_17sec.wav" # noqa: E501 + audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"]) + audio = audio.reshape(1, -1) + inputs = {"audio": audio, "text": [""]} + transcription = keras_model.generate(inputs) + print("Transcription:", transcription) + + # Test 3: Male Muffled Voice, Manuscript (Length - 46 Sec) + print("\nTest: Male Muffled Voice, Manuscript (Length - 46 Sec)") + audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/male_muffled_voice_clip_46sec.wav" # noqa: E501 + audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"]) + audio = audio.reshape(1, -1) + inputs = {"audio": audio, "text": [""]} + transcription = keras_model.generate(inputs, max_length=200) + print("Transcription:", transcription) + + # Test 4: Female Clear Voice, Odyssey (Maximum Length - 64 Sec) + print("\nTest: Female Clear Voice, Odyssey (Maximum Length - 64 Sec)") + audio_path = "keras_hub/src/tests/test_data/audio_transcription_tests/female_long_voice_clip_64sec.wav" # noqa: E501 + audio, sr = librosa.load(audio_path, sr=cfg["sampling_rate"]) + audio = audio.reshape(1, -1) + inputs = {"audio": audio, "text": [""]} + transcription = keras_model.generate(inputs, max_length=200) + print("Transcription:", transcription) diff --git a/tools/sentencepiece_testing/create_moonshine_test_proto.py b/tools/sentencepiece_testing/create_moonshine_test_proto.py new file mode 100644 index 0000000000..ef2173d01b --- /dev/null +++ b/tools/sentencepiece_testing/create_moonshine_test_proto.py @@ -0,0 +1,42 @@ +import os + +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def create_moonshine_test_vocab(): + special_tokens = ( + [ + "<>", # Empty token + ] + + [f"<0x{i:02X}>" for i in range(256)] + + [f"<>" for i in range(768)] + ) + training_texts = ["the quick brown fox", "the earth is round"] + test_data_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "..", + "keras_hub", + "src", + "tests", + "test_data", + ) + os.makedirs(test_data_dir, exist_ok=True) + model_prefix = "moonshine_test_vocab" + target_path = os.path.join(test_data_dir, f"{model_prefix}.spm") + train_sentencepiece( + training_texts, + target_path, + vocab_size=11 + len(special_tokens), + model_type="WORD", + pad_id=0, # token + unk_id=1, # token + bos_id=2, # token + eos_id=3, # token + user_defined_symbols=special_tokens, + ) + print(f"Moonshine test vocabulary created at: {target_path}") + + +if __name__ == "__main__": + create_moonshine_test_vocab()