Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e55eef4

Browse files
committedDec 28, 2023
Modify view ops to make them compatible with TP
1 parent d2e8e62 commit e55eef4

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed
 

‎examples/llama/2d_llama.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@
55
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
66
from torch.distributed._tensor import init_device_mesh
77

8+
9+
def modify_view(
10+
gm: torch.fx.GraphModule,
11+
tp: int
12+
):
13+
"""
14+
Adjust dimension size of view ops to make them compatible with tensor parallelism.
15+
"""
16+
for node in gm.graph.nodes:
17+
if node.op == "call_method" and (
18+
node.target == "view" or node.target == "reshape"
19+
):
20+
assert len(node.args) >= 4
21+
node.update_arg(3, node.args[3] // tp)
22+
gm.recompile()
23+
24+
825
# Grab the model
926
llama = AutoModelForCausalLM.from_pretrained(
1027
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
@@ -42,27 +59,34 @@
4259
stage_idx = rank // tp_group_size
4360
stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)
4461

62+
modify_view(stage.submod, tp_group_size)
63+
4564
# Tensor parallel
4665
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
4766
starting_layer = stage_idx * layers_per_stage
48-
plan = {}
67+
attn_plan = {}
68+
mlp_plan = {}
4969
for i in range(layers_per_stage):
5070
# HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper
5171
extra = "_mod" if starting_layer > 0 and i == 0 else ""
5272
layer_name = f"L__self___model_layers_{starting_layer + i}{extra}"
53-
plan.update({
73+
attn_plan.update({
5474
# Parallel self attention not working yet due to the dimension mismatch
5575
# after TP in view operation
56-
#f"{layer_name}_self_attn_q_proj": ColwiseParallel(),
57-
#f"{layer_name}_self_attn_k_proj": ColwiseParallel(),
58-
#f"{layer_name}_self_attn_v_proj": ColwiseParallel(),
59-
#f"{layer_name}_self_attn_o_proj": RowwiseParallel(),
76+
f"{layer_name}_self_attn_q_proj": ColwiseParallel(),
77+
f"{layer_name}_self_attn_k_proj": ColwiseParallel(),
78+
f"{layer_name}_self_attn_v_proj": ColwiseParallel(),
79+
f"{layer_name}_self_attn_o_proj": RowwiseParallel(),
80+
})
81+
mlp_plan.update({
6082
f"{layer_name}_mlp_gate_proj": ColwiseParallel(),
6183
f"{layer_name}_mlp_up_proj": ColwiseParallel(),
6284
f"{layer_name}_mlp_down_proj": RowwiseParallel(),
6385
})
6486
tp_mesh = mesh_2d["tp"]
65-
parallelize_module(stage.submod, tp_mesh, plan)
87+
parallelize_module(
88+
stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
89+
)
6690

6791
# Run
6892
if stage_idx == 0:

0 commit comments

Comments
 (0)
Please sign in to comment.