Skip to content

Commit 68e2630

Browse files
committed
feat(archon): improve pipeline parallelism memory handling
Add reshard_after_forward_policy config, MoE-aware donated_buffer management, output chunk memory optimization, and comprehensive PP memory guide documentation. Key changes: - Add reshard_after_forward_policy config for FSDP forward resharding control - Add is_moe_model_config utility; only disable donated_buffer for MoE models - Add _NullOutputChunks to free logits during PP training step - Fix microbatch validation to use num_total_stages instead of pp_size - Add PP Memory Guide appendix to archon.md - Replace handling_oom.md PP section with seealso cross-reference
1 parent f1fc2d3 commit 68e2630

9 files changed

Lines changed: 302 additions & 28 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
# opencode
1515
.opencode/sessions/
16+
.sisyphus/
1617

1718
# Ruff
1819
.ruff_cache/

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
```bash
1010
# Environment
1111
uv sync --extra cuda # dependencies (or `uv sync` without CUDA)
12+
source .venv/bin/activate # activate venv BEFORE pre-commit or git commit if venv exists
1213
pre-commit install # formatting hooks (Ruff, mdformat, clang-format, nbstripout, autoflake)
1314
pre-commit run --all-files # lint + format everything
1415

areal/api/cli_args.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,18 @@ class ArchonEngineConfig:
489489
},
490490
)
491491

492+
# FSDP reshard policy after forward pass
493+
reshard_after_forward_policy: str = field(
494+
default="default",
495+
metadata={
496+
"help": "FSDP reshard policy after forward pass. "
497+
"'default': reshard when pipeline parallelism is off; keep unsharded when on to avoid repeated all-gather per microbatch. "
498+
"'always': always reshard after forward (saves memory). "
499+
"'never': never reshard after forward.",
500+
"choices": ["default", "always", "never"],
501+
},
502+
)
503+
492504
# Deterministic mode
493505
use_deterministic_algorithms: bool = field(
494506
default=False,
@@ -515,6 +527,12 @@ def __post_init__(self):
515527
f"pp_last_stage_less_layers must be >= 0, "
516528
f"got {self.pp_last_stage_less_layers}"
517529
)
530+
valid_reshard_policies = ("default", "always", "never")
531+
if self.reshard_after_forward_policy not in valid_reshard_policies:
532+
raise ValueError(
533+
f"reshard_after_forward_policy must be one of {valid_reshard_policies}, "
534+
f"got '{self.reshard_after_forward_policy}'"
535+
)
518536

519537

520538
# These configurations are used by Megatron Bridge to build Megatron models.

areal/experimental/engine/archon_engine.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
ulysses_gather_output,
7575
ulysses_slice_inputs,
7676
)
77+
from areal.experimental.models.archon.utils import is_moe_model_config
7778
from areal.infra.dist_rollout import DistRolloutCoordinator
7879
from areal.infra.platforms import current_platform
7980
from areal.models.tree_attn.functional import (
@@ -292,13 +293,18 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
292293
ac_config = self._build_ac_config()
293294
enable_compile = self.config.archon.enable_compile
294295

296+
# NOTE: Upgrading PyTorch may resolve these in the future.
295297
# Zero-bubble schedules (InterleavedZeroBubble, ZBVZeroBubble, DualPipeV)
296-
# use split backward (I/W steps). This is incompatible with:
297-
# 1. torch.compile - donated buffer optimization assumes a single
298-
# backward pass (retain_graph=False).
299-
# 2. Op-level selective AC - its per-op cache (storage.pop) is consumed
298+
# use split backward (I/W steps) with retain_graph=True between them.
299+
# This is incompatible with:
300+
# 1. torch.compile - disabled unconditionally for zero-bubble.
301+
# 2. donated_buffer (MoE only) - MoE models have internally compiled
302+
# ops (via AOTAutograd) whose backward uses donated buffers. These
303+
# are freed after backward, conflicting with retain_graph=True.
304+
# Dense models have no such ops and are unaffected.
305+
# 3. Op-level selective AC - its per-op cache (storage.pop) is consumed
300306
# by the I step, leaving nothing for the W step recompute.
301-
# 3. memory_budget AC - it depends on torch.compile.
307+
# 4. memory_budget AC - it depends on torch.compile.
302308
# Full AC / layer-level selective AC use standard checkpoint_wrapper
303309
# whose gid-based recompute supports multiple backward passes.
304310
schedule_class = get_schedule_class(self.config.archon.pp_schedule)
@@ -316,6 +322,22 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs):
316322
)
317323
enable_compile = False
318324

325+
# NOTE: Upgrading PyTorch may resolve this in the future.
326+
# MoE models have internally compiled ops (via AOTAutograd)
327+
# whose backward uses donated buffers - these conflict with
328+
# retain_graph=True in split backward. Dense models have no
329+
# such ops and are unaffected.
330+
if is_moe_model_config(self.model_config):
331+
import torch._functorch.config as functorch_config
332+
333+
if getattr(functorch_config, "donated_buffer", False):
334+
self.logger.info(
335+
f"{schedule_name} requires donated_buffer=False "
336+
"for MoE models (internally compiled ops conflict "
337+
"with retain_graph=True in split backward). Disabling."
338+
)
339+
functorch_config.donated_buffer = False
340+
319341
if ac_config is not None and (
320342
(
321343
ac_config.mode == "selective"
@@ -899,7 +921,7 @@ def _apply_pipeline_parallelism(
899921
reduce_dtype=torch.float32,
900922
loss_parallel=True,
901923
cpu_offload=self.config.archon.offload_params,
902-
reshard_after_forward_policy="default",
924+
reshard_after_forward_policy=self.config.archon.reshard_after_forward_policy,
903925
ac_config=ac_config,
904926
enable_compile=enable_compile,
905927
)
@@ -938,7 +960,7 @@ def _apply_parallelism(
938960
reduce_dtype=torch.float32,
939961
loss_parallel=True,
940962
cpu_offload=self.config.archon.offload_params,
941-
reshard_after_forward_policy="default",
963+
reshard_after_forward_policy=self.config.archon.reshard_after_forward_policy,
942964
ac_config=ac_config,
943965
enable_compile=enable_compile,
944966
)
@@ -1239,16 +1261,20 @@ def _prepare_mb_list(self, input_: dict[str, Any]) -> MicroBatchList:
12391261

12401262
input_ = amend_position_ids(input_)
12411263

1242-
# Pipeline parallelism requires n_microbatches >= pp_stages
1264+
# Pipeline parallelism requires n_microbatches >= num_total_stages
12431265
if self.parallel_dims.pp_enabled:
12441266
pp_size = self.parallel_dims.pp
1267+
stages_per_rank = len(self.pp_stages)
1268+
num_total_stages = pp_size * stages_per_rank
12451269
n_seqs = input_["attention_mask"].shape[0]
1246-
if n_seqs < pp_size:
1270+
if n_seqs < num_total_stages:
12471271
raise RuntimeError(
1248-
f"Pipeline parallelism requires at least {pp_size} sequences, "
1249-
f"but got {n_seqs}. Increase batch size or reduce PP degree."
1272+
f"Pipeline parallelism requires at least {num_total_stages} "
1273+
f"sequences (pp_size={pp_size} * stages_per_rank="
1274+
f"{stages_per_rank}), but got {n_seqs}. "
1275+
f"Increase batch size or reduce PP degree/stages."
12501276
)
1251-
min_n_mbs = pp_size
1277+
min_n_mbs = num_total_stages
12521278
mb_spec = MicroBatchSpec.new(
12531279
self.config.mb_spec,
12541280
n_mbs=max(min_n_mbs, self.config.mb_spec.n_mbs or 1),

areal/experimental/engine/archon_runner.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
logger = logging.getLogger("ArchonRunner")
2121

2222

23+
class _NullOutputChunks(list):
24+
def append(self, item: Any) -> None:
25+
pass
26+
27+
2328
class ForwardBackwardRunner(ABC):
2429
"""Abstract base for forward/backward execution strategies."""
2530

@@ -216,9 +221,11 @@ def _run_eval(
216221
if not self.has_last_stage:
217222
return None
218223
output_stage = self._get_output_stage()
219-
return self._process_outputs(
224+
results = self._process_outputs(
220225
output_stage.output_chunks, contexts, process_output_fn
221226
)
227+
output_stage.output_chunks.clear()
228+
return results
222229

223230
def _run_train(
224231
self,
@@ -232,7 +239,24 @@ def _run_train(
232239
pp_loss_fn = self._create_loss_fn(contexts, process_output_fn)
233240
schedule = self._create_schedule(n_microbatches, loss_fn=pp_loss_fn)
234241
self._patch_skip_output_merge(schedule)
242+
243+
# NOTE: Upgrading PyTorch may resolve this in the future.
244+
# Replace output_chunks with a null list so
245+
# forward_one_chunk's `output_chunks.append(output)` becomes a no-op.
246+
# (torch/distributed/pipelining/schedules.py)
247+
# This lets each microbatch's logits be freed right after its backward,
248+
# instead of holding all N sets of logits until step() returns.
249+
output_stage = None
250+
if self.has_last_stage:
251+
output_stage = self._get_output_stage()
252+
output_stage.output_chunks = _NullOutputChunks()
253+
235254
schedule.step(*args, target=batched_target, **batched_kwargs)
255+
256+
# Restore normal list so subsequent eval() calls on the same
257+
# stage can read output_chunks normally.
258+
if output_stage is not None:
259+
output_stage.output_chunks = []
236260
return []
237261

238262
def _create_loss_fn(

areal/experimental/models/archon/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,29 @@ def validate_ep_constraints(
137137
)
138138

139139

140+
def is_moe_model_config(model_config: object) -> bool:
141+
"""Check if a HuggingFace PretrainedConfig represents a Mixture-of-Experts model.
142+
143+
Inspects common HF config attributes (num_experts, num_local_experts)
144+
to determine whether the model uses MoE layers.
145+
146+
Args:
147+
model_config: A HuggingFace PretrainedConfig (or any object with
148+
num_experts / num_local_experts attributes).
149+
150+
Returns:
151+
True if the config indicates an MoE model with more than one expert.
152+
"""
153+
num_experts = getattr(model_config, "num_experts", None)
154+
if num_experts is None:
155+
num_experts = getattr(model_config, "num_local_experts", None)
156+
return num_experts is not None and num_experts > 1
157+
158+
140159
__all__ = [
141160
"ModelArgsProtocol",
142161
"MoEModelArgsProtocol",
162+
"is_moe_model_config",
143163
"validate_cp_constraints",
144164
"validate_tp_constraints",
145165
"validate_ep_constraints",

docs/best_practices/handling_oom.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ allocation_mode: sglang:d4+archon:d2p2e2
144144
We recommend pipeline and expert parallelism over tensor/context parallelism. Check
145145
[Allocation Mode Reference](../reference/alloc_mode.md) for more details.
146146

147+
```{seealso}
148+
Pipeline parallelism introduces unique memory challenges (microbatch warmup accumulation,
149+
zero-bubble `retain_graph` overhead, FSDP resharding trade-offs, gradient accumulation
150+
costs, and per-rank memory budgeting). See the
151+
[Archon PP Memory Guide](../tutorial/archon.md#appendix-pipeline-parallelism-memory-guide)
152+
for a comprehensive walkthrough.
153+
```
154+
147155
### 4. Switch to a Lightweight Optimizer
148156

149157
AReaL supports different optimizers depending on the training engine.

0 commit comments

Comments
 (0)