Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions megatron/core/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,27 @@
jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2

try:
if is_torch_min_version("2.2.0a0"):
jit_fuser = torch.compile
except ImportError:

def noop_decorator(func):
return func
def noop_decorator(func):
'''No-op decorator'''
return func


def enable_jit_fuser():
'''Enable the JIT fuser'''
global jit_fuser
try:
if is_torch_min_version("2.2.0a0"):
jit_fuser = torch.compile
except ImportError:

jit_fuser = noop_decorator


def disable_jit_fuser():
'''Disable the JIT fuser'''
global jit_fuser
jit_fuser = noop_decorator


enable_jit_fuser()
7 changes: 4 additions & 3 deletions megatron/core/ssm/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.jit import jit_fuser
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel import get_cuda_rng_tracker
Expand Down Expand Up @@ -384,7 +385,7 @@ def forward(

# RMSNorm
nvtx_range_push(suffix="gated_norm")
norm_out = self._torch_compiled_gated_norm(core_attn_out, gate)
norm_out = self._apply_gated_norm(core_attn_out, gate)
nvtx_range_pop(suffix="gated_norm")

# Transpose: b s x --> s b x
Expand All @@ -399,8 +400,8 @@ def forward(

return out, out_bias

@torch.compile
def _torch_compiled_gated_norm(self, x, gate):
@jit_fuser
def _apply_gated_norm(self, x, gate):
# Output Norm
x_dtype = x.dtype
x = x.reshape(-1, x.shape[-1])
Expand Down
7 changes: 4 additions & 3 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from megatron.core import tensor_parallel
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.jit import jit_fuser
from megatron.core.models.common.embeddings.rope_utils import (
apply_rotary_pos_emb,
apply_rotary_pos_emb_with_cos_sin,
Expand Down Expand Up @@ -958,7 +959,7 @@ def forward(
# Output gate
if gate is not None:
nvtx_range_push(suffix="output_gate")
core_attn_out = self._torch_compiled_output_gate(core_attn_out, gate)
core_attn_out = self._apply_output_gate(core_attn_out, gate)
nvtx_range_pop(suffix="output_gate")

# =================
Expand All @@ -978,8 +979,8 @@ def forward(

return output, bias

@torch.compile
def _torch_compiled_output_gate(self, x, gate):
@jit_fuser
def _apply_output_gate(self, x, gate):
x_dtype = x.dtype
gate = gate.contiguous()
gate = gate.view(*x.shape)
Expand Down
6 changes: 6 additions & 0 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,12 @@ def sharded_state_dict(
elif isinstance(self.config.moe_layer_freq, list):
non_homogeneous_layers = True

if isinstance(self.config.linear_attention_freq, int):
if self.config.linear_attention_freq > 1:
non_homogeneous_layers = True
elif isinstance(self.config.linear_attention_freq, list):
non_homogeneous_layers = True

if self.config.heterogeneous_block_specs:
non_homogeneous_layers = True

Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2349,6 +2349,8 @@ def _add_training_args(parser):
help='The submodules to offload its input. Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".')
group.add_argument('--min-offloaded-tensor-size', type=int, default=1024*1024,
help='The minimum size of the tensor to be offloaded.')
group.add_argument('--disable-jit-fuser', action='store_true',
help='Disable the JIT fuser.')
return parser


Expand Down
4 changes: 4 additions & 0 deletions megatron/training/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from megatron.core import Timers
from megatron.core.config import set_experimental_flag
from megatron.core.energy_monitor import EnergyMonitor
from megatron.core.jit import disable_jit_fuser
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator
from megatron.training import dist_signal_handler
from megatron.training.tokenizer import build_tokenizer
Expand Down Expand Up @@ -111,6 +112,9 @@ def set_global_variables(args, build_tokenizer=True):
if args.exit_signal_handler:
_set_signal_handler()

if args.disable_jit_fuser:
disable_jit_fuser()


def unset_global_variables():
"""Unset global vars.
Expand Down
Loading