|
6 | 6 |
|
7 | 7 | from __future__ import annotations
|
8 | 8 |
|
9 |
| -from functools import partial |
10 |
| -from typing import Literal |
11 |
| - |
12 |
| -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
13 |
| - CheckpointWrapper, |
14 |
| -) |
15 |
| -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
16 | 9 | from torch.nn import Module
|
17 | 10 |
|
18 |
| -from fairseq2.nn.data_parallel import FSDPWrapPolicy |
19 |
| -from fairseq2.nn.transformer import ( |
20 |
| - TransformerDecoder, |
21 |
| - TransformerDecoderLayer, |
22 |
| - TransformerEncoder, |
23 |
| - TransformerEncoderLayer, |
24 |
| -) |
25 |
| - |
26 |
| - |
27 |
| -def get_fsdp_wrap_policy( |
28 |
| - model: Module, wrap_granularity: Literal["layer", "stack", "model"] = "layer" |
29 |
| -) -> tuple[FSDPWrapPolicy | None, list[Module] | None]: |
30 |
| - """Return the FSDP wrap policy for ``model`` along with ignored modules. |
31 |
| -
|
32 |
| - :param model: The model to be wrapped. |
33 |
| - :param wrap_granularity: The granularity at which to wrap modules of ``model``. |
34 |
| - - 'layer': Wraps individual layers (e.g. :class:`TransformerDecoderLayer`). |
35 |
| - - 'stack': Wraps layer stacks (e.g. :class:`TransformerDecoder`). |
36 |
| - - 'model': Wraps ``model``. |
37 |
| - """ |
38 |
| - if wrap_granularity == "model": |
39 |
| - return None, None |
40 |
| - |
41 |
| - kls: set[type[Module]] |
42 |
| - |
43 |
| - if wrap_granularity == "stack": |
44 |
| - kls = {TransformerEncoder, TransformerDecoder} |
45 |
| - elif wrap_granularity == "layer": |
46 |
| - kls = {TransformerEncoderLayer, TransformerDecoderLayer} |
47 |
| - |
48 |
| - # We make the assumption that if the model uses activation checkpointing, |
49 |
| - # it is at the layer granularity. |
50 |
| - for m in model.modules(): |
51 |
| - if isinstance(m, CheckpointWrapper): |
52 |
| - kls = {CheckpointWrapper} |
53 |
| - |
54 |
| - break |
55 |
| - else: |
56 |
| - raise ValueError( |
57 |
| - f"`wrap_granularity` must be 'layer', 'stack', or 'model', but is '{wrap_granularity}' instead." |
58 |
| - ) |
59 |
| - |
60 |
| - wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=kls) |
61 |
| - |
62 |
| - return wrap_policy, None |
| 11 | +from fairseq2.nn.data_parallel import FsdpGranularity, FsdpWrapper |
| 12 | +from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder |
| 13 | + |
| 14 | + |
| 15 | +def apply_fsdp_to_transformer( |
| 16 | + model: Module, granularity: FsdpGranularity, wrapper: FsdpWrapper |
| 17 | +) -> Module: |
| 18 | + if granularity == "model": |
| 19 | + return wrapper(model) |
| 20 | + |
| 21 | + children = list(model.named_children()) |
| 22 | + |
| 23 | + for name, child in children: |
| 24 | + if isinstance(child, (TransformerEncoder, TransformerDecoder)): |
| 25 | + if granularity == "stack": |
| 26 | + model.register_module(name, wrapper(child)) |
| 27 | + else: |
| 28 | + layers = list(child.layers.named_children()) |
| 29 | + |
| 30 | + for idx, (layer_name, layer) in enumerate(layers): |
| 31 | + # We don't need to reshard the last layer since we will |
| 32 | + # immediately gather it for the backward pass. |
| 33 | + if idx < len(layers) - 1: |
| 34 | + reshard_after_forward = None |
| 35 | + else: |
| 36 | + reshard_after_forward = False |
| 37 | + |
| 38 | + child.layers.register_module( |
| 39 | + layer_name, wrapper(layer, reshard_after_forward) |
| 40 | + ) |
| 41 | + |
| 42 | + return model |
0 commit comments