Skip to content

Fix ValidateInputs crash for AOTI (pt2) models#185

Open
danielsig727 wants to merge 2 commits into
triton-inference-server:mainfrom
danielsig727:call_spec_fix
Open

Fix ValidateInputs crash for AOTI (pt2) models#185
danielsig727 wants to merge 2 commits into
triton-inference-server:mainfrom
danielsig727:call_spec_fix

Conversation

@danielsig727
Copy link
Copy Markdown

@danielsig727 danielsig727 commented Apr 3, 2026

Fix ValidateInputs crash for AOTI models

Component: src/pt2/model_instance_state.ccModelInstanceState::ValidateInputs
Severity: Critical — all torch_aoti (.pt2) models fail to load
Affects: pytorch_backend from the initial AOTI commit (688df56, 2026-03-04) onwards


Symptom

Any torch_aoti model fails immediately at load time:

Failed to load model "my_model_0" configuration expects 1 inputs, but model expects 2 inputs.

This fires for every model regardless of actual input count. The model never becomes READY.


Root Cause

ValidateInputs calls GetModelCallSpec() (which wraps AOTIModelPackageLoader::get_call_spec()) and uses the result's size as the input count. Since PyTorch first introduced get_call_spec() (#110020, Oct 2023) it has always returned exactly two strings: the pytree specs for inputs and outputs.

// Bug since commit 688df56 — allowed_inputs.size() is always 2
std::vector<std::string> allowed_inputs{model_->GetModelCallSpec()};
if (allowed_inputs.size() != expected_input_count) {   // always 2 != 1
    THROW_TRITON_EXCEPTION(..., "model expects " << allowed_inputs.size() << " inputs.");
}

What get_call_spec() actually returns

