Fix ValidateInputs crash for AOTI (pt2) models#185
Conversation
|
cc the original author of pt2 support @whoisj 🙏 |
efe9379 to
8080040
Compare
|
@danielsig727, thanks for your contribution here. Though, I'm not certain that it is correct. Do you have a Asking, because I've been unable to produce this. Even with multi-argument signatures like |
thanks you're correct! missed that i've only tested with models with 1 input. I'm working on another fix to further parse the call_spec instead |
01bfd9e to
e200a92
Compare
get_call_spec() returns exactly two pytree serialisation strings
{in_spec, out_spec} since pytorch/pytorch#110020. ValidateInputs
incorrectly used the vector size (always 2) as the model input count,
causing every torch_aoti model to fail with:
"configuration expects N inputs, but model expects 2 inputs"
Parse the true input count from the leading integer in in_spec
("[num_leaves, tree_def]"), fail fast with a clear error if the format
is unexpected, and use STRICT_CONFIG_ORDERING unconditionally since
pytree specs carry no argument names.
e200a92 to
a567b73
Compare
|
I've updated the diff and description. locally validated with models with 1, 3, and 5 inputs |
Extends the ValidateInputs kwargs guard to also navigate the kwargs subtree
of in_spec and fail at load time if any kwarg tensor leaves are present.
Triton has no mechanism to supply named keyword arguments to a model, so an
AOTI model exported via torch.export.export(..., kwargs={"x": t}) would
silently undercount its required inputs under the old code. The new guard
surfaces this as a clear INVALID_ARG error at model load time with an
actionable message directing users to re-export using positional args only.
Also extends validate/export_multi_input.py and validate/check_call_spec.py
to inspect the kwargs subtree and assert it is empty for Triton-compatible
models, and to export a contrast kwargs model that demonstrates the non-empty
case detected by the new guard.
|
need some confirmation from you-- currently, the proposed fix only checks args from in_spec, and asserts that the kwargs to be empty. This imposes a limitation that models exported with kwargs won't be supported. However technically it's not impossible to allow all args+kwargs inputs together. But the challange is for user it's trickier to make sure order of the input names in config.pbtxt matches the exported order, especially the keys in kwargs |
|
@danielsig727 will I've not been able to get it to handle dynamic keys in dictionaries being passed as inputs/outputs reliably. |
|
it does seem to accept model exported with However, i agree with you that it might not work if your "dynamic keys" instead is referring to a situation when a model accepts non-fixed kwarg keys. But that could be a more rare use case IMO that might not even worth to support. import json
import os
import zipfile
import torch
class KwargsModel(torch.nn.Module):
def forward(self, x_pos: torch.Tensor, *, x_kw: torch.Tensor) -> torch.Tensor:
return x_pos + x_kw
def main():
output = "/tmp/kwargs_model.pt2"
print(f"PyTorch version: {torch.__version__}")
model = KwargsModel().eval()
x_pos = torch.randn(2, 4)
x_kw = torch.randn(2, 4)
# Export with one positional arg and one kwarg so in_spec kwargs subtree is non-empty.
batch = torch.export.Dim.AUTO
ep = torch.export.export(
model,
args=(x_pos,),
kwargs={"x_kw": x_kw},
dynamic_shapes={"x_pos": {0: batch}, "x_kw": {0: batch}},
)
os.makedirs(os.path.dirname(os.path.abspath(output)), exist_ok=True)
torch._inductor.aoti_compile_and_package(ep, package_path=output)
print(f"Exported → {output}")
# Discover model name from the archive.
model_name = "model"
with zipfile.ZipFile(output) as z:
for entry in z.namelist():
parts = entry.split("/")
if len(parts) >= 4 and parts[2] == "aotinductor":
model_name = parts[3]
break
# Load and inspect in_spec.
loader = torch._C._aoti.AOTIModelPackageLoader(output, model_name, False, 1, -1)
in_spec_raw = loader.get_call_spec()[0]
parsed = json.loads(in_spec_raw)
args_children = len(parsed[1]["children_spec"][0]["children_spec"])
kwargs_children = len(parsed[1]["children_spec"][1]["children_spec"])
print(f"\nin_spec (formatted):\n{json.dumps(parsed, indent=2)}")
print(f"args children = {args_children} ← positional tensors")
print(f"kwargs children = {kwargs_children} ← keyword tensors")
if __name__ == "__main__":
main()locally i get this |
|
@danielsig727 it appears that PyTorch has encoded your examples as which is basically a tuple with
As I understand it, if any other keys are passed to the model's for example: input_dict = { "not_x_kw": some_tensor }
model.forward(input_tensor, input_dict)Can you confirm? |
|
Hmm i feel we're referring to different things. The in_spec in the exported model only reflects how the arguments were provided to the exporter during in_spec contains a tuple of 2 elements: ((args...), dict{kwargs...}) [
1,
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": "builtins.tuple",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
},
{
"type": "builtins.dict",
"context": "[\"x_kw1\", \"x_kw2\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
}
]So in this PR, the validation will reject models exported with kwargs in |
|
hi @whoisj are there any updates regarding this? 🙏 |
|
@danielsig727 it is being actively worked on. Deciding on and agreeing on the mapping from Triton's config.pbtxt -> PT2 callspec -> |
Fix
ValidateInputscrash for AOTI modelsComponent:
src/pt2/model_instance_state.cc—ModelInstanceState::ValidateInputsSeverity: Critical — all
torch_aoti(.pt2) models fail to loadAffects:
pytorch_backendfrom the initial AOTI commit (688df56, 2026-03-04) onwardsSymptom
Any
torch_aotimodel fails immediately at load time:This fires for every model regardless of actual input count. The model never becomes
READY.Root Cause
ValidateInputscallsGetModelCallSpec()(which wrapsAOTIModelPackageLoader::get_call_spec()) and uses the result's size as the input count. Since PyTorch first introducedget_call_spec()(#110020, Oct 2023) it has always returned exactly two strings: the pytree specs for inputs and outputs.What
get_call_spec()actually returnsBoth strings are pytree specs from
treespec_dumps()(#106116), serialised as[<header_int>, <tree_def>]. Thein_specencodes the(args, kwargs)call signature:The header integer
json[0]is always1— it counts the outer(args, kwargs)wrapper as a single node, not the number of tensor inputs. The true input count is the length of the inner args tuple'schildren_spec:Verified across 1-, 3-, and 5-input models (ARM64, PyTorch 2.11, NGC 26.03):
len(get_call_spec())json[0]in_spec[1]["children_spec"][0]["children_spec"].size()Fix
Parse
in_specwithTritonJson::Value(already used throughout the file) and navigate the pytree JSON to count args leaves directly.Validation
Environment: ARM64 (Apple Silicon / OrbStack),
nvcr.io/nvidia/tritonserver:26.03-py3, PyTorch 2.11.0+cpu, CPU-only instance group.Stock backend — fails for every model regardless of input count:
Patched backend — all models load, outputs verified numerically (exit 0):
get_call_spec()values observed on ARM64 PyTorch 2.11, across all input counts:References
treespec_dumpspytree serialisation.get_call_spec()returning{in_spec, out_spec}.688df56— introduced the incorrectValidateInputsassumption.