Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
max_num_tiles: int = 0,
tokenizer_type: str = "",
vp_stage: Optional[int] = None,
use_vision_backbone_fp8_arch: bool = False,
) -> None:
super().__init__(config=language_transformer_config)

Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(
ln_post_impl = None
use_mask_token = False

if vision_transformer_config.fp8:
if vision_transformer_config.fp8 or use_vision_backbone_fp8_arch:
# FP8 padding for final sequence length to be a multiple of 16 or 32.
class_token_len = 32 if vision_transformer_config.fp8_recipe == "mxfp8" else 16

Expand Down
13 changes: 7 additions & 6 deletions megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ def __init__(
eps=self.config.layernorm_epsilon,
)

self.apply(
partial(
_init_weights,
n_layer=self.config.num_layers,
initializer_range=self.config.init_method_std,
if self.config.perform_initialization:
self.apply(
partial(
_init_weights,
n_layer=self.config.num_layers,
initializer_range=self.config.init_method_std,
)
)
)

def _select_layers_for_pipeline_parallel(self, layer_type_list):
num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size()
Expand Down
104 changes: 56 additions & 48 deletions megatron/core/ssm/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,60 +268,68 @@ def __init__(
)

conv_dim = self.d_inner_local_tp + 2 * self.ngroups_local_tp * self.d_state # x B C
with get_cuda_rng_tracker().fork():
# weight shape: [conv_dim, 1, d_conv]
# bias shape: [conv_dim]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "tensor_model_parallel", True)
# weight shape: [conv_dim, 1, d_conv]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a pretty significant change to do this with a different rng tracker and not related to respsecting config.perform_initialization. Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that, there is no randomness in creating the conv layer if weights are not initialized, so the get rng tracker context manager should only apply to the param init (line 288)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if conv_bias=True?

This could also change randomness from previous versions since it'll use up numbers for the weight initialization even if it is overwritten below, so a run with seed=N from before this commit won't match a run with seed=N after this commit. Not sure if that's a big deal, but might unnecessarily change the golden values.

# bias shape: [conv_dim]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, "tensor_model_parallel", True)
setattr(self.conv1d.bias, "tensor_model_parallel", True)

if self.conv_init is not None:
if self.config.perform_initialization and self.conv_init is not None:
with get_cuda_rng_tracker().fork():
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)

self.activation = "silu"
self.act = nn.SiLU()

with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local_tp,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
setattr(self.dt_bias, "tensor_model_parallel", True)

# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
).uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, "tensor_model_parallel", True)
if self.config.perform_initialization:
with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local_tp,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
else:
inv_dt = torch.empty(
self.nheads_local_tp, device=torch.cuda.current_device(), dtype=config.params_dtype
)

self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
setattr(self.dt_bias, "tensor_model_parallel", True)

# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
)
if self.config.perform_initialization:
A = A.uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, "tensor_model_parallel", True)

# D "skip" parameter
self.D = nn.Parameter(
Expand Down
Loading