Skip to content

Convert TF and Numpy ops in whisper_audio_convert.py to Keras Ops #2225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 70 additions & 65 deletions keras_hub/src/models/whisper/whisper_audio_converter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import numpy as np
import keras.ops as ops
import tensorflow as tf

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can never do a bare import of tf like this. Check other files in the library.

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone

try:
import tensorflow as tf
except ImportError:
tf = None


@keras_hub_export("keras_hub.layers.WhisperAudioConverter")
class WhisperAudioConverter(AudioConverter):
@@ -36,7 +32,7 @@ class WhisperAudioConverter(AudioConverter):

Examples:
```python
audio_tensor = tf.ones((8000,), dtype="float32")
audio_tensor = ops.ones((8000,), dtype="float32")

# Compute the log-mel spectrogram.
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
@@ -45,8 +41,8 @@ class WhisperAudioConverter(AudioConverter):
audio_converter(audio_tensor)

# Compute the log-mel spectrogram for a batch of audio tensors.
audio_tensor_1 = tf.ones((8000,), dtype="float32")
audio_tensor_2 = tf.ones((10000,), dtype="float32")
audio_tensor_1 = ops.ones((8000,), dtype="float32")
audio_tensor_2 = ops.ones((10000,), dtype="float32")
audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
audio_converter(audio_tensor)
```
@@ -84,33 +80,33 @@ def audio_shape(self):
"""Returns the preprocessed size of a single audio sample."""
return (self.max_audio_length, self.num_mels)

def _get_rfftfreq_keras(self):
n = self.num_fft_bins
d = 1.0 / self.sampling_rate

if n % 2 == 0:
freqs = ops.arange(0, n // 2 + 1, dtype="float32") / (d * n)
else:
freqs = ops.arange(0, (n - 1) // 2 + 1, dtype="float32") / (d * n)

return freqs

def _get_mel_filters(self):
"""
Adapted from Hugging Face
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
"""

# TODO: Convert to TensorFlow ops (if possible).

dtype = np.float32
dtype = self.compute_dtype # Use the class's dtype
# Initialize the weights
weights = np.zeros(
weights = ops.zeros(
(self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype
)

# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(
n=self.num_fft_bins, d=1.0 / self.sampling_rate
)

fftfreqs = self._get_rfftfreq_keras()
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965

mels = np.linspace(min_mel, max_mel, self.num_mels + 2)

mels = np.asanyarray(mels)

mels = ops.linspace(min_mel, max_mel, self.num_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
@@ -119,93 +115,102 @@ def _get_mel_filters(self):
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region

logstep = ops.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(
logstep * (mels[log_t] - min_log_mel)
freqs = ops.where(
log_t, min_log_hz * ops.exp(logstep * (mels - min_log_mel)), freqs
)

mel_f = freqs
fdiff = ops.diff(mel_f)
ramps = (
ops.expand_dims(mel_f, axis=1) - fftfreqs
) # keras subtract outer

fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)

weights_list = []
for i in range(self.num_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]

# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
weights_i = ops.maximum(0, ops.minimum(lower, upper))
weights_list.append(weights_i)

weights = ops.stack(weights_list)

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels])
weights *= enorm[:, np.newaxis]
weights *= ops.expand_dims(enorm, axis=1)

weights = np.transpose(weights)
return tf.constant(weights, dtype=self.compute_dtype)
weights = ops.transpose(weights)
return weights

def _extract_audio_features(self, audio):
audio = tf.cast(audio, self.compute_dtype)
audio = ops.cast(audio, self.compute_dtype)
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding
# internally.
audio = tf.pad(
audio = ops.pad(
audio,
paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]],
mode="REFLECT",
pad_width=[
[0, 0],
[self.num_fft_bins // 2, self.num_fft_bins // 2],
],
mode="reflect",
)

# Compute the mel spectrogram.
stft = tf.signal.stft(
stft = ops.stft(
audio,
frame_length=self.num_fft_bins,
frame_step=self.stride,
sequence_length=self.num_fft_bins,
sequence_stride=self.stride,
fft_length=self.num_fft_bins,
center=False,
)
magnitudes = tf.square(tf.abs(stft[:, :-1, :]))
stft = ops.sum(stft, axis=0)
magnitudes = ops.square(ops.absolute(stft[:, :-1, :]))

mel_spec = tf.matmul(
mel_spec = ops.matmul(
magnitudes,
self.mel_filters,
)
# mel_spec = ops.matmul(magnitudes,mel_filters_casted,)

def tf_log10(x):
"""Computes log base 10 of input tensor using TensorFlow."""
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
numerator = ops.log(x)
denominator = ops.log(
ops.cast(ops.array(10), dtype=numerator.dtype)
)
return numerator / denominator

# Clamp the values to a minimum value of 1e-10. This is done to avoid
# taking the log of 0, i.e., for numerical stability.
mel_spec = tf.maximum(mel_spec, 1e-10)
mel_spec = ops.maximum(mel_spec, 1e-10)

# Calculate the log mel spectrogram.
log_spec = tf_log10(mel_spec)
# Dynamic range compression.
log_spec_shape = tf.shape(log_spec)
max_value_minus_eight = tf.math.subtract(
tf.math.reduce_max(log_spec, axis=[1, 2]),
tf.cast(8, dtype=log_spec.dtype),
log_spec_shape = ops.shape(log_spec)
max_value_minus_eight = ops.subtract(
ops.max(log_spec, axis=[1, 2]),
ops.cast(8, dtype=log_spec.dtype),
)
max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1)
max_value_minus_eight = tf.repeat(
max_value_minus_eight = ops.expand_dims(max_value_minus_eight, axis=1)
max_value_minus_eight = ops.repeat(
max_value_minus_eight,
repeats=log_spec_shape[1] * log_spec_shape[2],
axis=1,
)
max_value_minus_eight = tf.reshape(
max_value_minus_eight, shape=log_spec_shape
max_value_minus_eight = ops.reshape(
max_value_minus_eight, newshape=log_spec_shape
)
log_spec = tf.maximum(log_spec, max_value_minus_eight)
log_spec = ops.maximum(log_spec, max_value_minus_eight)
# Normalization.
type_cast_four = tf.cast(4, dtype=log_spec.dtype)
log_spec = tf.math.divide(
tf.math.add(log_spec, type_cast_four),
type_cast_four = ops.cast(4, dtype=log_spec.dtype)
log_spec = ops.divide(
ops.add(log_spec, type_cast_four),
type_cast_four,
)

return log_spec

def call(self, audio):
@@ -214,21 +219,21 @@ def call(self, audio):

rank_1_input = audio.shape.rank == 1
if rank_1_input:
audio = tf.expand_dims(audio, 0)
audio = ops.expand_dims(audio, 0)

# Convert the tensor to a Ragged Tensor.
if isinstance(audio, tf.Tensor):
audio = tf.RaggedTensor.from_tensor(audio)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably switch to a @preprocessing_function annotation and remove this? Something to try at least.


# Pad audio.
audio_shape = audio.shape.as_list()
audio_shape = list(audio.shape)
audio_shape[-1] = self.num_samples
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe tf.shape cannot always be listified like this. Maybe call ops.shape?

audio = audio.to_tensor(shape=audio_shape)

# Find the log mel spectrogram.
log_spec = self._extract_audio_features(audio)
if rank_1_input:
log_spec = tf.squeeze(log_spec, 0)
log_spec = ops.squeeze(log_spec, 0)
return log_spec

def get_config(self):
9 changes: 5 additions & 4 deletions keras_hub/src/models/whisper/whisper_audio_converter_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
import keras.ops as ops

from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
@@ -15,8 +16,8 @@ def setUp(self):
"sampling_rate": 100,
"max_audio_length": 5,
}
audio_tensor_1 = tf.ones((2,), dtype="float32")
audio_tensor_2 = tf.ones((25,), dtype="float32")
audio_tensor_1 = ops.ones((2,), dtype="float32")
audio_tensor_2 = ops.ones((25,), dtype="float32")
self.input_data = tf.ragged.stack(
[audio_tensor_1, audio_tensor_2],
axis=0,
@@ -30,11 +31,11 @@ def test_feature_extractor_basics(self):
)

def test_correctness(self):
audio_tensor = tf.ones((2,), dtype="float32")
audio_tensor = ops.ones((2,), dtype="float32")
outputs = WhisperAudioConverter(**self.init_kwargs)(audio_tensor)

# Verify shape.
self.assertEqual(outputs.shape, (5, 80))
# Verify output.
expected = [1.1656, 1.0151, -0.8343, -0.8343, -0.8343]
self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01)
self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep newlines at end of files.

Loading