Skip to content
30 changes: 27 additions & 3 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import gc
import inspect
import logging
import re
import shutil
import subprocess
import warnings
Expand All @@ -18,10 +19,12 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, RenameFunctionOutputsTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.cache_utils import InvalidIndexProvider
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
Expand All @@ -32,6 +35,7 @@
hash_dict_params,
load_json,
)
from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,6 +183,7 @@ def _export(
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
use_onnx_subfunctions: bool = False,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -243,7 +248,19 @@ def _export(
input_names.append(param)

try:
# Initialize the registry with your custom ops
export_kwargs = {} if export_kwargs is None else export_kwargs
if use_onnx_subfunctions:
warnings.warn(
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
)
apply_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = True
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
self._onnx_transforms.append(RenameFunctionOutputsTransform)
self._onnx_transforms.append(CustomOpTransform)

torch.onnx.export(
self.model,
(example_inputs,),
Expand All @@ -255,7 +272,6 @@ def _export(
**export_kwargs,
)
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)

model = onnx.load(tmp_onnx_path, load_external_data=False)
Expand Down Expand Up @@ -284,6 +300,12 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

if use_onnx_subfunctions:
undo_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = False
self._onnx_transforms.remove(CustomOpTransform)
self._onnx_transforms.remove(RenameFunctionOutputsTransform)

self.onnx_path = onnx_path
return onnx_path

Expand All @@ -300,6 +322,7 @@ def _compile(
num_speculative_tokens: Optional[int] = None,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
**compiler_options,
) -> str:
"""
Expand All @@ -325,8 +348,9 @@ def _compile(

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""

if onnx_path is None and self.onnx_path is None:
self.export()
self.export(use_onnx_subfunctions=use_onnx_subfunctions)

onnx_path = Path(onnx_path or self.onnx_path)
compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down
97 changes: 96 additions & 1 deletion QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
#
# ----------------------------------------------------------------------------

from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from onnx import ModelProto, external_data_helper, numpy_helper

from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc


class OnnxTransform:
"""
Expand Down Expand Up @@ -99,3 +103,94 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class CustomOpTransform(OnnxTransform):
"""
Transform to register custom operations and add their function protos to the ONNX model.
"""

_custom_ops: Dict[str, Tuple[Any, Any]] = {
"CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm),
"CtxScatterFunc": (CtxScatterFunc, CtxScatter),
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
}

@classmethod
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any) -> None:
"""Register a custom operation."""
cls._custom_ops[op_name] = (func_class, onnxscript_func)

@classmethod
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
"""
Apply custom op registration and add all function protos to the model.

:param model: The ONNX model to transform.
:param opset_version: ONNX opset version for symbolic registration.
:returns: (Transformed model, success flag).
"""
transformed = False

# Register all custom op symbolic functions with torch.onnx
for op_name, (func_class, _) in cls._custom_ops.items():
if hasattr(func_class, "symbolic"):
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)

func_names = {func.name for func in model.functions}

for _, onnxscript_func in cls._custom_ops.values():
proto = onnxscript_func.to_function_proto()
if proto.name not in func_names:
model.functions.append(proto)
transformed = True

return model, transformed


class RenameFunctionOutputsTransform(OnnxTransform):
"""
Renames function outputs in decoder layers by removing 'Internal' from '_InternalRetainedState' patterns.
"""

@classmethod
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
"""
Rename function outputs in decoder layer nodes.

:param model: The ONNX model to transform
:returns: Transformed model and boolean indicating whether transform was applied
"""
graph = model.graph
op_type_to_func_map = {func.name: func for func in model.functions}
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
transformed = False

# Create a dict mapping output name to its index for quick lookup
model_graph_outputs_map = {val.name: idx for idx, val in enumerate(model.graph.output)}

layer_index = 0
for node in graph.node:
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
func = op_type_to_func_map.get(node.op_type)
if func is None:
continue

for i, out_name in enumerate(func.output):
if "_InternalRetainedState" in out_name:
transformed = True
original_output_name = node.output[i]

# Generate new name based on key/value
if "key" in out_name:
new_name = f"past_key.{layer_index}_RetainedState"
elif "value" in out_name:
new_name = f"past_value.{layer_index}_RetainedState"
node.output[i] = new_name

# Update graph output name if it exists
if original_output_name in model_graph_outputs_map:
idx = model_graph_outputs_map[original_output_name]
model.graph.output[idx].name = new_name
layer_index += 1
return model, transformed
5 changes: 4 additions & 1 deletion QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
"""
Export the model with the active adapter to ONNX format.

Expand Down Expand Up @@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
use_onnx_subfunctions=use_onnx_subfunctions,
)

def compile(
Expand All @@ -300,6 +301,7 @@ def compile(
num_cores: int = 16,
mxfp6_matmul: bool = False,
mxint8_kv_cache: bool = False,
use_onnx_subfunctions: bool = False,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -367,6 +369,7 @@ def compile(
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
mxint8_kv_cache=mxint8_kv_cache,
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)

Expand Down
3 changes: 2 additions & 1 deletion QEfficient/peft/lora/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _init_adapter_model(self):
# load_weight to model
self._load_adapter_weights_to_model()

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
"""
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.

Expand Down Expand Up @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
output_names,
dynamic_axes,
export_dir=export_dir,
use_onnx_subfunctions=use_onnx_subfunctions,
)

def generate(
Expand Down
42 changes: 30 additions & 12 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@
)


class InvalidIndexProvider:
SUBFUNC_ENABLED = False

@classmethod
def enable_subfunc(cls):
cls.SUBFUNC_ENABLED = True

@classmethod
def _get_invalid_idx_value(cls):
"""
Get the appropriate invalid index value for CtxGather operations.

For ONNX export with functions, we use 0 to avoid INT32_MAX constants
that cause issues when functions are inlined at runtime.

Returns:
int: Invalid index value (0 for ONNX functions, INT32_MAX otherwise)
"""
if torch.onnx.is_in_onnx_export():
if cls.SUBFUNC_ENABLED:
return 0
else:
return torch.iinfo(torch.int32).max
else:
return 0


class QEffDynamicLayer(DynamicLayer):
def read_only(self, cache_kwargs):
"""
Expand All @@ -46,10 +73,7 @@ def read_only(self, cache_kwargs):
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

Expand Down Expand Up @@ -143,10 +167,7 @@ def update(
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
Expand Down Expand Up @@ -419,10 +440,7 @@ def update(
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
Expand Down
Loading
Loading