Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
max_num_tiles: int = 0,
tokenizer_type: str = "",
vp_stage: Optional[int] = None,
use_vision_backbone_fp8_arch: bool = False,
) -> None:
super().__init__(config=language_transformer_config)

Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(
ln_post_impl = None
use_mask_token = False

if vision_transformer_config.fp8:
if vision_transformer_config.fp8 or use_vision_backbone_fp8_arch:
# FP8 padding for final sequence length to be a multiple of 16 or 32.
class_token_len = 32 if vision_transformer_config.fp8_recipe == "mxfp8" else 16

Expand Down
13 changes: 7 additions & 6 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ def __init__(
eps=self.config.layernorm_epsilon,
)

self.apply(
partial(
_init_weights,
n_layer=self.config.num_layers,
initializer_range=self.config.init_method_std,
if self.config.perform_initialization:
self.apply(
partial(
_init_weights,
n_layer=self.config.num_layers,
initializer_range=self.config.init_method_std,
)
)
)

def _select_layers_for_pipeline_parallel(self, layer_type_list):
num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size()
Expand Down
102 changes: 54 additions & 48 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,60 +268,66 @@ def __init__(
)

conv_dim = self.d_inner_local_tp + 2 * self.ngroups_local_tp * self.d_state # x B C
with get_cuda_rng_tracker().fork():
# weight shape: [conv_dim, 1, d_conv]
# bias shape: [conv_dim]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "tensor_model_parallel", True)
# weight shape: [conv_dim, 1, d_conv]
# bias shape: [conv_dim]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "tensor_model_parallel", True)

if self.conv_init is not None:
if self.config.perform_initialization and self.conv_init is not None:
with get_cuda_rng_tracker().fork():
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)

self.activation = "silu"
self.act = nn.SiLU()

with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local_tp,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
setattr(self.dt_bias, "tensor_model_parallel", True)

# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
).uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, "tensor_model_parallel", True)
if self.config.perform_initialization:
with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local_tp,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
else:
inv_dt = torch.empty(self.nheads_local_tp)

self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
setattr(self.dt_bias, "tensor_model_parallel", True)

# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
)
if self.config.perform_initialization:
A = A.uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, "tensor_model_parallel", True)

# D "skip" parameter
self.D = nn.Parameter(
Expand Down
Loading