Skip to content

Commit 3bb379c

Browse files
committed
Improve FSDP wrapper implementation
1 parent 6280fdd commit 3bb379c

File tree

7 files changed

+136
-187
lines changed

7 files changed

+136
-187
lines changed

src/fairseq2/models/_handler.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
UnknownModelArchitectureError,
3131
model_asset_card_error,
3232
)
33-
from fairseq2.nn.data_parallel import load_with_sdp_gang
33+
from fairseq2.models.fsdp import apply_fsdp_to_transformer
34+
from fairseq2.nn.data_parallel import FsdpGranularity, FsdpWrapper, load_with_sdp_gang
3435
from fairseq2.nn.utils.module import (
3536
load_state_dict,
3637
reset_non_persistent_buffers,
@@ -79,6 +80,11 @@ def load_from_path(
7980
@abstractmethod
8081
def compile(self, model: Module, config: object) -> Module: ...
8182

83+
@abstractmethod
84+
def apply_fsdp(
85+
self, model: Module, granularity: FsdpGranularity, wrapper: FsdpWrapper
86+
) -> Module: ...
87+
8288
@property
8389
@abstractmethod
8490
def family(self) -> str: ...
@@ -461,6 +467,7 @@ def _do_create(
461467

462468
return model
463469

470+
@override
464471
def compile(self, model: Module, config: object) -> Module:
465472
if self._torch_compiler is None:
466473
raise NotSupportedError(
@@ -479,6 +486,12 @@ def compile(self, model: Module, config: object) -> Module:
479486

480487
return self._torch_compiler(model, config)
481488

489+
@override
490+
def apply_fsdp(
491+
self, model: Module, granularity: FsdpGranularity, wrapper: FsdpWrapper
492+
) -> Module:
493+
return apply_fsdp_to_transformer(model, granularity, wrapper)
494+
482495
@property
483496
@override
484497
def family(self) -> str:

src/fairseq2/models/fsdp.py

+32-52
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,37 @@
66

77
from __future__ import annotations
88

9-
from functools import partial
10-
from typing import Literal
11-
12-
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
13-
CheckpointWrapper,
14-
)
15-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
169
from torch.nn import Module
1710

18-
from fairseq2.nn.data_parallel import FSDPWrapPolicy
19-
from fairseq2.nn.transformer import (
20-
TransformerDecoder,
21-
TransformerDecoderLayer,
22-
TransformerEncoder,
23-
TransformerEncoderLayer,
24-
)
25-
26-
27-
def get_fsdp_wrap_policy(
28-
model: Module, wrap_granularity: Literal["layer", "stack", "model"] = "layer"
29-
) -> tuple[FSDPWrapPolicy | None, list[Module] | None]:
30-
"""Return the FSDP wrap policy for ``model`` along with ignored modules.
31-
32-
:param model: The model to be wrapped.
33-
:param wrap_granularity: The granularity at which to wrap modules of ``model``.
34-
- 'layer': Wraps individual layers (e.g. :class:`TransformerDecoderLayer`).
35-
- 'stack': Wraps layer stacks (e.g. :class:`TransformerDecoder`).
36-
- 'model': Wraps ``model``.
37-
"""
38-
if wrap_granularity == "model":
39-
return None, None
40-
41-
kls: set[type[Module]]
42-
43-
if wrap_granularity == "stack":
44-
kls = {TransformerEncoder, TransformerDecoder}
45-
elif wrap_granularity == "layer":
46-
kls = {TransformerEncoderLayer, TransformerDecoderLayer}
47-
48-
# We make the assumption that if the model uses activation checkpointing,
49-
# it is at the layer granularity.
50-
for m in model.modules():
51-
if isinstance(m, CheckpointWrapper):
52-
kls = {CheckpointWrapper}
53-
54-
break
55-
else:
56-
raise ValueError(
57-
f"`wrap_granularity` must be 'layer', 'stack', or 'model', but is '{wrap_granularity}' instead."
58-
)
59-
60-
wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=kls)
61-
62-
return wrap_policy, None
11+
from fairseq2.nn.data_parallel import FsdpGranularity, FsdpWrapper
12+
from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
13+
14+
15+
def apply_fsdp_to_transformer(
16+
model: Module, granularity: FsdpGranularity, wrapper: FsdpWrapper
17+
) -> Module:
18+
if granularity == "model":
19+
return wrapper(model)
20+
21+
children = list(model.named_children())
22+
23+
for name, child in children:
24+
if isinstance(child, (TransformerEncoder, TransformerDecoder)):
25+
if granularity == "stack":
26+
model.register_module(name, wrapper(child))
27+
else:
28+
layers = list(child.layers.named_children())
29+
30+
for idx, (layer_name, layer) in enumerate(layers):
31+
# We don't need to reshard the last layer since we will
32+
# immediately gather it for the backward pass.
33+
if idx < len(layers) - 1:
34+
reshard_after_forward = None
35+
else:
36+
reshard_after_forward = False
37+
38+
child.layers.register_module(
39+
layer_name, wrapper(layer, reshard_after_forward)
40+
)
41+
42+
return model

src/fairseq2/nn/data_parallel/__init__.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,12 @@
1010
from fairseq2.nn.data_parallel._error import (
1111
DistributedSetupError as DistributedSetupError,
1212
)
13+
from fairseq2.nn.data_parallel._fsdp import FsdpApplier as FsdpApplier
14+
from fairseq2.nn.data_parallel._fsdp import FsdpGranularity as FsdpGranularity
1315
from fairseq2.nn.data_parallel._fsdp import (
14-
FSDP_LOW_MEMORY_POLICY as FSDP_LOW_MEMORY_POLICY,
16+
FsdpParameterInitializer as FsdpParameterInitializer,
1517
)
16-
from fairseq2.nn.data_parallel._fsdp import (
17-
FSDP_STANDARD_MEMORY_POLICY as FSDP_STANDARD_MEMORY_POLICY,
18-
)
19-
from fairseq2.nn.data_parallel._fsdp import (
20-
FSDP_VERY_LOW_MEMORY_POLICY as FSDP_VERY_LOW_MEMORY_POLICY,
21-
)
22-
from fairseq2.nn.data_parallel._fsdp import FSDPMemoryPolicy as FSDPMemoryPolicy
23-
from fairseq2.nn.data_parallel._fsdp import (
24-
FSDPParameterInitializer as FSDPParameterInitializer,
25-
)
26-
from fairseq2.nn.data_parallel._fsdp import FSDPWrapPolicy as FSDPWrapPolicy
18+
from fairseq2.nn.data_parallel._fsdp import FsdpWrapper as FsdpWrapper
2719
from fairseq2.nn.data_parallel._fsdp import (
2820
fsdp_local_state_dict as fsdp_local_state_dict,
2921
)

0 commit comments

Comments
 (0)