7474 ulysses_gather_output ,
7575 ulysses_slice_inputs ,
7676)
77+ from areal .experimental .models .archon .utils import is_moe_model_config
7778from areal .infra .dist_rollout import DistRolloutCoordinator
7879from areal .infra .platforms import current_platform
7980from areal .models .tree_attn .functional import (
@@ -292,13 +293,18 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
292293 ac_config = self ._build_ac_config ()
293294 enable_compile = self .config .archon .enable_compile
294295
296+ # NOTE: Upgrading PyTorch may resolve these in the future.
295297 # Zero-bubble schedules (InterleavedZeroBubble, ZBVZeroBubble, DualPipeV)
296- # use split backward (I/W steps). This is incompatible with:
297- # 1. torch.compile - donated buffer optimization assumes a single
298- # backward pass (retain_graph=False).
299- # 2. Op-level selective AC - its per-op cache (storage.pop) is consumed
298+ # use split backward (I/W steps) with retain_graph=True between them.
299+ # This is incompatible with:
300+ # 1. torch.compile - disabled unconditionally for zero-bubble.
301+ # 2. donated_buffer (MoE only) - MoE models have internally compiled
302+ # ops (via AOTAutograd) whose backward uses donated buffers. These
303+ # are freed after backward, conflicting with retain_graph=True.
304+ # Dense models have no such ops and are unaffected.
305+ # 3. Op-level selective AC - its per-op cache (storage.pop) is consumed
300306 # by the I step, leaving nothing for the W step recompute.
301- # 3 . memory_budget AC - it depends on torch.compile.
307+ # 4 . memory_budget AC - it depends on torch.compile.
302308 # Full AC / layer-level selective AC use standard checkpoint_wrapper
303309 # whose gid-based recompute supports multiple backward passes.
304310 schedule_class = get_schedule_class (self .config .archon .pp_schedule )
@@ -316,6 +322,22 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
316322 )
317323 enable_compile = False
318324
325+ # NOTE: Upgrading PyTorch may resolve this in the future.
326+ # MoE models have internally compiled ops (via AOTAutograd)
327+ # whose backward uses donated buffers - these conflict with
328+ # retain_graph=True in split backward. Dense models have no
329+ # such ops and are unaffected.
330+ if is_moe_model_config (self .model_config ):
331+ import torch ._functorch .config as functorch_config
332+
333+ if getattr (functorch_config , "donated_buffer" , False ):
334+ self .logger .info (
335+ f"{ schedule_name } requires donated_buffer=False "
336+ "for MoE models (internally compiled ops conflict "
337+ "with retain_graph=True in split backward). Disabling."
338+ )
339+ functorch_config .donated_buffer = False
340+
319341 if ac_config is not None and (
320342 (
321343 ac_config .mode == "selective"
@@ -899,7 +921,7 @@ def _apply_pipeline_parallelism(
899921 reduce_dtype = torch .float32 ,
900922 loss_parallel = True ,
901923 cpu_offload = self .config .archon .offload_params ,
902- reshard_after_forward_policy = "default" ,
924+ reshard_after_forward_policy = self . config . archon . reshard_after_forward_policy ,
903925 ac_config = ac_config ,
904926 enable_compile = enable_compile ,
905927 )
@@ -938,7 +960,7 @@ def _apply_parallelism(
938960 reduce_dtype = torch .float32 ,
939961 loss_parallel = True ,
940962 cpu_offload = self .config .archon .offload_params ,
941- reshard_after_forward_policy = "default" ,
963+ reshard_after_forward_policy = self . config . archon . reshard_after_forward_policy ,
942964 ac_config = ac_config ,
943965 enable_compile = enable_compile ,
944966 )
@@ -1239,16 +1261,20 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
12391261
12401262 input_ = amend_position_ids (input_ )
12411263
1242- # Pipeline parallelism requires n_microbatches >= pp_stages
1264+ # Pipeline parallelism requires n_microbatches >= num_total_stages
12431265 if self .parallel_dims .pp_enabled :
12441266 pp_size = self .parallel_dims .pp
1267+ stages_per_rank = len (self .pp_stages )
1268+ num_total_stages = pp_size * stages_per_rank
12451269 n_seqs = input_ ["attention_mask" ].shape [0 ]
1246- if n_seqs < pp_size :
1270+ if n_seqs < num_total_stages :
12471271 raise RuntimeError (
1248- f"Pipeline parallelism requires at least { pp_size } sequences, "
1249- f"but got { n_seqs } . Increase batch size or reduce PP degree."
1272+ f"Pipeline parallelism requires at least { num_total_stages } "
1273+ f"sequences (pp_size={ pp_size } * stages_per_rank="
1274+ f"{ stages_per_rank } ), but got { n_seqs } . "
1275+ f"Increase batch size or reduce PP degree/stages."
12501276 )
1251- min_n_mbs = pp_size
1277+ min_n_mbs = num_total_stages
12521278 mb_spec = MicroBatchSpec .new (
12531279 self .config .mb_spec ,
12541280 n_mbs = max (min_n_mbs , self .config .mb_spec .n_mbs or 1 ),
0 commit comments