Skip to content

add audio optimization for qwen2.5-omni #13037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,12 +2072,31 @@ def _optimize_post(model):
convert_forward(model.thinker.visual, module.Qwen2_5OmniVisionSdpaAttention,
qwen2_5_omni_vision_attention_forward)

# audio opt
from ipex_llm.transformers.models.qwen2_5_omni import qwen2_5_omni_audio_attention_forward
convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioAttention,
qwen2_5_omni_audio_attention_forward)
convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioSdpaAttention,
qwen2_5_omni_audio_attention_forward)

# tts opt
if hasattr(model, "talker"):
convert_forward(model.talker, module.Qwen2_5OmniAttention,
if model.has_talker:
# talker part
convert_forward(model.talker.model, module.Qwen2_5OmniAttention,
qwen2_5_omni_attention_forward)
convert_forward(model.talker.model, module.Qwen2_5OmniSdpaAttention,
qwen2_5_omni_attention_forward)
convert_forward(model.talker, module.Qwen2_5OmniThinkerModel,
convert_forward(model.talker.model, module.Qwen2_5OmniTalkerModel,
qwen2_5_omni_thinker_model_forward)
convert_forward(model.talker.model, module.Qwen2MLP, qwen2_mlp_forward)

# token2wav part
from ipex_llm.transformers.models.qwen2_5_omni import dit_attention_forward
from ipex_llm.transformers.models.qwen2_5_omni import _create_block_diff
convert_forward(model.token2wav, module.DiTAttention, dit_attention_forward)
dit_model = model.token2wav.code2wav_dit_model
dit_model._create_block_diff = MethodType(_create_block_diff, dit_model)

return model


Expand Down
161 changes: 160 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import math
import torch
from typing import Optional, Tuple, List, Union
from transformers.cache_utils import Cache
from transformers.cache_utils import Cache, EncoderDecoderCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAttention
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb_vision
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_multimodal_rotary_pos_emb

Expand Down Expand Up @@ -284,3 +286,160 @@ def qwen2_5_omni_vision_attention_forward(
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output


def qwen2_5_omni_audio_attention_forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
cu_seqlens: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
seq_length, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
query_states = query_states.reshape(seq_length, self.num_heads, -1)

seq_lens = cu_seqlens.tolist()
invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length,
"unexpected input")

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id,
# we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache

# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)

if layer_head_mask is None and use_sdp_non_causal(query_states.size(-1),
query_states.device, query_states.dtype):
kv_length = key_states.size(0)
padding_kv_length = (kv_length + 128 - 1) // 128 * 128
attention_mask = torch.full(
[1, 1, seq_length, padding_kv_length], torch.finfo(query_states.dtype).min,
device=query_states.device, dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0

q = query_states.transpose(0, 1).unsqueeze(0)
k = key_states.transpose(0, 1).unsqueeze(0).contiguous()
v = value_states.transpose(0, 1).unsqueeze(0).contiguous()
# q, k, v: [1, num_heads, seq_length, head_dim]

attn_weights = None
attn_output = scaled_dot_product_attention(q, k, v, attention_mask, False)
attn_output = attn_output.permute(0, 2, 1, 3).squeeze(0)
# attn_output: [seq_length, num_heads, head_dim]
else:
attention_mask = torch.full(
[1, seq_length, key_states.size(0)], torch.finfo(query_states.dtype).min,
device=query_states.device, dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0

query_states = query_states.transpose(0, 1)
key_states = key_states.transpose(0, 1)
value_states = value_states.transpose(0, 1)

attn_weights = torch.matmul(query_states,
key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights)

if layer_head_mask is not None:
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights

attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state`s
# because `attn_output` can be partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(seq_length, self.embed_dim)
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights, past_key_value


def dit_attention_forward(
self,
x,
rope=None,
mask=None,
) -> torch.Tensor:
batch_size = x.shape[0]

# `sample` projections.
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)

# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)

# apply rotary position embedding
# Due to training process, only first head is applied with RoPE, will be fixed at next release
cos, sin = rope
query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin)

if use_sdp_non_causal(head_dim, query.device, query.dtype):
mask = torch.where(mask, 0, torch.finfo(query.dtype).min)
x = scaled_dot_product_attention(query, key.contiguous(), value.contiguous(), mask, False)
x = x.transpose(1, 2)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation]
x, _ = attention_interface(self, query, key, value, attention_mask=mask, is_causal=False)

# mask
x = x.reshape(batch_size, -1, self.heads * head_dim)
x = x.to(query.dtype)

# linear proj
x = self.to_out[0](x)
# dropout
x = self.to_out[1](x)

return x


def _create_block_diff(self, x):
batch, seq_len = x.shape[0], x.shape[1]
block_indices = torch.arange(seq_len, device=x.device) // self.block_size

block_i = block_indices.unsqueeze(1) # [seq_length, 1]
block_j = block_indices.unsqueeze(0) # [1, seq_length]

block_diff = block_j - block_i # (n, n)
return block_diff.unsqueeze(0).unsqueeze(0)