Skip to content

Add Moonshine to KerasHub #2093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8037ed0
init: Add MoonshineBackbone files
harshaljanjani Feb 10, 2025
51a40b8
feat: Make backbone test suite more robust
harshaljanjani Feb 10, 2025
098781e
feat: Exactness to the original and robustness of test cases
harshaljanjani Feb 11, 2025
047de1f
fix: Support stacked encoder layers from original implementation
harshaljanjani Feb 11, 2025
885f77f
TODO: Fix layer names
harshaljanjani Feb 12, 2025
805a806
fix: Add __init__ file
harshaljanjani Feb 12, 2025
9f579c0
Merge branch 'master' into moonshine
harshaljanjani Feb 12, 2025
aebeac7
fix: Correct subclassing and make ops more robust
harshaljanjani Feb 12, 2025
60112d5
feat: Incorporate feedback for Moonshine
harshaljanjani Feb 16, 2025
10cff1e
refactor: Move super.build() calls to the beginning of build() functions
harshaljanjani Feb 18, 2025
8dac22f
fix: Resolve API issue and fix duplicate parameters in attention
harshaljanjani Feb 18, 2025
2bacaf2
init: Add MoonshineDecoderBlock files (TODO: MoonshineDecoder)
harshaljanjani Feb 19, 2025
3af8498
feat: Add MoonshineDecoder with questionable tolerance
harshaljanjani Feb 20, 2025
2a2fcb9
fix: Fix decoder numerics (TODO: serialization and tokenizer)
harshaljanjani Feb 21, 2025
e05d1ed
feat: Add Tokenizer and SentencePiece model files
harshaljanjani Feb 22, 2025
9130d2c
refactor: API modification and temporarily removed TestCase
harshaljanjani Feb 22, 2025
b4e1ae9
chore: Update HF params (TODO: Resolve numerics issue)
harshaljanjani Feb 24, 2025
524e052
fix: Refactor model components: improve documentation, fix numerical …
harshaljanjani Feb 27, 2025
6cd6cd9
fix: Update checkpoint paths to include base directory for encoder, p…
harshaljanjani Feb 27, 2025
1e330b8
refactor: Rename arguments for clarity in Moonshine layers and unit t…
harshaljanjani Feb 28, 2025
64bcd63
feat: Add decoder to MoonshineBackbone, enable mixed-precision traini…
harshaljanjani Feb 28, 2025
c8f82aa
feat: Revamp test suites, finalize MoonshineBackbone, and improve doc…
harshaljanjani Mar 3, 2025
706078f
test: Add unit tests for Moonshine layers including InvFreqInitialize…
harshaljanjani Mar 3, 2025
57ad858
fix: TensorFlow compatibility in MoonshineInvFreqInitializer test
harshaljanjani Mar 3, 2025
67d01a9
refactor: Remove MoonshinePreprocessor and update related tests and i…
harshaljanjani Mar 8, 2025
2c76289
clean up: Use ReversibleEmbedding and shorten the weights conversion …
harshaljanjani Mar 9, 2025
a32d292
refactor: Simplify input handling in MoonshineBackbone and remove unu…
harshaljanjani Mar 9, 2025
189d39e
refactor: Update MoonshineBackbone, remove testable components, and f…
harshaljanjani Mar 10, 2025
3e236bb
feat: Add padding mask support, make the logits() function for a trai…
harshaljanjani Mar 11, 2025
e993ead
feat: Add trainable conditional generation task model, fix nits
harshaljanjani Mar 14, 2025
34ea915
refactor: Reformat MoonshineForConditionalGeneration and MoonshineBac…
harshaljanjani Mar 14, 2025
5599073
Merge branch 'keras-team:master' into moonshine
harshaljanjani Mar 14, 2025
f57fcd1
may fix JAX (Keras 3.5) backend tests: Update input handling to use a…
harshaljanjani Mar 15, 2025
c9e4d76
cleanup: Merge MoonshineAttention into a single class, remove unneces…
harshaljanjani Mar 15, 2025
80b1d9d
may fix JAX (Keras 3.5) backend: Fix initializer error in MoonshineRo…
harshaljanjani Mar 16, 2025
efc1424
finalizing changes: Complete generate() API with caching speedup
harshaljanjani Mar 18, 2025
18c06ef
refactor: Apply BART-inspired structural changes and optimize generate()
harshaljanjani Mar 23, 2025
d719ca2
bug fix: Fix the build() method in MoonshineAudioConverter, thus reso…
harshaljanjani Mar 24, 2025
578c7d0
task: Complete rewrite of the generation strategy in one go (TODO: up…
harshaljanjani Mar 31, 2025
0705d58
feat: Update weights conversion script
harshaljanjani Mar 31, 2025
3224a28
revert: Leave comments in the code for next review and revert caching…
harshaljanjani Apr 2, 2025
f5541f4
fix nits: Add warnings and the missing decoder_attention_mask param; …
harshaljanjani Apr 3, 2025
63a457f
another refactor: Single caching strategy, easily integrable into Ker…
harshaljanjani Apr 6, 2025
f961a06
fix nits: Remove unused encoder packer init, re-enable MHA tests; the…
harshaljanjani Apr 9, 2025
4f53d78
hooraayyy: The tests are yet to be fixed, but the task model works on…
harshaljanjani Apr 12, 2025
17ec26e
TODO: Fix JAX backend
harshaljanjani Apr 14, 2025
c59607d
end of sprint: Complete JAX backend implementation
harshaljanjani Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
MobileNetImageConverter,
)
from keras_hub.src.models.moonshine.moonshine_audio_converter import (
MoonshineAudioConverter,
)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
Expand Down
10 changes: 10 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@
from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor,
)
from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
MoonshineAudioToText,
)
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
from keras_hub.src.models.moonshine.moonshine_seq_2_seq_lm_preprocessor import (
MoonshineSeq2SeqLMPreprocessor,
)
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
MoonshineTokenizer,
)
from keras_hub.src.models.object_detector import ObjectDetector
from keras_hub.src.models.object_detector import (
ObjectDetector as ImageObjectDetector,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
MoonshineTokenizer,
)
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
Expand Down
Empty file.
260 changes: 260 additions & 0 deletions keras_hub/src/models/moonshine/moonshine_audio_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import warnings

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
from keras_hub.src.models.moonshine.moonshine_layers import (
moonshine_kernel_initializer,
)
from keras_hub.src.utils.keras_utils import clone_initializer


