|
5 | 5 | from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
|
6 | 6 | from torch.distributed._tensor import init_device_mesh
|
7 | 7 |
|
| 8 | + |
| 9 | +def modify_view( |
| 10 | + gm: torch.fx.GraphModule, |
| 11 | + tp: int |
| 12 | +): |
| 13 | + """ |
| 14 | + Adjust dimension size of view ops to make them compatible with tensor parallelism. |
| 15 | + """ |
| 16 | + for node in gm.graph.nodes: |
| 17 | + if node.op == "call_method" and ( |
| 18 | + node.target == "view" or node.target == "reshape" |
| 19 | + ): |
| 20 | + assert len(node.args) >= 4 |
| 21 | + node.update_arg(3, node.args[3] // tp) |
| 22 | + gm.recompile() |
| 23 | + |
| 24 | + |
8 | 25 | # Grab the model
|
9 | 26 | llama = AutoModelForCausalLM.from_pretrained(
|
10 | 27 | "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
|
|
42 | 59 | stage_idx = rank // tp_group_size
|
43 | 60 | stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)
|
44 | 61 |
|
| 62 | +modify_view(stage.submod, tp_group_size) |
| 63 | + |
45 | 64 | # Tensor parallel
|
46 | 65 | from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
|
47 | 66 | starting_layer = stage_idx * layers_per_stage
|
48 |
| -plan = {} |
| 67 | +attn_plan = {} |
| 68 | +mlp_plan = {} |
49 | 69 | for i in range(layers_per_stage):
|
50 | 70 | # HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper
|
51 | 71 | extra = "_mod" if starting_layer > 0 and i == 0 else ""
|
52 | 72 | layer_name = f"L__self___model_layers_{starting_layer + i}{extra}"
|
53 |
| - plan.update({ |
| 73 | + attn_plan.update({ |
54 | 74 | # Parallel self attention not working yet due to the dimension mismatch
|
55 | 75 | # after TP in view operation
|
56 |
| - #f"{layer_name}_self_attn_q_proj": ColwiseParallel(), |
57 |
| - #f"{layer_name}_self_attn_k_proj": ColwiseParallel(), |
58 |
| - #f"{layer_name}_self_attn_v_proj": ColwiseParallel(), |
59 |
| - #f"{layer_name}_self_attn_o_proj": RowwiseParallel(), |
| 76 | + f"{layer_name}_self_attn_q_proj": ColwiseParallel(), |
| 77 | + f"{layer_name}_self_attn_k_proj": ColwiseParallel(), |
| 78 | + f"{layer_name}_self_attn_v_proj": ColwiseParallel(), |
| 79 | + f"{layer_name}_self_attn_o_proj": RowwiseParallel(), |
| 80 | + }) |
| 81 | + mlp_plan.update({ |
60 | 82 | f"{layer_name}_mlp_gate_proj": ColwiseParallel(),
|
61 | 83 | f"{layer_name}_mlp_up_proj": ColwiseParallel(),
|
62 | 84 | f"{layer_name}_mlp_down_proj": RowwiseParallel(),
|
63 | 85 | })
|
64 | 86 | tp_mesh = mesh_2d["tp"]
|
65 |
| -parallelize_module(stage.submod, tp_mesh, plan) |
| 87 | +parallelize_module( |
| 88 | + stage.submod, tp_mesh, {**attn_plan, **mlp_plan} |
| 89 | +) |
66 | 90 |
|
67 | 91 | # Run
|
68 | 92 | if stage_idx == 0:
|
|
0 commit comments