@@ -43,7 +43,7 @@ def __call__(self, position_ids):
4343 inv_freq = self .base ** (- paddle .arange (0 , self .rotary_dim , 2 , dtype = "float32" ) / self .rotary_dim )
4444 partial_rotary_position_ids = position_ids / self .partial_rotary_factor
4545 freqs = paddle .einsum ("ij,k->ijk" , partial_rotary_position_ids .cast ("float32" ), inv_freq )
46- if paddle . is_compiled_with_xpu () or paddle .is_compiled_with_custom_device ("iluvatar_gpu" ):
46+ if current_platform . is_xpu () or paddle .is_compiled_with_custom_device ("iluvatar_gpu" ):
4747 # shape: [B, S, D]
4848 rot_emb = paddle .zeros ((2 , bsz , max_seq_len , 1 , self .rotary_dim ), dtype = "float32" )
4949 emb = paddle .stack ([freqs , freqs ], axis = - 1 ).reshape ((bsz , max_seq_len , self .rotary_dim ))
@@ -89,9 +89,14 @@ def __call__(self, position_ids):
8989 bsz , max_seq_len = position_ids .shape [:2 ]
9090 inv_freq = self .base ** (- paddle .arange (0 , self .rotary_dim , 2 , dtype = "float32" ) / self .rotary_dim )
9191 freqs = paddle .einsum ("ij,k->ijk" , position_ids .cast ("float32" ), inv_freq )
92- # shape: [B, S, D/2]
93- rot_emb = paddle .zeros ((2 , bsz , max_seq_len , 1 , self .rotary_dim // 2 ), dtype = "float32" )
94- emb = paddle .stack ([freqs ], axis = - 1 ).reshape ((bsz , max_seq_len , self .rotary_dim // 2 ))
92+ if current_platform .is_xpu ():
93+ # shape: [B, S, D]
94+ rot_emb = paddle .zeros ((2 , bsz , max_seq_len , 1 , self .rotary_dim ), dtype = "float32" )
95+ emb = paddle .concat ([freqs , freqs ], axis = - 1 ).reshape ((bsz , max_seq_len , self .rotary_dim ))
96+ else :
97+ # shape: [B, S, D/2]
98+ rot_emb = paddle .zeros ((2 , bsz , max_seq_len , 1 , self .rotary_dim // 2 ), dtype = "float32" )
99+ emb = paddle .stack ([freqs ], axis = - 1 ).reshape ((bsz , max_seq_len , self .rotary_dim // 2 ))
95100 # shape: [B, S, 1, D]
96101 emb = paddle .unsqueeze (emb , 2 )
97102 rot_emb [0 ] = paddle .cos (emb )
0 commit comments