@keras_hub_export("keras_hub.layers.MoonshineAudioConverter")
class MoonshineAudioConverter(AudioConverter):
"""Moonshine preprocessor and audio converter layer.

This layer processes raw audio waveforms for the Moonshine ASR model. Audio
is formatted as a batched tensor at a 16kHz sample rate and validated for
length (0.1 to 64 seconds). The layer downsamples and extracts key features
from the audio signal through a series of convolutional operations,
normalization, and nonlinear activations.

Args:
filter_dim: int. The number of filters for the first convolutional
layer. This influences the dimensionality of the feature extraction
pipeline and determines the richness of the audio representation.
sampling_rate: int, optional. The audio sampling rate in Hz. Defaults to
16,000.
padding_value: float, optional. The value for padding. Defaults to 0.0.
do_normalize: bool, optional. Whether to normalize inputs. Defaults to
False.
return_attention_mask: bool, optional. Whether to return an attention
mask. Defaults to True.
initializer_range: float, optional. The standard deviation for kernel
initialization. Defaults to 0.02.
**kwargs: Additional keyword arguments passed to the base AudioConverter
class for customizing the underlying preprocessing behavior.

Examples:
```python
import keras
from keras_hub.layers import MoonshineAudioConverter

# Create a dummy audio input (1 second at 16kHz).
dummy_audio = keras.ops.convert_to_tensor(
[[0.1] * 16000],
dtype="float32"
)
dummy_audio = keras.ops.expand_dims(dummy_audio, axis=-1)

# Initialize the preprocessor.
preprocessor = MoonshineAudioConverter(filter_dim=256)

# Process the audio.
features = preprocessor(dummy_audio)

# Output shapes.
print(features["input_values"].shape) # Expected: (1, 40, 256)
print(features["attention_mask"].shape) # Expected: (1, 40)
```
"""

# References:
# Defined and formulated based on the Hugging Face implementation of the
# Wav2Vec2FeatureExtractor class (https://github.com/huggingface/transformers/blob/66f29aaaf55c8fe0c3dbcd24beede2ca4effac56/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L31-L243)
# and the convolutional layer structure defined in the UsefulSensors
# implementation of the AudioPreprocessor class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L6-L32).

backbone_cls = MoonshineBackbone

