From 918593e1dd3fbda0b14f39a4636ceb2f6fe75f8b Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 31 Oct 2025 15:15:55 -0700 Subject: [PATCH 1/2] make mamba blocks respect perform_initialization for bridge export Signed-off-by: Chen Cui (cherry picked from commit ccbda0ab88ff014b13918349702bb54d3a033b58) --- megatron/core/ssm/mamba_block.py | 13 ++-- megatron/core/ssm/mamba_mixer.py | 102 ++++++++++++++++--------------- 2 files changed, 61 insertions(+), 54 deletions(-) diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 01b9f4eac6..ad1eccbb74 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -203,13 +203,14 @@ def __init__( eps=self.config.layernorm_epsilon, ) - self.apply( - partial( - _init_weights, - n_layer=self.config.num_layers, - initializer_range=self.config.init_method_std, + if self.config.perform_initialization: + self.apply( + partial( + _init_weights, + n_layer=self.config.num_layers, + initializer_range=self.config.init_method_std, + ) ) - ) def _select_layers_for_pipeline_parallel(self, layer_type_list): num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 2caa36fb1e..d3d428d5f9 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -268,60 +268,66 @@ def __init__( ) conv_dim = self.d_inner_local_tp + 2 * self.ngroups_local_tp * self.d_state # x B C - with get_cuda_rng_tracker().fork(): - # weight shape: [conv_dim, 1, d_conv] - # bias shape: [conv_dim] - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - setattr(self.conv1d.weight, "tensor_model_parallel", True) - setattr(self.conv1d.bias, "tensor_model_parallel", True) + # weight shape: [conv_dim, 1, d_conv] + # bias shape: [conv_dim] + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + setattr(self.conv1d.weight, "tensor_model_parallel", True) + setattr(self.conv1d.bias, "tensor_model_parallel", True) - if self.conv_init is not None: + if self.config.perform_initialization and self.conv_init is not None: + with get_cuda_rng_tracker().fork(): nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) self.activation = "silu" self.act = nn.SiLU() - with get_cuda_rng_tracker().fork(): - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand( - self.nheads_local_tp, - device=torch.cuda.current_device(), - dtype=config.params_dtype, - ) - * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.dt_bias = nn.Parameter(inv_dt) - # Our initialization would set all Linear.bias to zero, - # need to mark this one as _no_reinit - self.dt_bias._no_reinit = True - # Just to be explicit. Without this we already don't - # put wd on dt_bias because of the check - # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True - setattr(self.dt_bias, "tensor_model_parallel", True) - - # A parameter - assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] - A = torch.empty( - self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device() - ).uniform_(*A_init_range) - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - setattr(self.A_log, "tensor_model_parallel", True) + if self.config.perform_initialization: + with get_cuda_rng_tracker().fork(): + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand( + self.nheads_local_tp, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + else: + inv_dt = torch.empty(self.nheads_local_tp) + + self.dt_bias = nn.Parameter(inv_dt) + # Our initialization would set all Linear.bias to zero, + # need to mark this one as _no_reinit + self.dt_bias._no_reinit = True + # Just to be explicit. Without this we already don't + # put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + setattr(self.dt_bias, "tensor_model_parallel", True) + + # A parameter + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty( + self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device() + ) + if self.config.perform_initialization: + A = A.uniform_(*A_init_range) + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + setattr(self.A_log, "tensor_model_parallel", True) # D "skip" parameter self.D = nn.Parameter( From 3dab5c1df18dbf164800741fd2493c87d4f08c2c Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 31 Oct 2025 15:15:25 -0700 Subject: [PATCH 2/2] add fp8 arch arg Signed-off-by: Chen Cui --- megatron/core/models/multimodal/llava_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 2ac2657c1c..dae9a02b78 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -124,6 +124,7 @@ def __init__( max_num_tiles: int = 0, tokenizer_type: str = "", vp_stage: Optional[int] = None, + use_vision_backbone_fp8_arch: bool = False, ) -> None: super().__init__(config=language_transformer_config) @@ -295,7 +296,7 @@ def __init__( ln_post_impl = None use_mask_token = False - if vision_transformer_config.fp8: + if vision_transformer_config.fp8 or use_vision_backbone_fp8_arch: # FP8 padding for final sequence length to be a multiple of 16 or 32. class_token_len = 32 if vision_transformer_config.fp8_recipe == "mxfp8" else 16