diff --git a/keras_hub/src/models/whisper/whisper_audio_converter.py b/keras_hub/src/models/whisper/whisper_audio_converter.py index e1da985cc2..45e3aafca4 100644 --- a/keras_hub/src/models/whisper/whisper_audio_converter.py +++ b/keras_hub/src/models/whisper/whisper_audio_converter.py @@ -1,14 +1,10 @@ -import numpy as np +import keras.ops as ops +import tensorflow as tf from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone -try: - import tensorflow as tf -except ImportError: - tf = None - @keras_hub_export("keras_hub.layers.WhisperAudioConverter") class WhisperAudioConverter(AudioConverter): @@ -36,7 +32,7 @@ class WhisperAudioConverter(AudioConverter): Examples: ```python - audio_tensor = tf.ones((8000,), dtype="float32") + audio_tensor = ops.ones((8000,), dtype="float32") # Compute the log-mel spectrogram. audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset( @@ -45,8 +41,8 @@ class WhisperAudioConverter(AudioConverter): audio_converter(audio_tensor) # Compute the log-mel spectrogram for a batch of audio tensors. - audio_tensor_1 = tf.ones((8000,), dtype="float32") - audio_tensor_2 = tf.ones((10000,), dtype="float32") + audio_tensor_1 = ops.ones((8000,), dtype="float32") + audio_tensor_2 = ops.ones((10000,), dtype="float32") audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0) audio_converter(audio_tensor) ``` @@ -84,33 +80,33 @@ def audio_shape(self): """Returns the preprocessed size of a single audio sample.""" return (self.max_audio_length, self.num_mels) + def _get_rfftfreq_keras(self): + n = self.num_fft_bins + d = 1.0 / self.sampling_rate + + if n % 2 == 0: + freqs = ops.arange(0, n // 2 + 1, dtype="float32") / (d * n) + else: + freqs = ops.arange(0, (n - 1) // 2 + 1, dtype="float32") / (d * n) + + return freqs + def _get_mel_filters(self): """ Adapted from Hugging Face (https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86) """ - - # TODO: Convert to TensorFlow ops (if possible). - - dtype = np.float32 + dtype = self.compute_dtype # Use the class's dtype # Initialize the weights - weights = np.zeros( + weights = ops.zeros( (self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype ) - # Center freqs of each FFT bin - fftfreqs = np.fft.rfftfreq( - n=self.num_fft_bins, d=1.0 / self.sampling_rate - ) - + fftfreqs = self._get_rfftfreq_keras() # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = 0.0 max_mel = 45.245640471924965 - - mels = np.linspace(min_mel, max_mel, self.num_mels + 2) - - mels = np.asanyarray(mels) - + mels = ops.linspace(min_mel, max_mel, self.num_mels + 2) # Fill in the linear scale f_min = 0.0 f_sp = 200.0 / 3 @@ -119,93 +115,102 @@ def _get_mel_filters(self): # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = np.log(6.4) / 27.0 # step size for log region - + logstep = ops.log(6.4) / 27.0 # step size for log region # If we have vector data, vectorize log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * np.exp( - logstep * (mels[log_t] - min_log_mel) + freqs = ops.where( + log_t, min_log_hz * ops.exp(logstep * (mels - min_log_mel)), freqs ) - mel_f = freqs + fdiff = ops.diff(mel_f) + ramps = ( + ops.expand_dims(mel_f, axis=1) - fftfreqs + ) # keras subtract outer - fdiff = np.diff(mel_f) - ramps = np.subtract.outer(mel_f, fftfreqs) - + weights_list = [] for i in range(self.num_mels): # lower and upper slopes for all bins lower = -ramps[i] / fdiff[i] upper = ramps[i + 2] / fdiff[i + 1] # .. then intersect them with each other and zero - weights[i] = np.maximum(0, np.minimum(lower, upper)) + weights_i = ops.maximum(0, ops.minimum(lower, upper)) + weights_list.append(weights_i) + + weights = ops.stack(weights_list) # Slaney-style mel is scaled to be approx constant energy per channel enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels]) - weights *= enorm[:, np.newaxis] + weights *= ops.expand_dims(enorm, axis=1) - weights = np.transpose(weights) - return tf.constant(weights, dtype=self.compute_dtype) + weights = ops.transpose(weights) + return weights def _extract_audio_features(self, audio): - audio = tf.cast(audio, self.compute_dtype) + audio = ops.cast(audio, self.compute_dtype) # Use "reflection" padding - `tf.signal.stft` uses symmetric padding # internally. - audio = tf.pad( + audio = ops.pad( audio, - paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]], - mode="REFLECT", + pad_width=[ + [0, 0], + [self.num_fft_bins // 2, self.num_fft_bins // 2], + ], + mode="reflect", ) - # Compute the mel spectrogram. - stft = tf.signal.stft( + stft = ops.stft( audio, - frame_length=self.num_fft_bins, - frame_step=self.stride, + sequence_length=self.num_fft_bins, + sequence_stride=self.stride, fft_length=self.num_fft_bins, + center=False, ) - magnitudes = tf.square(tf.abs(stft[:, :-1, :])) + stft = ops.sum(stft, axis=0) + magnitudes = ops.square(ops.absolute(stft[:, :-1, :])) - mel_spec = tf.matmul( + mel_spec = ops.matmul( magnitudes, self.mel_filters, ) + # mel_spec = ops.matmul(magnitudes,mel_filters_casted,) def tf_log10(x): """Computes log base 10 of input tensor using TensorFlow.""" - numerator = tf.math.log(x) - denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) + numerator = ops.log(x) + denominator = ops.log( + ops.cast(ops.array(10), dtype=numerator.dtype) + ) return numerator / denominator # Clamp the values to a minimum value of 1e-10. This is done to avoid # taking the log of 0, i.e., for numerical stability. - mel_spec = tf.maximum(mel_spec, 1e-10) + mel_spec = ops.maximum(mel_spec, 1e-10) # Calculate the log mel spectrogram. log_spec = tf_log10(mel_spec) # Dynamic range compression. - log_spec_shape = tf.shape(log_spec) - max_value_minus_eight = tf.math.subtract( - tf.math.reduce_max(log_spec, axis=[1, 2]), - tf.cast(8, dtype=log_spec.dtype), + log_spec_shape = ops.shape(log_spec) + max_value_minus_eight = ops.subtract( + ops.max(log_spec, axis=[1, 2]), + ops.cast(8, dtype=log_spec.dtype), ) - max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1) - max_value_minus_eight = tf.repeat( + max_value_minus_eight = ops.expand_dims(max_value_minus_eight, axis=1) + max_value_minus_eight = ops.repeat( max_value_minus_eight, repeats=log_spec_shape[1] * log_spec_shape[2], axis=1, ) - max_value_minus_eight = tf.reshape( - max_value_minus_eight, shape=log_spec_shape + max_value_minus_eight = ops.reshape( + max_value_minus_eight, newshape=log_spec_shape ) - log_spec = tf.maximum(log_spec, max_value_minus_eight) + log_spec = ops.maximum(log_spec, max_value_minus_eight) # Normalization. - type_cast_four = tf.cast(4, dtype=log_spec.dtype) - log_spec = tf.math.divide( - tf.math.add(log_spec, type_cast_four), + type_cast_four = ops.cast(4, dtype=log_spec.dtype) + log_spec = ops.divide( + ops.add(log_spec, type_cast_four), type_cast_four, ) - return log_spec def call(self, audio): @@ -214,21 +219,21 @@ def call(self, audio): rank_1_input = audio.shape.rank == 1 if rank_1_input: - audio = tf.expand_dims(audio, 0) + audio = ops.expand_dims(audio, 0) # Convert the tensor to a Ragged Tensor. if isinstance(audio, tf.Tensor): audio = tf.RaggedTensor.from_tensor(audio) # Pad audio. - audio_shape = audio.shape.as_list() + audio_shape = list(audio.shape) audio_shape[-1] = self.num_samples audio = audio.to_tensor(shape=audio_shape) # Find the log mel spectrogram. log_spec = self._extract_audio_features(audio) if rank_1_input: - log_spec = tf.squeeze(log_spec, 0) + log_spec = ops.squeeze(log_spec, 0) return log_spec def get_config(self): diff --git a/keras_hub/src/models/whisper/whisper_audio_converter_test.py b/keras_hub/src/models/whisper/whisper_audio_converter_test.py index 6e6d451748..f5054e933e 100644 --- a/keras_hub/src/models/whisper/whisper_audio_converter_test.py +++ b/keras_hub/src/models/whisper/whisper_audio_converter_test.py @@ -1,4 +1,5 @@ import tensorflow as tf +import keras.ops as ops from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, @@ -15,8 +16,8 @@ def setUp(self): "sampling_rate": 100, "max_audio_length": 5, } - audio_tensor_1 = tf.ones((2,), dtype="float32") - audio_tensor_2 = tf.ones((25,), dtype="float32") + audio_tensor_1 = ops.ones((2,), dtype="float32") + audio_tensor_2 = ops.ones((25,), dtype="float32") self.input_data = tf.ragged.stack( [audio_tensor_1, audio_tensor_2], axis=0, @@ -30,11 +31,11 @@ def test_feature_extractor_basics(self): ) def test_correctness(self): - audio_tensor = tf.ones((2,), dtype="float32") + audio_tensor = ops.ones((2,), dtype="float32") outputs = WhisperAudioConverter(**self.init_kwargs)(audio_tensor) # Verify shape. self.assertEqual(outputs.shape, (5, 80)) # Verify output. expected = [1.1656, 1.0151, -0.8343, -0.8343, -0.8343] - self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01) + self.assertAllClose(outputs[:, 0], expected, atol=0.01, rtol=0.01) \ No newline at end of file