2525import torch
2626import torch .nn .functional as F
2727import torch .utils .checkpoint
28- from einops import rearrange
2928from torch import nn
3029from transformers .activations import ACT2FN
3130from transformers .generation import GenerationMixin
3231from transformers .modeling_utils import PreTrainedModel
3332from transformers .utils import ModelOutput
3433
35- from tensorrt_llm ._torch .auto_deploy .models .patches .nemotron_h import (
36- _nemotron_h_moe_forward ,
37- _nemotron_h_topk_router_forward ,
38- )
34+ from tensorrt_llm ._torch .auto_deploy .custom_ops .rms_norm import gated_rms_norm_ref
35+ from tensorrt_llm ._torch .auto_deploy .models .hf import AutoModelForCausalLMFactory
3936
4037
4138class MambaRMSNormGated (torch .nn .Module ):
@@ -46,7 +43,7 @@ def __init__(self, hidden_size, group_size, eps=1e-5):
4643 self .group_size = group_size
4744
4845 def forward (self , hidden_states , gate = None ):
49- return _rms_norm_ref (
46+ return gated_rms_norm_ref (
5047 x = hidden_states ,
5148 weight = self .weight ,
5249 bias = None ,
@@ -57,38 +54,6 @@ def forward(self, hidden_states, gate=None):
5754 )
5855
5956
60- # Forked from:
61- # https://github.com/state-spaces/mamba/blob/6b32be06d026e170b3fdaf3ae6282c5a6ff57b06/mamba_ssm/ops/triton/layernorm_gated.py
62- # NOTES:
63- # 1. At time of writing (09/25/2025), the nano nemotron v2 modeling code expects `mamba_ssm`
64- # to be installed so as to be able to make use of its grouped gated RMS norm operation.
65- # We therefore replace it with one that uses einops + pytorch.
66- def _rms_norm_ref (
67- x , weight , bias , z = None , eps = 1e-6 , group_size = None , norm_before_gate = True , upcast = True
68- ):
69- dtype = x .dtype
70- # N = x.shape[-1]
71- weight = weight .float ()
72- bias = bias .float () if bias is not None else None
73- if upcast :
74- x = x .float ()
75- z = z .float () if z is not None else z
76- if z is not None and not norm_before_gate :
77- x = x * F .silu (z )
78- if group_size is None :
79- rstd = 1 / torch .sqrt ((x .square ()).mean (dim = - 1 , keepdim = True ) + eps )
80- out = (x * rstd * weight ) + bias if bias is not None else (x * rstd * weight )
81- else :
82- x_group = rearrange (x , "... (g d) -> ... g d" , d = group_size )
83- rstd = 1 / torch .sqrt ((x_group .square ()).mean (dim = - 1 , keepdim = True ) + eps )
84- out = rearrange (x_group * rstd , "... g d -> ... (g d)" ) * weight
85- if bias is not None :
86- out = out + bias
87- if z is not None and norm_before_gate :
88- out *= F .silu (z )
89- return out .to (dtype )
90-
91-
9257class NemotronHMamba2Mixer (nn .Module ):
9358 """
9459 Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
@@ -149,9 +114,9 @@ def __init__(self, config, layer_idx: int):
149114 self .A_log ._no_weight_decay = True
150115 # Instead of recomputing `torch.exp(self.A_log.float())` on every forward pass, we will register a hook
151116 # that sets this appropriately when loading weights.
152- # NOTE: we explicitly do NOT make this a `nn.Parameter` so that it does not appear in the state dict of
153- # this module, or an equivalent graph module trace from it.
154- self ._minus_A = - A .float ()
117+ # NOTE: we explicitly register this as a non-persistent buffer so that it does not appear in the state dict of
118+ # this module, or an equivalent graph module trace from it, but still gets included in e.g. `to()` calls .
119+ self .register_buffer ( " _minus_A" , - A .float (), persistent = False )
155120 self .norm = MambaRMSNormGated (
156121 self .intermediate_size ,
157122 eps = self .layer_norm_epsilon ,
@@ -317,8 +282,43 @@ def __init__(self, config, layer_idx: Optional[int] = None):
317282 layer_idx = layer_idx ,
318283 )
319284
320- # TODO: inline code from `_nemotron_h_moe_forward` when removing patches.
321- forward = _nemotron_h_moe_forward
285+ def forward (self , hidden_states : torch .Tensor ):
286+ residuals = hidden_states
287+ orig_shape = hidden_states .shape
288+ topk_indices , topk_weights = self .gate (hidden_states )
289+ x_flat = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
290+
291+ # NOTE: So far we've seen that the dispatch order in eager code is the same as the node order in the exported
292+ # graph.
293+ # We dispatch shared expert first so that we can easily fork the execution of the routed experts
294+ # (using the custom op below) to an auxiliary stream.
295+ shared_out = self .shared_experts (residuals )
296+ # Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj)
297+ has_latent_proj = hasattr (self , "fc1_latent_proj" ) and hasattr (self , "fc2_latent_proj" )
298+
299+ if has_latent_proj :
300+ # Latent MOE: project to latent space before routing
301+ x_flat = self .fc1_latent_proj (x_flat )
302+
303+ # Route through experts (operates in latent space if latent MOE, full space otherwise)
304+ out_flat = torch .ops .auto_deploy .torch_moe (
305+ x_flat ,
306+ topk_indices ,
307+ topk_weights ,
308+ w1_weight = [e .up_proj .weight for e in self .experts ],
309+ w2_weight = [e .down_proj .weight for e in self .experts ],
310+ w3_weight = [],
311+ act_fn = "relu2" ,
312+ mlp_style = "mlp" ,
313+ )
314+
315+ if has_latent_proj :
316+ # Latent MOE: project back from latent space
317+ out_flat = self .fc2_latent_proj (out_flat )
318+
319+ routed_out = out_flat .view (* orig_shape )
320+ out = shared_out + routed_out
321+ return out
322322
323323
324324class NemotronHTopkRouter (nn .Module ):
@@ -339,22 +339,33 @@ def __init__(self, config):
339339 "e_score_correction_bias" , torch .zeros (self .n_routed_experts , dtype = torch .float32 )
340340 )
341341
342- forward = _nemotron_h_topk_router_forward
342+ def forward (self , hidden_states ):
343+ """
344+ Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel.
343345
346+ This replaces the original forward method which used pure PyTorch operations
347+ with optimized CUDA kernels:
348+ """
349+ hidden_states = hidden_states .view (- 1 , self .config .hidden_size )
350+ if self .weight .dtype == torch .float32 :
351+ router_logits = F .linear (hidden_states .type (torch .float32 ), self .weight )
352+ else :
353+ router_logits = torch .ops .trtllm .dsv3_router_gemm_op (
354+ hidden_states , self .weight .t (), bias = None , out_dtype = torch .float32
355+ )
344356
345- # Copied from transformers.models.llama.modeling_llama.repeat_kv
346- def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
347- """
348- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
349- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
350- """
351- batch , num_key_value_heads , slen , head_dim = hidden_states .shape
352- if n_rep == 1 :
353- return hidden_states
354- hidden_states = hidden_states [:, :, None , :, :].expand (
355- batch , num_key_value_heads , n_rep , slen , head_dim
356- )
357- return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
357+ # Use the fused noaux_tc_op kernel which applies sigmoid internally
358+ # and performs group-based top-k selection with normalization
359+ topk_weights , topk_indices = torch .ops .trtllm .noaux_tc_op (
360+ router_logits ,
361+ self .e_score_correction_bias ,
362+ self .n_group ,
363+ self .topk_group ,
364+ self .top_k ,
365+ self .routed_scaling_factor ,
366+ )
367+
368+ return topk_indices , topk_weights
358369
359370
360371class NemotronHAttention (nn .Module ):
@@ -369,8 +380,23 @@ def __init__(self, config, layer_idx: Optional[int] = None):
369380
370381 self .hidden_size = config .hidden_size
371382 self .num_heads = config .num_attention_heads
372- if config .head_dim is not None :
373- self .head_dim = config .head_dim
383+
384+ # At some point during NemotronH development, what used to be called `attention_head_dim`
385+ # was renamed to `head_dim`. Since no configuration class's code (nor the modeling code,
386+ # for that matter) was ever upstreamed into `transformers`, we have to resort to the below
387+ # hack in order to support multiple iterations of NemotronH models.
388+ if hasattr (config , "head_dim" ):
389+ head_dim = config .head_dim
390+ elif hasattr (config , "attention_head_dim" ):
391+ head_dim = config .attention_head_dim
392+ else :
393+ raise AttributeError (
394+ "Expected either `head_dim` or `attention_head_dim` to be present in the config "
395+ "class, found neither."
396+ )
397+
398+ if head_dim is not None :
399+ self .head_dim = head_dim
374400 else :
375401 self .head_dim = config .hidden_size // config .num_attention_heads
376402 self .num_key_value_heads = config .num_key_value_heads
@@ -594,7 +620,4 @@ def forward(
594620 return NemotronHCausalLMOutput (logits )
595621
596622
597- # TODO: uncomment after removing patches (and make sure it is imported in `__init__.py`).
598- # from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
599- #
600- # AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM)
623+ AutoModelForCausalLMFactory .register_custom_model_cls ("NemotronHConfig" , NemotronHForCausalLM )
0 commit comments