@@ -58,7 +58,7 @@ def __init__(
58
58
self .bias_key : Optional [str ] = prefix + ".bias"
59
59
else :
60
60
self .bias_key : Optional [str ] = None
61
- self .use_ep : bool = fd_config .parallel_config .use_ep
61
+ self .tp_group = fd_config .parallel_config .tp_group
62
62
self .column_cut = True
63
63
self .nranks = fd_config .parallel_config .tensor_parallel_size
64
64
self .fd_config = fd_config
@@ -68,60 +68,46 @@ def __init__(
68
68
69
69
self .tie_word_embeddings : bool = fd_config .model_config .tie_word_embeddings
70
70
71
- if self .use_ep :
72
- self .weight = self .create_parameter (
73
- shape = [embedding_dim , num_embeddings ],
74
- dtype = paddle .get_default_dtype (),
75
- is_bias = False ,
71
+ if self .column_cut :
72
+ need_gather = True
73
+ self .linear = ColumnParallelLinear (
74
+ embedding_dim ,
75
+ num_embeddings ,
76
+ mp_group = self .tp_group ,
77
+ weight_attr = None ,
78
+ has_bias = True if self .bias_key is not None else False ,
79
+ gather_output = need_gather ,
80
+ fuse_matmul_bias = False ,
76
81
)
77
- if self .bias_key is not None :
78
- self .bias = self .create_parameter (
79
- shape = [num_embeddings ],
80
- dtype = paddle .get_default_dtype (),
81
- is_bias = True ,
82
- )
83
-
82
+ set_weight_attrs (
83
+ self .linear .weight ,
84
+ {
85
+ "weight_loader" : default_weight_loader (self .fd_config ),
86
+ "model_format" : self .fd_config .model_config .model_format ,
87
+ },
88
+ )
89
+ if self .nranks > 1 :
90
+ set_weight_attrs (self .linear .weight , {"output_dim" : True })
84
91
else :
85
- if self .column_cut :
86
- need_gather = True
87
- self .linear = ColumnParallelLinear (
88
- embedding_dim ,
89
- num_embeddings ,
90
- mp_group = fleet .get_hybrid_communicate_group ().get_model_parallel_group (),
91
- weight_attr = None ,
92
- has_bias = True if self .bias_key is not None else False ,
93
- gather_output = need_gather ,
94
- fuse_matmul_bias = False ,
95
- )
96
- set_weight_attrs (
97
- self .linear .weight ,
98
- {
99
- "weight_loader" : default_weight_loader (self .fd_config ),
100
- "model_format" : self .fd_config .model_config .model_format ,
101
- },
102
- )
103
- if self .nranks > 1 :
104
- set_weight_attrs (self .linear .weight , {"output_dim" : True })
105
- else :
106
- self .linear = RowParallelLinear (
107
- embedding_dim ,
108
- num_embeddings ,
109
- mp_group = fleet .get_hybrid_communicate_group ().get_model_parallel_group (),
110
- weight_attr = None ,
111
- has_bias = True if self .bias_key is not None else False ,
112
- input_is_parallel = False ,
113
- fuse_matmul_bias = False ,
114
- )
115
- set_weight_attrs (
116
- self .linear .weight ,
117
- {
118
- "weight_loader" : default_weight_loader (self .fd_config ),
119
- "model_format" : self .fd_config .model_config .model_format ,
120
- },
121
- )
122
-
123
- if self .nranks > 1 :
124
- set_weight_attrs (self .linear .weight , {"output_dim" : False })
92
+ self .linear = RowParallelLinear (
93
+ embedding_dim ,
94
+ num_embeddings ,
95
+ mp_group = self .tp_group ,
96
+ weight_attr = None ,
97
+ has_bias = True if self .bias_key is not None else False ,
98
+ input_is_parallel = False ,
99
+ fuse_matmul_bias = False ,
100
+ )
101
+ set_weight_attrs (
102
+ self .linear .weight ,
103
+ {
104
+ "weight_loader" : default_weight_loader (self .fd_config ),
105
+ "model_format" : self .fd_config .model_config .model_format ,
106
+ },
107
+ )
108
+
109
+ if self .nranks > 1 :
110
+ set_weight_attrs (self .linear .weight , {"output_dim" : False })
125
111
126
112
def load_state_dict (self , state_dict : Dict [str , paddle .Tensor | np .ndarray ]):
127
113
"""
@@ -131,24 +117,19 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
131
117
state_dict (dict): A dictionary containing the checkpoint weights and biases.
132
118
"""
133
119
134
- if self .use_ep :
135
- self .weight .set_value (get_tensor ( state_dict . pop ( self . weight_key )). astype ( paddle . get_default_dtype ()))
136
- if self . bias_key is not None :
137
- self . bias . set_value ( get_tensor ( state_dict . pop ( self . bias_key )). astype ( paddle . get_default_dtype ()) )
120
+ if self .tie_word_embeddings :
121
+ self .linear . weight .set_value (
122
+ get_tensor ( state_dict . pop ( self . weight_key )). astype ( paddle . get_default_dtype ()). transpose ([ 1 , 0 ])
123
+ )
138
124
else :
139
- if self .tie_word_embeddings :
140
- self .linear .weight .set_value (
141
- get_tensor (state_dict .pop (self .weight_key )).astype (paddle .get_default_dtype ()).transpose ([1 , 0 ])
142
- )
143
- else :
144
- weight_tensor = get_tensor (state_dict .pop (self .weight_key )).astype (paddle .get_default_dtype ())
145
- if self .linear .weight .shape != weight_tensor .shape :
146
- weight_tensor = weight_tensor .transpose ([1 , 0 ])
147
- self .linear .weight .set_value (weight_tensor )
148
-
149
- if self .bias_key is not None :
150
- bias = get_tensor (state_dict .pop (self .bias_key )).astype (paddle .get_default_dtype ())
151
- self .linear .bias .set_value (bias )
125
+ weight_tensor = get_tensor (state_dict .pop (self .weight_key )).astype (paddle .get_default_dtype ())
126
+ if self .linear .weight .shape != weight_tensor .shape :
127
+ weight_tensor = weight_tensor .transpose ([1 , 0 ])
128
+ self .linear .weight .set_value (weight_tensor )
129
+
130
+ if self .bias_key is not None :
131
+ bias = get_tensor (state_dict .pop (self .bias_key )).astype (paddle .get_default_dtype ())
132
+ self .linear .bias .set_value (bias )
152
133
153
134
def forward (self , input : paddle .Tensor ) -> paddle .Tensor :
154
135
"""
@@ -161,11 +142,5 @@ def forward(self, input: paddle.Tensor) -> paddle.Tensor:
161
142
Tensor: The output tensor after processing through the layer.
162
143
"""
163
144
logits = input
164
- if self .use_ep :
165
- if self .bias_key is None :
166
- logits = paddle .matmul (logits , self .weight )
167
- else :
168
- logits = paddle .incubate .nn .functional .fused_linear (logits , self .weight , self .bias )
169
- else :
170
- logits = self .linear (logits )
145
+ logits = self .linear (logits )
171
146
return logits
0 commit comments