Both strings are pytree specs from treespec_dumps() (#106116), serialised as [<header_int>, <tree_def>]. The in_spec encodes the (args, kwargs) call signature:

in_spec = [1, {"type": "builtins.tuple", ..., "children_spec": [
               {"type": "builtins.tuple", ..., "children_spec": [<N leaf nodes>]},  ← args
               {"type": "builtins.dict",  ..., "children_spec": []}                 ← kwargs
           ]}]

The header integer json[0] is always 1 — it counts the outer (args, kwargs) wrapper as a single node, not the number of tensor inputs. The true input count is the length of the inner args tuple's children_spec:

in_spec[1]["children_spec"][0]["children_spec"].size()   ← correct count

Verified across 1-, 3-, and 5-input models (ARM64, PyTorch 2.11, NGC 26.03):

Model len(get_call_spec()) json[0] in_spec[1]["children_spec"][0]["children_spec"].size()
1 input 2 1 1
3 inputs 2 1 3
5 inputs 2 1 5

Fix

Parse in_spec with TritonJson::Value (already used throughout the file) and navigate the pytree JSON to count args leaves directly.

-  std::vector<std::string> allowed_inputs{model_->GetModelCallSpec()};
-  if (allowed_inputs.size() != expected_input_count) {
+  // get_call_spec() returns [in_spec, out_spec] — always exactly 2 strings.
+  // in_spec is a pytree spec "[<header_int>, <tree_def_object>]" produced by
+  // treespec_dumps().  The outer tuple encodes (args, kwargs), so in_spec is:
+  //   [1, {"type": "builtins.tuple", ..., "children_spec": [
+  //       {"type": "builtins.tuple", ..., "children_spec": [<N leaf nodes>]},
+  //       {"type": "builtins.dict",  ..., "children_spec": []}
+  //   ]}]
+  // json[0] is always 1 — it counts the outer (args, kwargs) wrapper as one
+  // node, NOT the number of inputs.  True count: in_spec[1]["children_spec"][0]
+  // ["children_spec"].size().
+  std::vector<std::string> call_spec{model_->GetModelCallSpec()};
+  if (call_spec.size() != 2 || call_spec[0].empty() || call_spec[0][0] != '[') {
+    THROW_TRITON_EXCEPTION(TRITONSERVER_ERROR_INTERNAL,
+        "Unexpected get_call_spec() format for model \"" << Name() << "\"");
+  }
+
+  // Parse in_spec with TritonJson to count the positional tensor inputs.
+  //
+  // TritonJson::Value::Parse accepts a top-level JSON array (rapidjson handles
+  // both objects and arrays).  in_spec is [<int>, <object>], so element [0] is
+  // an integer (skipped) and element [1] is the tree_def object.
+  //
+  // We need:  in_spec[1]["children_spec"][0]["children_spec"].size()
+  //
+  //   [1]                → tree_def for the outer (args, kwargs) tuple
+  //   ["children_spec"]  → array of two children: [args_tuple, kwargs_dict]
+  //   [0]                → the args_tuple node (kwargs dict is always empty)
+  //   ["children_spec"]  → one entry per positional tensor argument
+  //
+  // IndexAsObject / MemberAsArray return nullptr on success, so chaining them
+  // with && short-circuits at the first failure.
+  size_t model_input_count = 0;
+  {
+    TritonJsonValue in_spec_doc;
+    TRITONSERVER_Error* parse_err =
+        in_spec_doc.Parse(call_spec[0].c_str(), call_spec[0].size());
+    if (parse_err != nullptr) {
+      THROW_TRITON_EXCEPTION(parse_err,
+          "Failed to parse in_spec JSON for model \"" << Name() << "\"");
+    }
+    TritonJsonValue outer_tree;      // in_spec_doc[1]               — outer (args,kwargs) tree_def
+    TritonJsonValue outer_children;  // outer_tree["children_spec"]  — [args_tuple, kwargs_dict]
+    TritonJsonValue args_tree;       // outer_children[0]            — the args tuple node
+    TritonJsonValue args_children;   // args_tree["children_spec"]   — one entry per input tensor
+    bool nav_ok =
+        (in_spec_doc.IndexAsObject(1, &outer_tree) == nullptr) &&
+        (outer_tree.MemberAsArray("children_spec", &outer_children) == nullptr) &&
+        (outer_children.IndexAsObject(0, &args_tree) == nullptr) &&
+        (args_tree.MemberAsArray("children_spec", &args_children) == nullptr);
+    if (!nav_ok) {
+      THROW_TRITON_EXCEPTION(TRITONSERVER_ERROR_INTERNAL,
+          "Unexpected in_spec pytree structure for model \"" << Name() << "\": "
+          << call_spec[0].substr(0, 200));
+    }
+    model_input_count = static_cast<size_t>(args_children.ArraySize());
+  }
+
+  if (model_input_count != expected_input_count) {
     THROW_TRITON_EXCEPTION(TRITONSERVER_ERROR_INTERNAL,
         "Failed to load model \"" << Name() << "\" configuration expects "
-        << expected_input_count << " inputs, but model expects "
-        << allowed_inputs.size() << " inputs.");
+        << expected_input_count << " inputs, but model expects "
+        << model_input_count << " inputs.");
   }

   /* ... ios validation, dtype checks unchanged ... */

-  auto naming_convention = GetNamingConvention(allowed_inputs);
+  // Pytree specs encode the tensor tree structure but carry no argument names,
+  // so positional ordering is the only option for AOTI models.
+  // STRICT_CONFIG_ORDERING maps each config input to its runner slot by the
+  // loop index i, i.e.  input_index_map_[io_name] = i  (see AddInputToMap).
+  auto naming_convention = TritonNamingConvention::STRICT_CONFIG_ORDERING;
+
+  // allowed_inputs is only consulted by the FORWARD_ARGUMENT branch inside
+  // AddInputToMap (name-based lookup).  Under STRICT_CONFIG_ORDERING that
+  // branch is never reached, so this vector is never read — it exists solely
+  // to satisfy the function signature.
+  std::vector<std::string> allowed_inputs{};

Validation

Environment: ARM64 (Apple Silicon / OrbStack), nvcr.io/nvidia/tritonserver:26.03-py3, PyTorch 2.11.0+cpu, CPU-only instance group.

Stock backend — fails for every model regardless of input count:

# 1-input model (config expects 1, get_call_spec() returns 2 strings)
Failed to load model "minimal_model_0_0" configuration expects 1 inputs, but model expects 2 inputs.

# 3-input model
Failed to load model "model_3in_0_0" configuration expects 3 inputs, but model expects 2 inputs.

Patched backend — all models load, outputs verified numerically (exit 0):

✓ Model 'minimal_model' is READY   output matches x * 2.0      ✓ Inference PASSED
✓ Model 'model_1in'     is READY   output matches x_1 * 2.0    ✓ Inference PASSED
✓ Model 'model_3in'     is READY   output matches x_1+x_2+x_3  ✓ Inference PASSED
✓ Model 'model_5in'     is READY   output matches x_1+…+x_5    ✓ Inference PASSED

get_call_spec() values observed on ARM64 PyTorch 2.11, across all input counts:

model_1in  len(specs)=2  json[0]=1  treespec_loads().num_leaves=1  ✓ PASS
model_3in  len(specs)=2  json[0]=1  treespec_loads().num_leaves=3  ✓ PASS
model_5in  len(specs)=2  json[0]=1  treespec_loads().num_leaves=5  ✓ PASS

References

@danielsig727
Copy link
Copy Markdown
Author

cc the original author of pt2 support @whoisj 🙏

@whoisj
Copy link
Copy Markdown
Contributor

whoisj commented Apr 3, 2026

@danielsig727, thanks for your contribution here. Though, I'm not certain that it is correct.

Do you have a forward(self, ...): signature that actually produces an output from call_spec() that with a string value that does not start with "[1, "?

Asking, because I've been unable to produce this. Even with multi-argument signatures like forward(self, INPUT0, INPUT1, INPUT2):, the returned value always seems to start with "[1, ".

@danielsig727
Copy link
Copy Markdown
Author

Asking, because I've been unable to produce this. Even with multi-argument signatures like forward(self, INPUT0, INPUT1, INPUT2):, the returned value always seems to start with "[1, ".

thanks you're correct! missed that i've only tested with models with 1 input. I'm working on another fix to further parse the call_spec instead

@danielsig727 danielsig727 force-pushed the call_spec_fix branch 2 times, most recently from 01bfd9e to e200a92 Compare April 4, 2026 03:46
get_call_spec() returns exactly two pytree serialisation strings
{in_spec, out_spec} since pytorch/pytorch#110020. ValidateInputs
incorrectly used the vector size (always 2) as the model input count,
causing every torch_aoti model to fail with:

  "configuration expects N inputs, but model expects 2 inputs"

Parse the true input count from the leading integer in in_spec
("[num_leaves, tree_def]"), fail fast with a clear error if the format
is unexpected, and use STRICT_CONFIG_ORDERING unconditionally since
pytree specs carry no argument names.
@danielsig727
Copy link
Copy Markdown
Author

I've updated the diff and description. locally validated with models with 1, 3, and 5 inputs

Extends the ValidateInputs kwargs guard to also navigate the kwargs subtree
of in_spec and fail at load time if any kwarg tensor leaves are present.

Triton has no mechanism to supply named keyword arguments to a model, so an
AOTI model exported via torch.export.export(..., kwargs={"x": t}) would
silently undercount its required inputs under the old code. The new guard
surfaces this as a clear INVALID_ARG error at model load time with an
actionable message directing users to re-export using positional args only.

Also extends validate/export_multi_input.py and validate/check_call_spec.py
to inspect the kwargs subtree and assert it is empty for Triton-compatible
models, and to export a contrast kwargs model that demonstrates the non-empty
case detected by the new guard.
@danielsig727
Copy link
Copy Markdown
Author

danielsig727 commented Apr 4, 2026

need some confirmation from you-- currently, the proposed fix only checks args from in_spec, and asserts that the kwargs to be empty. This imposes a limitation that models exported with kwargs won't be supported.

However technically it's not impossible to allow all args+kwargs inputs together. But the challange is for user it's trickier to make sure order of the input names in config.pbtxt matches the exported order, especially the keys in kwargs

@whoisj
Copy link
Copy Markdown
Contributor

whoisj commented Apr 6, 2026

@danielsig727 will torch._inductor.compile_and_package() even handle kwargs?

I've not been able to get it to handle dynamic keys in dictionaries being passed as inputs/outputs reliably.

@danielsig727
Copy link
Copy Markdown
Author

danielsig727 commented Apr 7, 2026

it does seem to accept model exported with kwargs during torch.export.export. I could get 1 args and 1 kwargs with the script below on torch 2.11.

However, i agree with you that it might not work if your "dynamic keys" instead is referring to a situation when a model accepts non-fixed kwarg keys. But that could be a more rare use case IMO that might not even worth to support.

import json
import os
import zipfile

import torch


class KwargsModel(torch.nn.Module):
    def forward(self, x_pos: torch.Tensor, *, x_kw: torch.Tensor) -> torch.Tensor:
        return x_pos + x_kw


def main():
    output = "/tmp/kwargs_model.pt2"

    print(f"PyTorch version: {torch.__version__}")

    model = KwargsModel().eval()
    x_pos = torch.randn(2, 4)
    x_kw = torch.randn(2, 4)

    # Export with one positional arg and one kwarg so in_spec kwargs subtree is non-empty.
    batch = torch.export.Dim.AUTO
    ep = torch.export.export(
        model,
        args=(x_pos,),
        kwargs={"x_kw": x_kw},
        dynamic_shapes={"x_pos": {0: batch}, "x_kw": {0: batch}},
    )

    os.makedirs(os.path.dirname(os.path.abspath(output)), exist_ok=True)
    torch._inductor.aoti_compile_and_package(ep, package_path=output)
    print(f"Exported → {output}")

    # Discover model name from the archive.
    model_name = "model"
    with zipfile.ZipFile(output) as z:
        for entry in z.namelist():
            parts = entry.split("/")
            if len(parts) >= 4 and parts[2] == "aotinductor":
                model_name = parts[3]
                break

    # Load and inspect in_spec.
    loader = torch._C._aoti.AOTIModelPackageLoader(output, model_name, False, 1, -1)

    in_spec_raw = loader.get_call_spec()[0]
    parsed = json.loads(in_spec_raw)

    args_children = len(parsed[1]["children_spec"][0]["children_spec"])
    kwargs_children = len(parsed[1]["children_spec"][1]["children_spec"])

    print(f"\nin_spec (formatted):\n{json.dumps(parsed, indent=2)}")
    print(f"args children       = {args_children}   ← positional tensors")
    print(f"kwargs children     = {kwargs_children}   ← keyword tensors")

if __name__ == "__main__":
    main()

locally i get this

in_spec (formatted):
[
  1,
  {
    "type": "builtins.tuple",
    "context": "null",
    "children_spec": [
      {
        "type": "builtins.tuple",
        "context": "null",
        "children_spec": [
          {
            "type": null,
            "context": null,
            "children_spec": []
          }
        ]
      },
      {
        "type": "builtins.dict",
        "context": "[\"x_kw\"]",
        "children_spec": [
          {
            "type": null,
            "context": null,
            "children_spec": []
          }
        ]
      }
    ]
  }
]
args children       = 1   ← positional tensors
kwargs children     = 1   ← keyword tensors

@whoisj
Copy link
Copy Markdown
Contributor

whoisj commented Apr 8, 2026

@danielsig727 it appears that PyTorch has encoded your examples as

"children_spec": [
      {
        "type": "builtins.tuple",
        "context": "null",
        "children_spec": [
          {
            "type": null,
            "context": null,
            "children_spec": []
          }
        ]
      },
      {
        "type": "builtins.dict",
        "context": "[\"x_kw\"]",
        "children_spec": [
          {
            "type": null,
            "context": null,
            "children_spec": []
          }
        ]
      }
    ]

which is basically a tuple with

  1. a tensor.Tensor
  2. a dict[string, tensor.Tensor] which ONLY has the key "x_kw"

As I understand it, if any other keys are passed to the model's .forward() function, it'll result in an error.

for example:

  input_dict = { "not_x_kw": some_tensor }
  model.forward(input_tensor, input_dict)

Can you confirm?

@danielsig727
Copy link
Copy Markdown
Author

danielsig727 commented Apr 9, 2026

Hmm i feel we're referring to different things.

The in_spec in the exported model only reflects how the arguments were provided to the exporter during torch.export.export. And not exactly reflecting whether kwargs is used in forward() signature. The exporter works as long as forward() can be called correctly.

in_spec contains a tuple of 2 elements: ((args...), dict{kwargs...})
For example, i have another model with forward(x_pos1, x_pos2, x_kw1, x_kw2). Even w/o kwargs, i can still export it with export(args=(x_pos1, x_pos2), kwargs={"x_kw1": x_kw1, "x_kw2": x_kw2}), resulting in_spec of

[
  1,
  {
    "type": "builtins.tuple",
    "context": "null",
    "children_spec": [
      {
        "type": "builtins.tuple",
        "context": "null",
        "children_spec": [
          {
            "type": null,
            "context": null,
            "children_spec": []
          },
          {
            "type": null,
            "context": null,
            "children_spec": []
          }
        ]
      },
      {
        "type": "builtins.dict",
        "context": "[\"x_kw1\", \"x_kw2\"]",
        "children_spec": [
          {
            "type": null,
            "context": null,
            "children_spec": []
          },
          {
            "type": null,
            "context": null,
            "children_spec": []
          }
        ]
      }
    ]
  }
]

So in this PR, the validation will reject models exported with kwargs in torch.export.export call (again, not kwargs usage in forward()). I think it's a reasonable design decision based on how the rest of code handles input tensors as flattened std::vectortorch::Tensor.

@danielsig727
Copy link
Copy Markdown
Author

hi @whoisj are there any updates regarding this? 🙏

@whoisj
Copy link
Copy Markdown
Contributor

whoisj commented Apr 17, 2026

@danielsig727 it is being actively worked on. Deciding on and agreeing on the mapping from Triton's config.pbtxt -> PT2 callspec -> model.forward() is taking time as there are quite a few things to be concerned about, with forward/backward compatibility being one of them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

2 participants