-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model builder] Add option to exclude cache in inputs and outputs #1162
base: main
Are you sure you want to change the base?
Conversation
@microsoft-github-policy-service agree [company="{Hugging Face}"] |
@microsoft-github-policy-service agree company="Hugging Face" |
This modification only works for models w/ GQA. Maybe someone with a bit more experience with the model builder could help get it working for models w/ MHA? 😇 |
Does this give a memory overhead improvement? |
Indeed - not having to store and pass these values back has improved execution time in my tests. |
@@ -3267,6 +3275,8 @@ def get_args(): | |||
exclude_lm_head = Remove language modeling head from your ONNX model. | |||
Use this option when you want to remove the language modeling head from within your ONNX model. | |||
Instead of `logits`, you will have `hidden_states` as the output to your ONNX model. | |||
exclude_cache = Remove cache inputs and outputs from your ONNX model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exclude_cache = Remove cache inputs and outputs from your ONNX model. | |
exclude_kv_cache = Remove KV cache inputs and outputs from your ONNX model. |
@@ -3267,6 +3275,8 @@ def get_args(): | |||
exclude_lm_head = Remove language modeling head from your ONNX model. | |||
Use this option when you want to remove the language modeling head from within your ONNX model. | |||
Instead of `logits`, you will have `hidden_states` as the output to your ONNX model. | |||
exclude_cache = Remove cache inputs and outputs from your ONNX model. | |||
Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model. | |
Use this option when you want to remove the `past_key_values` inputs and `present` outputs from within your ONNX model. | |
Note that this should be used when you want to run ONNX models with ONNX Runtime only. ONNX Runtime GenAI requires the KV cache inputs and outputs for inference. |
@@ -111,6 +111,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | |||
elif self.include_hidden_states: | |||
self.output_names = ["hidden_states"] + self.output_names | |||
|
|||
self.exclude_cache = "exclude_cache" in extra_options |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.exclude_cache = "exclude_cache" in extra_options | |
self.exclude_cache = extra_options.get("exclude_cache", False) |
past_k, past_v, present_k, present_v = "", "", "", "" | ||
else: | ||
past_k, past_v = "", "" | ||
present_k = f"present.{layer_id}.key" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think present_k
and present_v
should be empty strings since the KV cache inputs and outputs are not in the model.
Thanks for the contribution! Can you also update the following places?
onnxruntime-genai/src/python/py/models/builder.py Lines 3133 to 3141 in 44e541e
For onnxruntime-genai/src/python/py/models/builder.py Lines 1287 to 1304 in 44e541e
|
In certain cases (e.g., single-round conversations), it is not necessary to require
past_key_values
as inputs andpresent
outputs, like with https://huggingface.co/livekit/turn-detector (and its usage here).So, this PR adds an option to exclude these inputs and outputs from the graph.
Example usage:
Output graph signature:
For comparison, graph signature w/ cache IO: