-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper Redesigned Solution #23549
base: main
Are you sure you want to change the base?
Whisper Redesigned Solution #23549
Conversation
return model | ||
|
||
|
||
def fix_past_sequence_length(model: ModelProto): |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py
Dismissed
Show dismissed
Hide dismissed
diff = np.abs(pt_outputs[i] - ort_outputs[i]) | ||
logger.warning(f"Comparing {output_name}...") | ||
logger.warning(f"Max diff: {np.max(diff)}") | ||
except: # noqa: E722 |
Check notice
Code scanning / CodeQL
Except block handles 'BaseException' Note
onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
Dismissed
Show dismissed
Hide dismissed
onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py
Dismissed
Show dismissed
Hide dismissed
@@ -0,0 +1,195 @@ | |||
import numpy as np | |||
|
|||
import onnxruntime as ort |
Check notice
Code scanning / CodeQL
Module is imported with 'import' and 'import from' Note test
Description
This PR re-designs how Whisper is created and supported in ONNX Runtime. The new solution leverages previous optimization work, and it is designed to be used in conjunction with ONNX Runtime GenAI.
Some of the added changes include:
WhisperBeamSearch
opWhisperBeamSearch
op to chain the encoder and decoder subgraphsWhisperBeamSearch
op created an encoder-decoder-init model and decoder-with-past model. The decoder was duplicated twice, one in each.DUMP_STRING
to enable easy logging of intermediate information when running in debug mode to debug a problem. This info is not printed in release mode so it will not impact performance.DecoderMaskedMultiHeadAttention
intoMultiHeadAttention
MultiHeadAttention
op for improved performancecache_indirection
andpast_sequence_length
as new optional inputs toMultiHeadAttention
output_qk
as new optional output toMultiHeadAttention
output_qk
tensor with FP16 or FP32 precision, regardless of the model's precisionThe existing solutions are still available if desired.
Known Issues
WhisperBeamSearch
op and output QK is currently disabled. This is because ONNX Runtime doesn't currently support output QK kernels on CPU, only on CUDA.Neg --> Shape
in the jump times model when exporting the model to contain theWhisperBeamSearch
op.DecoderMaskedMultiHeadAttention
CPU kernel has a parity mismatch with theDecoderMaskedMultiHeadAttention
CUDA kernel.DecoderMaskedMultiHeadAttention
for the FP32 CPU model is not enabled. Currently, it usesMultiHeadAttention
to avoid the parity mismatch issue.Motivation and Context
Using the beam search op has made it more difficult to debug and fix errors that are encountered. This new approach is more flexible and more customizable for users (e.g. by running with ONNX Runtime GenAI). It also helps this issue.