-
Notifications
You must be signed in to change notification settings - Fork 278
Convert TF and Numpy ops in whisper_audio_convert.py to Keras Ops #2225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,10 @@ | ||
import numpy as np | ||
import keras.ops as ops | ||
import tensorflow as tf | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter | ||
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone | ||
|
||
try: | ||
import tensorflow as tf | ||
except ImportError: | ||
tf = None | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.WhisperAudioConverter") | ||
class WhisperAudioConverter(AudioConverter): | ||
|
@@ -36,7 +32,7 @@ class WhisperAudioConverter(AudioConverter): | |
|
||
Examples: | ||
```python | ||
audio_tensor = tf.ones((8000,), dtype="float32") | ||
audio_tensor = ops.ones((8000,), dtype="float32") | ||
|
||
# Compute the log-mel spectrogram. | ||
audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset( | ||
|
@@ -45,8 +41,8 @@ class WhisperAudioConverter(AudioConverter): | |
audio_converter(audio_tensor) | ||
|
||
# Compute the log-mel spectrogram for a batch of audio tensors. | ||
audio_tensor_1 = tf.ones((8000,), dtype="float32") | ||
audio_tensor_2 = tf.ones((10000,), dtype="float32") | ||
audio_tensor_1 = ops.ones((8000,), dtype="float32") | ||
audio_tensor_2 = ops.ones((10000,), dtype="float32") | ||
audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0) | ||
audio_converter(audio_tensor) | ||
``` | ||
|
@@ -84,33 +80,33 @@ def audio_shape(self): | |
"""Returns the preprocessed size of a single audio sample.""" | ||
return (self.max_audio_length, self.num_mels) | ||
|
||
def _get_rfftfreq_keras(self): | ||
n = self.num_fft_bins | ||
d = 1.0 / self.sampling_rate | ||
|
||
if n % 2 == 0: | ||
freqs = ops.arange(0, n // 2 + 1, dtype="float32") / (d * n) | ||
else: | ||
freqs = ops.arange(0, (n - 1) // 2 + 1, dtype="float32") / (d * n) | ||
|
||
return freqs | ||
|
||
def _get_mel_filters(self): | ||
""" | ||
Adapted from Hugging Face | ||
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86) | ||
""" | ||
|
||
# TODO: Convert to TensorFlow ops (if possible). | ||
|
||
dtype = np.float32 | ||
dtype = self.compute_dtype # Use the class's dtype | ||
# Initialize the weights | ||
weights = np.zeros( | ||
weights = ops.zeros( | ||
(self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype | ||
) | ||
|
||
# Center freqs of each FFT bin | ||
fftfreqs = np.fft.rfftfreq( | ||
n=self.num_fft_bins, d=1.0 / self.sampling_rate | ||
) | ||
|
||
fftfreqs = self._get_rfftfreq_keras() | ||
# 'Center freqs' of mel bands - uniformly spaced between limits | ||
min_mel = 0.0 | ||
max_mel = 45.245640471924965 | ||
|
||
mels = np.linspace(min_mel, max_mel, self.num_mels + 2) | ||
|
||
mels = np.asanyarray(mels) | ||
|
||
mels = ops.linspace(min_mel, max_mel, self.num_mels + 2) | ||
# Fill in the linear scale | ||
f_min = 0.0 | ||
f_sp = 200.0 / 3 | ||
|
@@ -119,93 +115,102 @@ def _get_mel_filters(self): | |
# And now the nonlinear scale | ||
min_log_hz = 1000.0 # beginning of log region (Hz) | ||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) | ||
logstep = np.log(6.4) / 27.0 # step size for log region | ||
|
||
logstep = ops.log(6.4) / 27.0 # step size for log region | ||
# If we have vector data, vectorize | ||
log_t = mels >= min_log_mel | ||
freqs[log_t] = min_log_hz * np.exp( | ||
logstep * (mels[log_t] - min_log_mel) | ||
freqs = ops.where( | ||
log_t, min_log_hz * ops.exp(logstep * (mels - min_log_mel)), freqs | ||
) | ||
|
||
mel_f = freqs | ||
fdiff = ops.diff(mel_f) | ||
ramps = ( | ||
ops.expand_dims(mel_f, axis=1) - fftfreqs | ||
) # keras subtract outer | ||
|
||
fdiff = np.diff(mel_f) | ||
ramps = np.subtract.outer(mel_f, fftfreqs) | ||
|
||
weights_list = [] | ||
for i in range(self.num_mels): | ||
# lower and upper slopes for all bins | ||
lower = -ramps[i] / fdiff[i] | ||
upper = ramps[i + 2] / fdiff[i + 1] | ||
|
||
# .. then intersect them with each other and zero | ||
weights[i] = np.maximum(0, np.minimum(lower, upper)) | ||
weights_i = ops.maximum(0, ops.minimum(lower, upper)) | ||
weights_list.append(weights_i) | ||
|
||
weights = ops.stack(weights_list) | ||
|
||
# Slaney-style mel is scaled to be approx constant energy per channel | ||
enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels]) | ||
weights *= enorm[:, np.newaxis] | ||
weights *= ops.expand_dims(enorm, axis=1) | ||
|
||
weights = np.transpose(weights) | ||
return tf.constant(weights, dtype=self.compute_dtype) | ||
weights = ops.transpose(weights) | ||
return weights | ||
|
||
def _extract_audio_features(self, audio): | ||
audio = tf.cast(audio, self.compute_dtype) | ||
audio = ops.cast(audio, self.compute_dtype) | ||
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding | ||
# internally. | ||
audio = tf.pad( | ||
audio = ops.pad( | ||
audio, | ||
paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]], | ||
mode="REFLECT", | ||
pad_width=[ | ||
[0, 0], | ||
[self.num_fft_bins // 2, self.num_fft_bins // 2], | ||
], | ||
mode="reflect", | ||
) | ||
|
||
# Compute the mel spectrogram. | ||
stft = tf.signal.stft( | ||
stft = ops.stft( | ||
audio, | ||
frame_length=self.num_fft_bins, | ||
frame_step=self.stride, | ||
sequence_length=self.num_fft_bins, | ||
sequence_stride=self.stride, | ||
fft_length=self.num_fft_bins, | ||
center=False, | ||
) | ||
magnitudes = tf.square(tf.abs(stft[:, :-1, :])) | ||
stft = ops.sum(stft, axis=0) | ||
magnitudes = ops.square(ops.absolute(stft[:, :-1, :])) | ||
|
||
mel_spec = tf.matmul( | ||
mel_spec = ops.matmul( | ||
magnitudes, | ||
self.mel_filters, | ||
) | ||
# mel_spec = ops.matmul(magnitudes,mel_filters_casted,) | ||
|
||
def tf_log10(x): | ||
"""Computes log base 10 of input tensor using TensorFlow.""" | ||
numerator = tf.math.log(x) | ||
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) | ||
numerator = ops.log(x) | ||
denominator = ops.log( | ||
ops.cast(ops.array(10), dtype=numerator.dtype) | ||
) | ||
return numerator / denominator | ||
|
||
# Clamp the values to a minimum value of 1e-10. This is done to avoid | ||
# taking the log of 0, i.e., for numerical stability. | ||
mel_spec = tf.maximum(mel_spec, 1e-10) | ||
mel_spec = ops.maximum(mel_spec, 1e-10) | ||
|
||
# Calculate the log mel spectrogram. | ||
log_spec = tf_log10(mel_spec) | ||
# Dynamic range compression. | ||
log_spec_shape = tf.shape(log_spec) | ||
max_value_minus_eight = tf.math.subtract( | ||
tf.math.reduce_max(log_spec, axis=[1, 2]), | ||
tf.cast(8, dtype=log_spec.dtype), | ||
log_spec_shape = ops.shape(log_spec) | ||
max_value_minus_eight = ops.subtract( | ||
ops.max(log_spec, axis=[1, 2]), | ||
ops.cast(8, dtype=log_spec.dtype), | ||
) | ||
max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1) | ||
max_value_minus_eight = tf.repeat( | ||
max_value_minus_eight = ops.expand_dims(max_value_minus_eight, axis=1) | ||
max_value_minus_eight = ops.repeat( | ||
max_value_minus_eight, | ||
repeats=log_spec_shape[1] * log_spec_shape[2], | ||
axis=1, | ||
) | ||
max_value_minus_eight = tf.reshape( | ||
max_value_minus_eight, shape=log_spec_shape | ||
max_value_minus_eight = ops.reshape( | ||
max_value_minus_eight, newshape=log_spec_shape | ||
) | ||
log_spec = tf.maximum(log_spec, max_value_minus_eight) | ||
log_spec = ops.maximum(log_spec, max_value_minus_eight) | ||
# Normalization. | ||
type_cast_four = tf.cast(4, dtype=log_spec.dtype) | ||
log_spec = tf.math.divide( | ||
tf.math.add(log_spec, type_cast_four), | ||
type_cast_four = ops.cast(4, dtype=log_spec.dtype) | ||
log_spec = ops.divide( | ||
ops.add(log_spec, type_cast_four), | ||
type_cast_four, | ||
) | ||
|
||
return log_spec | ||
|
||
def call(self, audio): | ||
|
@@ -214,21 +219,21 @@ def call(self, audio): | |
|
||
rank_1_input = audio.shape.rank == 1 | ||
if rank_1_input: | ||
audio = tf.expand_dims(audio, 0) | ||
audio = ops.expand_dims(audio, 0) | ||
|
||
# Convert the tensor to a Ragged Tensor. | ||
if isinstance(audio, tf.Tensor): | ||
audio = tf.RaggedTensor.from_tensor(audio) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could probably switch to a |
||
|
||
# Pad audio. | ||
audio_shape = audio.shape.as_list() | ||
audio_shape = list(audio.shape) | ||
audio_shape[-1] = self.num_samples | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe tf.shape cannot always be listified like this. Maybe call |
||
audio = audio.to_tensor(shape=audio_shape) | ||
|
||
# Find the log mel spectrogram. | ||
log_spec = self._extract_audio_features(audio) | ||
if rank_1_input: | ||
log_spec = tf.squeeze(log_spec, 0) | ||
log_spec = ops.squeeze(log_spec, 0) | ||
return log_spec | ||
|
||
def get_config(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import tensorflow as tf | ||
import keras.ops as ops | ||
|
||
from keras_hub.src.models.whisper.whisper_audio_converter import ( | ||
WhisperAudioConverter, | ||
|
@@ -15,8 +16,8 @@ def setUp(self): | |
"sampling_rate": 100, | ||
"max_audio_length": 5, | ||
} | ||
audio_tensor_1 = tf.ones((2,), dtype="float32") | ||
audio_tensor_2 = tf.ones((25,), dtype="float32") | ||
audio_tensor_1 = ops.ones((2,), dtype="float32") | ||
audio_tensor_2 = ops.ones((25,), dtype="float32") | ||
self.input_data = tf.ragged.stack( | ||
[audio_tensor_1, audio_tensor_2], | ||
axis=0, | ||
|
@@ -30,11 +31,11 @@ def test_feature_extractor_basics(self): | |
) | ||
|
||
def test_correctness(self): | ||
audio_tensor = tf.ones((2,), dtype="float32") | ||
audio_tensor = ops.ones((2,), dtype="float32") | ||
outputs = WhisperAudioConverter(**self.init_kwargs)(audio_tensor) | ||
|
||
# Verify shape. | ||
self.assertEqual(outputs.shape, (5, 80)) | ||
# Verify output. | ||
expected = [1.1656, 1.0151, -0.8343, -0.8343, -0.8343] | ||
self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01) | ||
self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep newlines at end of files. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can never do a bare import of tf like this. Check other files in the library.