Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,33 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
raise NotImplementedError("Use subclasses for Pytorch transform")


class ProxyModuleMappingTransform(PytorchTransform):
"""
Replaces the PyTorch modules based on the _module_mapping class variable.
"""

_module_mapping: Dict[Type[nn.Module], Type[nn.Module]]

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for name, module in model.named_modules():
for base_type, repl_type in cls._module_mapping.items():
if isinstance(module, base_type):
if base_type is nn.Linear:
short_name = name.split(".")[-1] if name else ""
if short_name != "lm_head":
continue
# Perform in-place class replacement (preserve parameters/state)
try:
module.__class__ = repl_type
transformed = True
except Exception as e:
logger.warning(f"Failed to replace module {name} ({base_type}) -> {repl_type}: {e}")

return model, transformed


class ModuleMappingTransform(PytorchTransform):
"""
Replaces the PyTorch modules based on the _module_mapping class variable.
Expand Down
13 changes: 13 additions & 0 deletions QEfficient/proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from QEfficient.proxy.proxy_transform import QeffProxyEmbedding, QeffProxyLinear

__all__ = [
"QeffProxyEmbedding",
"QeffProxyLinear",
]
27 changes: 27 additions & 0 deletions QEfficient/proxy/proxy_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
import torch
from torch import nn


class QeffProxyEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
self.embed_tokens = None
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim

def forward(self, hidden_states):
inputs_embeds = torch.unsqueeze(hidden_states.float(), 2).expand(-1, -1, self.embedding_dim)
return inputs_embeds


class QeffProxyLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False):
self.lm_head = None

def forward(self, hidden_states):
return hidden_states
22 changes: 22 additions & 0 deletions QEfficient/proxy/pytorch_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import torch.nn as nn

from QEfficient.base.pytorch_transforms import ProxyModuleMappingTransform
from QEfficient.proxy import QeffProxyEmbedding, QeffProxyLinear


class QeffProxyModuleTransform(ProxyModuleMappingTransform):
"""
This transform is used to replace the original modules with QEfficient modules.
"""

_module_mapping = {
nn.Embedding: QeffProxyEmbedding,
nn.Linear: QeffProxyLinear,
}
13 changes: 13 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# ----------------------------------------------------------------------------

import os
import warnings
from pathlib import Path
from time import perf_counter
Expand Down Expand Up @@ -40,6 +41,7 @@
get_compilation_dims,
)
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform
from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH
from QEfficient.transformers.models.pytorch_transforms import (
CustomOpsTransform,
Expand Down Expand Up @@ -2348,6 +2350,10 @@ def __init__(
if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("LMHeadModel")):
raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}")

if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

# TODO: remove from version 1.20
if kwargs.pop("full_batch_size", None):
continuous_batching = True
Expand Down Expand Up @@ -2452,6 +2458,10 @@ def from_pretrained(
QEFFAutoModelForCausalLM
An instance initialized with the pretrained weights.
"""
if kwargs.pop("enable_proxy", False):
cls._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

if kwargs.pop("full_batch_size", None):
continuous_batching = True
warnings.warn(
Expand Down Expand Up @@ -3062,6 +3072,7 @@ def generate(
**kwargs :
Additional keyword arguments. Currently supports:
- `generation_len (int, optional)`: The maximum number of tokens to generate.
- `write_io (bool, optional)`: Whether to save the io files.

Returns
-------
Expand All @@ -3079,6 +3090,7 @@ def generate(
if not isinstance(self.qpc_path, Path):
raise TypeError("Please run compile API first!")
generation_len = kwargs.pop("generation_len", None)
write_io = kwargs.pop("write_io", False)
return QEfficient.cloud_ai_100_exec_kv(
tokenizer=tokenizer,
qpc_path=self.qpc_path,
Expand All @@ -3090,6 +3102,7 @@ def generate(
automation=kwargs.pop("automation", False),
iteration=kwargs.pop("iteration", 1),
is_tlm=self.is_tlm,
write_io_dir=os.path.join(os.path.dirname(self.onnx_path), "io_dir") if write_io else None,
**kwargs,
)
else:
Expand Down
17 changes: 17 additions & 0 deletions examples/proxy_model_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from transformers import AutoTokenizer

from QEfficient import QEFFAutoModelForCausalLM

model = QEFFAutoModelForCausalLM.from_pretrained(
"gpt2", num_hidden_layers=2, enable_proxy=True
) # enable_proxy=True to use proxy model export i.e., export model disable the embedding and LM head layers
model.compile(num_cores=16)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.generate(prompts=["Hi there!!"], tokenizer=tokenizer, write_io=True) # write_io = True to save io files