Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
max_num_tiles: int = 0,
tokenizer_type: str = "",
vp_stage: Optional[int] = None,
use_vision_backbone_fp8_arch: bool = False,
) -> None:
super().__init__(config=language_transformer_config)

Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(
ln_post_impl = None
use_mask_token = False

if vision_transformer_config.fp8:
if vision_transformer_config.fp8 or use_vision_backbone_fp8_arch:
# FP8 padding for final sequence length to be a multiple of 16 or 32.
class_token_len = 32 if vision_transformer_config.fp8_recipe == "mxfp8" else 16

Expand Down
13 changes: 7 additions & 6 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ def __init__(
eps=self.config.layernorm_epsilon,
)

self.apply(
partial(
_init_weights,
n_layer=self.config.num_layers,
initializer_range=self.config.init_method_std,
if self.config.perform_initialization:
self.apply(
partial(
_init_weights,
n_layer=self.config.num_layers,
initializer_range=self.config.init_method_std,
)
)
)

def _select_layers_for_pipeline_parallel(self, layer_type_list):
num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size()
Expand Down
102 changes: 54 additions & 48 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,60 +268,66 @@ def __init__(
)

conv_dim = self.d_inner_local_tp + 2 * self.ngroups_local_tp * self.d_state # x B C
with get_cuda_rng_tracker().fork():
# weight shape: [conv_dim, 1, d_conv]
# bias shape: [conv_dim]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "tensor_model_parallel", True)
# weight shape: [conv_dim, 1, d_conv]
# bias shape: [conv_dim]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "tensor_model_parallel", True)

if self.conv_init is not None:
if self.config.perform_initialization and self.conv_init is not None:
with get_cuda_rng_tracker().fork():
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)

self.activation = "silu"
self.act = nn.SiLU()

with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local_tp,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
setattr(self.dt_bias, "tensor_model_parallel", True)

# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
).uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, "tensor_model_parallel", True)
if self.config.perform_initialization:
with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local_tp,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
else:
inv_dt = torch.empty(self.nheads_local_tp)

self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
setattr(self.dt_bias, "tensor_model_parallel", True)

# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
)
if self.config.perform_initialization:
A = A.uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, "tensor_model_parallel", True)

# D "skip" parameter
self.D = nn.Parameter(
Expand Down
Loading