Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Redesigned Solution #23549

Open
wants to merge 38 commits into
base: main
Choose a base branch
from

Conversation

kunal-vaishnavi
Copy link
Contributor

@kunal-vaishnavi kunal-vaishnavi commented Jan 31, 2025

Description

This PR re-designs how Whisper is created and supported in ONNX Runtime. The new solution leverages previous optimization work, and it is designed to be used in conjunction with ONNX Runtime GenAI.

Some of the added changes include:

  • Re-designed export that creates new ONNX models without needing a WhisperBeamSearch op
    • Creates one encoder model that also pre-computes the cross-attention KV caches (since they only need to be run once)
    • Creates one decoder model that can be used during pre-fill and token generation
    • Creates one jump-times model that can be used for word-level timestamps
    • Removes need for a WhisperBeamSearch op to chain the encoder and decoder subgraphs
    • Removes need to duplicate decoder's weights in memory
      • Previous solution with the WhisperBeamSearch op created an encoder-decoder-init model and decoder-with-past model. The decoder was duplicated twice, one in each.
    • Removes need for separate logic to export the PyTorch model coming from OpenAI vs. the PyTorch model coming from Hugging Face
  • Re-factors common parameters and logic used in CPU and CUDA attention kernels
    • Adds DUMP_STRING to enable easy logging of intermediate information when running in debug mode to debug a problem. This info is not printed in release mode so it will not impact performance.
    • Integrates DecoderMaskedMultiHeadAttention into MultiHeadAttention
    • Enables past-present buffer sharing in the MultiHeadAttention op for improved performance
    • Adds cache_indirection and past_sequence_length as new optional inputs to MultiHeadAttention
    • Adds output_qk as new optional output to MultiHeadAttention
    • Enables calculating output_qk tensor with FP16 or FP32 precision, regardless of the model's precision
  • CI tests that run end-to-end across various flag combinations that are used by many customers internally and externally

The existing solutions are still available if desired.

Known Issues

  • The FP32 CPU model with the WhisperBeamSearch op and output QK is currently disabled. This is because ONNX Runtime doesn't currently support output QK kernels on CPU, only on CUDA.
  • The FP32 CPU model has a bug with Neg --> Shape in the jump times model when exporting the model to contain the WhisperBeamSearch op.
  • The DecoderMaskedMultiHeadAttention CPU kernel has a parity mismatch with the DecoderMaskedMultiHeadAttention CUDA kernel.
  • Using DecoderMaskedMultiHeadAttention for the FP32 CPU model is not enabled. Currently, it uses MultiHeadAttention to avoid the parity mismatch issue.

Motivation and Context

Using the beam search op has made it more difficult to debug and fix errors that are encountered. This new approach is more flexible and more customizable for users (e.g. by running with ONNX Runtime GenAI). It also helps this issue.

kunal-vaishnavi and others added 30 commits April 25, 2024 18:32
return model


def fix_past_sequence_length(model: ModelProto):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
diff = np.abs(pt_outputs[i] - ort_outputs[i])
logger.warning(f"Comparing {output_name}...")
logger.warning(f"Max diff: {np.max(diff)}")
except: # noqa: E722

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
@@ -0,0 +1,195 @@
import numpy as np

import onnxruntime as ort

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnxruntime' is imported with both 'import' and 'import from'.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants