Skip to content

Commit 9a2ce88

Browse files
add a flag to control whether to use torch.compile
1 parent 1d1ac73 commit 9a2ce88

File tree

5 files changed

+28
-12
lines changed

5 files changed

+28
-12
lines changed

megatron/core/jit.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@
77
jit_fuser = torch.jit.script
88
# nvFuser is deprecated in PyTorch JIT starting from 2.2
99

10-
try:
11-
if is_torch_min_version("2.2.0a0"):
12-
jit_fuser = torch.compile
13-
except ImportError:
10+
def noop_decorator(func):
11+
return func
1412

15-
def noop_decorator(func):
16-
return func
13+
def enable_jit_fuser():
14+
global jit_fuser
15+
try:
16+
if is_torch_min_version("2.2.0a0"):
17+
jit_fuser = torch.compile
18+
except ImportError:
1719

20+
jit_fuser = noop_decorator
21+
22+
def disable_jit_fuser():
23+
global jit_fuser
1824
jit_fuser = noop_decorator
25+
26+
enable_jit_fuser()

megatron/core/ssm/gated_delta_net.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory
1919
from megatron.core.fp8_utils import get_fp8_align_size
2020
from megatron.core.inference.contexts import BaseInferenceContext
21+
from megatron.core.jit import jit_fuser
2122
from megatron.core.packed_seq_params import PackedSeqParams
2223
from megatron.core.process_groups_config import ProcessGroupCollection
2324
from megatron.core.tensor_parallel import get_cuda_rng_tracker
@@ -384,7 +385,7 @@ def forward(
384385

385386
# RMSNorm
386387
nvtx_range_push(suffix="gated_norm")
387-
norm_out = self._torch_compiled_gated_norm(core_attn_out, gate)
388+
norm_out = self._apply_gated_norm(core_attn_out, gate)
388389
nvtx_range_pop(suffix="gated_norm")
389390

390391
# Transpose: b s x --> s b x
@@ -399,8 +400,8 @@ def forward(
399400

400401
return out, out_bias
401402

402-
@torch.compile
403-
def _torch_compiled_gated_norm(self, x, gate):
403+
@jit_fuser
404+
def _apply_gated_norm(self, x, gate):
404405
# Output Norm
405406
x_dtype = x.dtype
406407
x = x.reshape(-1, x.shape[-1])

megatron/core/transformer/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from megatron.core import tensor_parallel
1111
from megatron.core.inference.contexts import BaseInferenceContext
12+
from megatron.core.jit import jit_fuser
1213
from megatron.core.models.common.embeddings.rope_utils import (
1314
apply_rotary_pos_emb,
1415
apply_rotary_pos_emb_with_cos_sin,
@@ -958,7 +959,7 @@ def forward(
958959
# Output gate
959960
if gate is not None:
960961
nvtx_range_push(suffix="output_gate")
961-
core_attn_out = self._torch_compiled_output_gate(core_attn_out, gate)
962+
core_attn_out = self._apply_output_gate(core_attn_out, gate)
962963
nvtx_range_pop(suffix="output_gate")
963964

964965
# =================
@@ -978,8 +979,8 @@ def forward(
978979

979980
return output, bias
980981

981-
@torch.compile
982-
def _torch_compiled_output_gate(self, x, gate):
982+
@jit_fuser
983+
def _apply_output_gate(self, x, gate):
983984
x_dtype = x.dtype
984985
gate = gate.contiguous()
985986
gate = gate.view(*x.shape)

megatron/training/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,6 +2341,8 @@ def _add_training_args(parser):
23412341
help='The submodules to offload its input. Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".')
23422342
group.add_argument('--min-offloaded-tensor-size', type=int, default=1024*1024,
23432343
help='The minimum size of the tensor to be offloaded.')
2344+
group.add_argument('--disable-jit-fuser', action='store_true',
2345+
help='Disable the JIT fuser.')
23442346
return parser
23452347

23462348

megatron/training/global_vars.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from megatron.core import Timers
1010
from megatron.core.config import set_experimental_flag
1111
from megatron.core.energy_monitor import EnergyMonitor
12+
from megatron.core.jit import disable_jit_fuser
1213
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator
1314
from megatron.training import dist_signal_handler
1415
from megatron.training.tokenizer import build_tokenizer
@@ -111,6 +112,9 @@ def set_global_variables(args, build_tokenizer=True):
111112
if args.exit_signal_handler:
112113
_set_signal_handler()
113114

115+
if args.disable_jit_fuser:
116+
disable_jit_fuser()
117+
114118

115119
def unset_global_variables():
116120
"""Unset global vars.

0 commit comments

Comments
 (0)