Skip to content

Commit 478b6b2

Browse files
authored
[#9230][refactor] Replace nemotron patches with custom model implementation (#9751)
[#9230][refactor] Replace nemotron patches with custom model implementation * Why? Patching for nemotron H models was growing out of hand, and made certain optimizations more complex than they needed to be. * What? This commit finally gets rid of them, and replaces them with the custom model implementation in `modeling_nemotron_h.py`. Closes #9230 Closes NvBug 5747867 Signed-off-by: William Zhang <[email protected]>
1 parent 72c5480 commit 478b6b2

File tree

13 files changed

+379
-465
lines changed

13 files changed

+379
-465
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import flashinfer
44
import torch
5+
import torch.nn.functional as F
6+
from einops import rearrange
57

68
from ...flashinfer_utils import get_env_enable_pdl
79
from ...modules.mamba.layernorm_gated import _layer_norm_fwd
@@ -159,3 +161,35 @@ def _triton_rmsnorm_gated_meta(
159161
assert gate.shape == x.shape, "gate must match x shape"
160162

161163
return x.new_empty(x.shape, dtype=torch.float32)
164+
165+
166+
# Forked from:
167+
# https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
168+
# NOTES:
169+
# 1. At time of writing (09/25/2025), the nano nemotron v2 modeling code expects `mamba_ssm`
170+
# to be installed so as to be able to make use of its grouped gated RMS norm operation.
171+
# We therefore replace it with one that uses einops + pytorch.
172+
def gated_rms_norm_ref(
173+
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True
174+
):
175+
dtype = x.dtype
176+
# N = x.shape[-1]
177+
weight = weight.float()
178+
bias = bias.float() if bias is not None else None
179+
if upcast:
180+
x = x.float()
181+
z = z.float() if z is not None else z
182+
if z is not None and not norm_before_gate:
183+
x = x * F.silu(z)
184+
if group_size is None:
185+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
186+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
187+
else:
188+
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
189+
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
190+
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
191+
if bias is not None:
192+
out = out + bias
193+
if z is not None and norm_before_gate:
194+
out *= F.silu(z)
195+
return out.to(dtype)
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
# TODO: When getting rid of the nemotron H patches, import `modeling_nemotron_h` here to ensure the
2-
# custom model implementation is registered.
31
from . import custom, hf, nemotron_flash, patches
42
from .factory import *
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
11
from .modeling_nemotron_flash import NemotronFlashForCausalLM, NemotronFlashPreTrainedTokenizerFast
2+
from .modeling_nemotron_h import NemotronHForCausalLM
3+
4+
__all__ = (
5+
"NemotronFlashForCausalLM",
6+
"NemotronFlashPreTrainedTokenizerFast",
7+
"NemotronHForCausalLM",
8+
)

tensorrt_llm/_torch/auto_deploy/models/modeling_nemotron_h.py renamed to tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py

Lines changed: 86 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,14 @@
2525
import torch
2626
import torch.nn.functional as F
2727
import torch.utils.checkpoint
28-
from einops import rearrange
2928
from torch import nn
3029
from transformers.activations import ACT2FN
3130
from transformers.generation import GenerationMixin
3231
from transformers.modeling_utils import PreTrainedModel
3332
from transformers.utils import ModelOutput
3433

35-
from tensorrt_llm._torch.auto_deploy.models.patches.nemotron_h import (
36-
_nemotron_h_moe_forward,
37-
_nemotron_h_topk_router_forward,
38-
)
34+
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
35+
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
3936

4037

4138
class MambaRMSNormGated(torch.nn.Module):
@@ -46,7 +43,7 @@ def __init__(self, hidden_size, group_size, eps=1e-5):
4643
self.group_size = group_size
4744

4845
def forward(self, hidden_states, gate=None):
49-
return _rms_norm_ref(
46+
return gated_rms_norm_ref(
5047
x=hidden_states,
5148
weight=self.weight,
5249
bias=None,
@@ -57,38 +54,6 @@ def forward(self, hidden_states, gate=None):
5754
)
5855

5956

60-
# Forked from:
61-
# https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
62-
# NOTES:
63-
# 1. At time of writing (09/25/2025), the nano nemotron v2 modeling code expects `mamba_ssm`
64-
# to be installed so as to be able to make use of its grouped gated RMS norm operation.
65-
# We therefore replace it with one that uses einops + pytorch.
66-
def _rms_norm_ref(
67-
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True
68-
):
69-
dtype = x.dtype
70-
# N = x.shape[-1]
71-
weight = weight.float()
72-
bias = bias.float() if bias is not None else None
73-
if upcast:
74-
x = x.float()
75-
z = z.float() if z is not None else z
76-
if z is not None and not norm_before_gate:
77-
x = x * F.silu(z)
78-
if group_size is None:
79-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
80-
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
81-
else:
82-
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
83-
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
84-
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
85-
if bias is not None:
86-
out = out + bias
87-
if z is not None and norm_before_gate:
88-
out *= F.silu(z)
89-
return out.to(dtype)
90-
91-
9257
class NemotronHMamba2Mixer(nn.Module):
9358
"""
9459
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
@@ -149,9 +114,9 @@ def __init__(self, config, layer_idx: int):
149114
self.A_log._no_weight_decay = True
150115
# Instead of recomputing `torch.exp(self.A_log.float())` on every forward pass, we will register a hook
151116
# that sets this appropriately when loading weights.
152-
# NOTE: we explicitly do NOT make this a `nn.Parameter` so that it does not appear in the state dict of
153-
# this module, or an equivalent graph module trace from it.
154-
self._minus_A = -A.float()
117+
# NOTE: we explicitly register this as a non-persistent buffer so that it does not appear in the state dict of
118+
# this module, or an equivalent graph module trace from it, but still gets included in e.g. `to()` calls.
119+
self.register_buffer("_minus_A", -A.float(), persistent=False)
155120
self.norm = MambaRMSNormGated(
156121
self.intermediate_size,
157122
eps=self.layer_norm_epsilon,
@@ -317,8 +282,43 @@ def __init__(self, config, layer_idx: Optional[int] = None):
317282
layer_idx=layer_idx,
318283
)
319284

320-
# TODO: inline code from `_nemotron_h_moe_forward` when removing patches.
321-
forward = _nemotron_h_moe_forward
285+
def forward(self, hidden_states: torch.Tensor):
286+
residuals = hidden_states
287+
orig_shape = hidden_states.shape
288+
topk_indices, topk_weights = self.gate(hidden_states)
289+
x_flat = hidden_states.view(-1, hidden_states.shape[-1])
290+
291+
# NOTE: So far we've seen that the dispatch order in eager code is the same as the node order in the exported
292+
# graph.
293+
# We dispatch shared expert first so that we can easily fork the execution of the routed experts
294+
# (using the custom op below) to an auxiliary stream.
295+
shared_out = self.shared_experts(residuals)
296+
# Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj)
297+
has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj")
298+
299+
if has_latent_proj:
300+
# Latent MOE: project to latent space before routing
301+
x_flat = self.fc1_latent_proj(x_flat)
302+
303+
# Route through experts (operates in latent space if latent MOE, full space otherwise)
304+
out_flat = torch.ops.auto_deploy.torch_moe(
305+
x_flat,
306+
topk_indices,
307+
topk_weights,
308+
w1_weight=[e.up_proj.weight for e in self.experts],
309+
w2_weight=[e.down_proj.weight for e in self.experts],
310+
w3_weight=[],
311+
act_fn="relu2",
312+
mlp_style="mlp",
313+
)
314+
315+
if has_latent_proj:
316+
# Latent MOE: project back from latent space
317+
out_flat = self.fc2_latent_proj(out_flat)
318+
319+
routed_out = out_flat.view(*orig_shape)
320+
out = shared_out + routed_out
321+
return out
322322

323323

324324
class NemotronHTopkRouter(nn.Module):
@@ -339,22 +339,33 @@ def __init__(self, config):
339339
"e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32)
340340
)
341341

342-
forward = _nemotron_h_topk_router_forward
342+
def forward(self, hidden_states):
343+
"""
344+
Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel.
343345
346+
This replaces the original forward method which used pure PyTorch operations
347+
with optimized CUDA kernels:
348+
"""
349+
hidden_states = hidden_states.view(-1, self.config.hidden_size)
350+
if self.weight.dtype == torch.float32:
351+
router_logits = F.linear(hidden_states.type(torch.float32), self.weight)
352+
else:
353+
router_logits = torch.ops.trtllm.dsv3_router_gemm_op(
354+
hidden_states, self.weight.t(), bias=None, out_dtype=torch.float32
355+
)
344356

345-
# Copied from transformers.models.llama.modeling_llama.repeat_kv
346-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
347-
"""
348-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
349-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
350-
"""
351-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
352-
if n_rep == 1:
353-
return hidden_states
354-
hidden_states = hidden_states[:, :, None, :, :].expand(
355-
batch, num_key_value_heads, n_rep, slen, head_dim
356-
)
357-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
357+
# Use the fused noaux_tc_op kernel which applies sigmoid internally
358+
# and performs group-based top-k selection with normalization
359+
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
360+
router_logits,
361+
self.e_score_correction_bias,
362+
self.n_group,
363+
self.topk_group,
364+
self.top_k,
365+
self.routed_scaling_factor,
366+
)
367+
368+
return topk_indices, topk_weights
358369

359370

360371
class NemotronHAttention(nn.Module):
@@ -369,8 +380,23 @@ def __init__(self, config, layer_idx: Optional[int] = None):
369380

370381
self.hidden_size = config.hidden_size
371382
self.num_heads = config.num_attention_heads
372-
if config.head_dim is not None:
373-
self.head_dim = config.head_dim
383+
384+
# At some point during NemotronH development, what used to be called `attention_head_dim`
385+
# was renamed to `head_dim`. Since no configuration class's code (nor the modeling code,
386+
# for that matter) was ever upstreamed into `transformers`, we have to resort to the below
387+
# hack in order to support multiple iterations of NemotronH models.
388+
if hasattr(config, "head_dim"):
389+
head_dim = config.head_dim
390+
elif hasattr(config, "attention_head_dim"):
391+
head_dim = config.attention_head_dim
392+
else:
393+
raise AttributeError(
394+
"Expected either `head_dim` or `attention_head_dim` to be present in the config "
395+
"class, found neither."
396+
)
397+
398+
if head_dim is not None:
399+
self.head_dim = head_dim
374400
else:
375401
self.head_dim = config.hidden_size // config.num_attention_heads
376402
self.num_key_value_heads = config.num_key_value_heads
@@ -594,7 +620,4 @@ def forward(
594620
return NemotronHCausalLMOutput(logits)
595621

596622

597-
# TODO: uncomment after removing patches (and make sure it is imported in `__init__.py`).
598-
# from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
599-
#
600-
# AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)
623+
AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)

0 commit comments

Comments
 (0)