5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
- import math
8
+ import os
9
9
10
+ import numpy as np
10
11
from transformers import AutoConfig
11
12
12
13
from QEfficient .utils ._utils import get_num_layers_from_config
13
14
14
15
15
- def get_device_map (model_name , num_pp_stages , rank ):
16
- """Returns device map for model layers based number of pipeline stages and given process rank .
16
+ def get_device_map (train_config ):
17
+ """Returns device map for the given model .
17
18
18
19
Args:
19
- model_name (str): model name to get the device map for.
20
- num_pp_stages (int): number of stages in pipeline
21
- rank (int): process rank
20
+ train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
21
+
22
+ Returns:
23
+ Dict: A dictionary of layers and corresponding device id.
24
+ """
25
+
26
+ if train_config .num_pp_stages > 1 :
27
+ if train_config .enable_ddp :
28
+ device_map = custom_device_map (train_config )
29
+ else :
30
+ device_map = "auto"
31
+ else :
32
+ device_map = None
33
+
34
+ return device_map
35
+
36
+
37
+ def custom_device_map (train_config ):
38
+ """Returns custom device map for model layers based number of pipeline stages and given process rank.
39
+
40
+ Args:
41
+ train_config (TrainConfig): Training configuration object contaning model name and number of pipeline stages etc.
22
42
23
43
Returns:
24
44
Dict: A dictionary of layers and corresponding device id.
@@ -33,32 +53,34 @@ def get_device_map(model_name, num_pp_stages, rank):
33
53
PP (Pipeline Parallelism): Each copy of the model is split into 2 stages
34
54
DDP (Distributed Data Parallel): 2 model copies run in parallel
35
55
36
- |-------------------------------------------------------------------------------
37
- | Process Rank | Assigned Device IDs | Model Component |
38
- |-------------------------------------------------------------------------------
39
- | Rank 0 | 0 | model.embed_tokens |
40
- | | | model.lm_head |
41
- | | | model.layers.0 - model.layers.7 |
42
- |-------------------------------------------------------------------------------
43
- | Rank 0 | 1 | model.norm |
44
- | | | model.rotary_emb |
45
- | | | model.layers.8 - model.layers.15 |
46
- |-------------------------------------------------------------------------------
47
- | Rank 1 | 2 | model.embed_tokens |
48
- | | | model.lm_head |
49
- | | | model.layers.0 - model.layers.7 |
50
- |-------------------------------------------------------------------------------
51
- | Rank 1 | 3 | model.norm |
52
- | | | model.rotary_emb |
53
- | | | model.layers.8 - model.layers.15 |
54
- |-------------------------------------------------------------------------------
56
+ |--------------------------------------------------------------------------|
57
+ | Process Rank | Assigned Device IDs | Model Component |
58
+ |--------------------------------------------------------------------------|
59
+ | Rank 0 | 0 | model.embed_tokens |
60
+ | | | model.lm_head |
61
+ | | | model.layers.0 - model.layers.7 |
62
+ |--------------------------------------------------------------------------|
63
+ | Rank 0 | 1 | model.norm |
64
+ | | | model.rotary_emb |
65
+ | | | model.layers.8 - model.layers.15 |
66
+ |--------------------------------------------------------------------------|
67
+ | Rank 1 | 2 | model.embed_tokens |
68
+ | | | model.lm_head |
69
+ | | | model.layers.0 - model.layers.7 |
70
+ |--------------------------------------------------------------------------|
71
+ | Rank 1 | 3 | model.norm |
72
+ | | | model.rotary_emb |
73
+ | | | model.layers.8 - model.layers.15 |
74
+ |--------------------------------------------------------------------------|
55
75
"""
56
76
57
- config = AutoConfig .from_pretrained (model_name )
77
+ config = AutoConfig .from_pretrained (train_config . model_name )
58
78
num_layers = get_num_layers_from_config (config )
59
79
60
- first_device = rank * num_pp_stages
61
- last_device = rank * num_pp_stages + (num_pp_stages - 1 )
80
+ n = train_config .num_pp_stages
81
+ rank = int (os .getenv ("LOCAL_RANK" , 0 ))
82
+ first_device = rank * n
83
+ last_device = rank * n + (n - 1 )
62
84
63
85
if config .tie_word_embeddings :
64
86
lm_head_device = first_device
@@ -71,11 +93,12 @@ def get_device_map(model_name, num_pp_stages, rank):
71
93
"model.norm" : last_device ,
72
94
"model.rotary_emb" : last_device ,
73
95
}
96
+ n_layer_per_stage = np .ceil (num_layers / n )
74
97
75
- n_layer_per_stage = math .ceil (num_layers / num_pp_stages )
98
+ pp_stage_ids = np .arange (n )
99
+ pp_device_map = np .repeat (pp_stage_ids , n_layer_per_stage )
76
100
77
- for j in range (num_pp_stages ):
78
- for i in range (n_layer_per_stage * j , min (n_layer_per_stage * (j + 1 ), num_layers )):
79
- device_map [f"model.layers.{ i } " ] = first_device + j
101
+ for i in range (num_layers ):
102
+ device_map [f"model.layers.{ i } " ] = pp_device_map [i ] + rank * n
80
103
81
104
return device_map
0 commit comments