def __init__(
self,
filter_dim,
sampling_rate=16000,
padding_value=0.0,
do_normalize=False,
return_attention_mask=True,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.filter_dim = filter_dim
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.do_normalize = do_normalize
self.return_attention_mask = return_attention_mask
self.initializer_range = initializer_range
self.kernel_initializer = moonshine_kernel_initializer(
initializer_range=initializer_range
)

self.conv1 = keras.layers.Conv1D(
filters=filter_dim,
kernel_size=127,
strides=64,
use_bias=False,
kernel_initializer=clone_initializer(self.kernel_initializer),
)
self.tanh = keras.layers.Activation("tanh")
self.group_norm = keras.layers.GroupNormalization(
groups=1,
axis=-1,
epsilon=1e-5,
)
self.conv2 = keras.layers.Conv1D(
filters=2 * filter_dim,
kernel_size=7,
strides=3,
padding="valid",
kernel_initializer=clone_initializer(self.kernel_initializer),
)
self.gelu1 = keras.layers.Activation("gelu")
self.conv3 = keras.layers.Conv1D(
filters=filter_dim,
kernel_size=3,
strides=2,
padding="valid",
kernel_initializer=clone_initializer(self.kernel_initializer),
)
self.gelu2 = keras.layers.Activation("gelu")

def build(self, input_shape):
self.conv1.build((None, None, 1))
self.group_norm.build((None, None, self.filter_dim))
self.conv2.build((None, None, self.filter_dim))
self.conv3.build((None, None, 2 * self.filter_dim))
self.built = True

def call(
self,
inputs,
sampling_rate=None,
padding=None,
max_length=None,
pad_to_multiple_of=None,
return_tensors=None,
):
# Validate sampling rate.
if sampling_rate is not None and sampling_rate != self.sampling_rate:
raise ValueError(
f"Expected sampling_rate {self.sampling_rate}, got "
f"{sampling_rate}"
)

# Ensure inputs are (batch_size, time_steps, 1).
if keras.ops.ndim(inputs) == 2:
inputs = keras.ops.expand_dims(inputs, axis=-1)
elif keras.ops.ndim(inputs) != 3 or keras.ops.shape(inputs)[-1] != 1:
raise ValueError(
"Inputs must be mono audio: (batch_size, time_steps, 1)"
)

# Get original length and validate duration.
original_length = keras.ops.shape(inputs)[1]
duration = original_length / self.sampling_rate
# Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/transcribe.py#L20
if duration < 0.1 or duration > 64:
raise warnings.warn(
f"Audio duration must be between 0.1 and 64 seconds, got "
f"{duration:.2f} seconds in a single transcribe call. For "
"transcribing longer segments, pre-segment your audio and "
"provide shorter segments."
)
# Handle padding.
if padding == "longest":
max_length = original_length
elif padding == "max_length" and max_length is None:
max_length = original_length
if max_length is not None:
if pad_to_multiple_of:
max_length = (
(max_length + pad_to_multiple_of - 1) // pad_to_multiple_of
) * pad_to_multiple_of
if original_length < max_length:
padding_amount = max_length - original_length
inputs = keras.ops.pad(
inputs,
[(0, 0), (0, padding_amount), (0, 0)],
constant_values=self.padding_value,
)

# Normalize if enabled.
if self.do_normalize:
mean = keras.ops.mean(inputs, axis=1, keepdims=True)
var = keras.ops.var(inputs, axis=1, keepdims=True)
inputs = (inputs - mean) / keras.ops.sqrt(var + 1e-7)

# Apply convolutional feature extraction.
x = self.conv1(inputs)
x = self.tanh(x)
x = self.group_norm(x)
x = self.conv2(x)
x = self.gelu1(x)
x = self.conv3(x)
features = self.gelu2(x)

# Generate attention mask.
output_length = keras.ops.shape(features)[1]
attention_mask = None
if self.return_attention_mask:
# Calculate mask length through the network's downsampling ops.
# Step 1: First conv layer (conv1).
conv1_out = (original_length - 127 + 1) / 64
# Step 2: Second conv layer (conv2).
conv2_out = (conv1_out - 7 + 1) / 3
# Step 3: Third conv layer (conv3).
conv3_out = (conv2_out - 3 + 1) / 2

# Apply ceil() to get the final mask length as an int.
mask_length = keras.ops.cast(
keras.ops.ceil(keras.ops.cast(conv3_out, "float32")), "int32"
)
# Broadcast the mask length to match the batch size.
batch_size = keras.ops.shape(inputs)[0]
mask_length = keras.ops.broadcast_to(mask_length, [batch_size])
indices = keras.ops.arange(output_length, dtype="int32")
attention_mask = keras.ops.cast(
indices[None, :] < mask_length[:, None], dtype="int32"
)

output = {"input_values": features}
if attention_mask is not None:
output["attention_mask"] = attention_mask

return output

def compute_output_shape(self, input_shape):
# [batch_size, time_steps] → [batch_size, time_steps, 1].
if len(input_shape) == 2:
expanded_shape = (input_shape[0], input_shape[1], 1)
else:
expanded_shape = input_shape
# Compute output shape sequentially.
x_shape = self.conv1.compute_output_shape(expanded_shape)
x_shape = self.tanh.compute_output_shape(x_shape)
x_shape = self.group_norm.compute_output_shape(x_shape)
x_shape = self.conv2.compute_output_shape(x_shape)
x_shape = self.gelu1.compute_output_shape(x_shape)
x_shape = self.conv3.compute_output_shape(x_shape)
x_shape = self.gelu2.compute_output_shape(x_shape)
output_shape = {"input_values": x_shape}
if self.return_attention_mask:
# [batch_size, output_time_steps].
output_shape["attention_mask"] = (expanded_shape[0], x_shape[1])
return output_shape

def get_config(self):
config = super().get_config()
config.update(
{
"filter_dim": self.filter_dim,
"sampling_rate": self.sampling_rate,
"padding_value": self.padding_value,
"do_normalize": self.do_normalize,
"return_attention_mask": self.return_attention_mask,
"initializer_range": self.initializer_range,
}
)
return config
Loading
Loading