|
1 | 1 | # -----------------------------------------------------------------------------
|
2 | 2 | #
|
3 |
| -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. |
| 3 | +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. |
4 | 4 | # SPDX-License-Identifier: BSD-3-Clause
|
5 | 5 | #
|
6 | 6 | # ----------------------------------------------------------------------------
|
|
9 | 9 |
|
10 | 10 | from torch import nn
|
11 | 11 |
|
| 12 | +from QEfficient.utils.logging_utils import logger |
| 13 | + |
12 | 14 |
|
13 | 15 | class PytorchTransform:
|
14 | 16 | """
|
@@ -110,3 +112,65 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
|
110 | 112 | transformed = True
|
111 | 113 |
|
112 | 114 | return model, transformed
|
| 115 | + |
| 116 | + |
| 117 | +class SplitGateUpWeightsTransform(PytorchTransform): |
| 118 | + """ |
| 119 | + split fused Gate+Up weights and copy into the model |
| 120 | +
|
| 121 | + For every transformer layer inside `model`: |
| 122 | + • expects <PREFIX>.experts.gate_up_proj in the *source* `sd` |
| 123 | + • copies halves into |
| 124 | + <PREFIX>.experts.gate_proj <-- Gate [E,H,I] |
| 125 | + <PREFIX>.experts.up_proj <-- Up [E,H,I] |
| 126 | + """ |
| 127 | + |
| 128 | + @classmethod |
| 129 | + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: |
| 130 | + transformed = False |
| 131 | + model_class = model.__class__.__name__ if hasattr(model, "model") else model.__class__.__name__ |
| 132 | + |
| 133 | + if model_class not in VLM_SPLIT_GATE_UP_WEIGHTS: |
| 134 | + return model, transformed |
| 135 | + |
| 136 | + model_tmp = model.language_model if hasattr(model, "language_model") else model |
| 137 | + |
| 138 | + num_layers = len(model_tmp.model.layers) |
| 139 | + delete_fused_key = True |
| 140 | + sd = model_tmp.state_dict() |
| 141 | + for layer_idx in range(num_layers): |
| 142 | + # ---- build the textual prefix once per layer ---------- |
| 143 | + prefix = f"model.layers.{layer_idx}.feed_forward.experts." |
| 144 | + |
| 145 | + fused_key = prefix + "gate_up_proj" |
| 146 | + gate_key = prefix + "gate_proj" |
| 147 | + up_key = prefix + "up_proj" |
| 148 | + |
| 149 | + # ---- split [E,H,2I] → two [E,H,I] tensors ---------------------- |
| 150 | + fused = sd[fused_key] # [E, H, 2I] (no .weight here) |
| 151 | + E, H, two_I = fused.shape |
| 152 | + ffn_dim = two_I // 2 |
| 153 | + gate, up = fused.split(ffn_dim, dim=-1) # views – no copy |
| 154 | + |
| 155 | + experts = model_tmp.model.layers[layer_idx].feed_forward.experts |
| 156 | + experts.gate_proj.data.copy_(gate) |
| 157 | + experts.up_proj.data.copy_(up) |
| 158 | + |
| 159 | + # ---- update the state-dict so load_state_dict sees the right keys |
| 160 | + sd[gate_key] = gate |
| 161 | + sd[up_key] = up |
| 162 | + |
| 163 | + if delete_fused_key: |
| 164 | + del sd[fused_key] |
| 165 | + |
| 166 | + logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})") |
| 167 | + transformed = True |
| 168 | + |
| 169 | + if hasattr(model, "language_model"): |
| 170 | + model.language_model = model_tmp |
| 171 | + else: |
| 172 | + model = model_tmp |
| 173 | + return model, transformed |
| 174 | + |
| 175 | + |
| 176 | +VLM_SPLIT_GATE_UP_WEIGHTS = {"QEffLlama4ForConditionalGeneration"} |
0 commit comments