Skip to content

Commit 014047e

Browse files
authored
Fix bug in apply_rotary_pos_emb_flashatt: in Qwen2-5-VL (#36065)
1 parent 006d924 commit 014047e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
162162

163163
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
164164
tensor_ = tensor.float()
165-
cos = freqs.cos()
166-
sin = freqs.sin()
165+
cos = freqs.cos().float()
166+
sin = freqs.sin().float()
167167
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
168168
return output
169169

src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565

6666
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
6767
tensor_ = tensor.float()
68-
cos = freqs.cos()
69-
sin = freqs.sin()
68+
cos = freqs.cos().float()
69+
sin = freqs.sin().float()
7070
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
7171
return output
7272

0 commit comments

Comments
 (0)