diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 1a98705412b..f04401672b4 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -2072,12 +2072,31 @@ def _optimize_post(model): convert_forward(model.thinker.visual, module.Qwen2_5OmniVisionSdpaAttention, qwen2_5_omni_vision_attention_forward) + # audio opt + from ipex_llm.transformers.models.qwen2_5_omni import qwen2_5_omni_audio_attention_forward + convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioAttention, + qwen2_5_omni_audio_attention_forward) + convert_forward(model.thinker.audio_tower, module.Qwen2_5OmniAudioSdpaAttention, + qwen2_5_omni_audio_attention_forward) + # tts opt - if hasattr(model, "talker"): - convert_forward(model.talker, module.Qwen2_5OmniAttention, + if model.has_talker: + # talker part + convert_forward(model.talker.model, module.Qwen2_5OmniAttention, + qwen2_5_omni_attention_forward) + convert_forward(model.talker.model, module.Qwen2_5OmniSdpaAttention, qwen2_5_omni_attention_forward) - convert_forward(model.talker, module.Qwen2_5OmniThinkerModel, + convert_forward(model.talker.model, module.Qwen2_5OmniTalkerModel, qwen2_5_omni_thinker_model_forward) + convert_forward(model.talker.model, module.Qwen2MLP, qwen2_mlp_forward) + + # token2wav part + from ipex_llm.transformers.models.qwen2_5_omni import dit_attention_forward + from ipex_llm.transformers.models.qwen2_5_omni import _create_block_diff + convert_forward(model.token2wav, module.DiTAttention, dit_attention_forward) + dit_model = model.token2wav.code2wav_dit_model + dit_model._create_block_diff = MethodType(_create_block_diff, dit_model) + return model diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py b/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py index 744b117c4db..5efe96327a2 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_5_omni.py @@ -20,9 +20,11 @@ import math import torch from typing import Optional, Tuple, List, Union -from transformers.cache_utils import Cache +from transformers.cache_utils import Cache, EncoderDecoderCache +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAttention +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb_vision from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_multimodal_rotary_pos_emb @@ -284,3 +286,160 @@ def qwen2_5_omni_vision_attention_forward( attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output + + +def qwen2_5_omni_audio_attention_forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + cu_seqlens: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + seq_length, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + query_states = query_states.reshape(seq_length, self.num_heads, -1) + + seq_lens = cu_seqlens.tolist() + invalidInputError(seq_lens[0] == 0 and seq_lens[-1] == seq_length, + "unexpected input") + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, + # we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(current_states).reshape(seq_length, self.num_heads, -1) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + if layer_head_mask is None and use_sdp_non_causal(query_states.size(-1), + query_states.device, query_states.dtype): + kv_length = key_states.size(0) + padding_kv_length = (kv_length + 128 - 1) // 128 * 128 + attention_mask = torch.full( + [1, 1, seq_length, padding_kv_length], torch.finfo(query_states.dtype).min, + device=query_states.device, dtype=query_states.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0 + + q = query_states.transpose(0, 1).unsqueeze(0) + k = key_states.transpose(0, 1).unsqueeze(0).contiguous() + v = value_states.transpose(0, 1).unsqueeze(0).contiguous() + # q, k, v: [1, num_heads, seq_length, head_dim] + + attn_weights = None + attn_output = scaled_dot_product_attention(q, k, v, attention_mask, False) + attn_output = attn_output.permute(0, 2, 1, 3).squeeze(0) + # attn_output: [seq_length, num_heads, head_dim] + else: + attention_mask = torch.full( + [1, seq_length, key_states.size(0)], torch.finfo(query_states.dtype).min, + device=query_states.device, dtype=query_states.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., seq_lens[i - 1]:seq_lens[i], seq_lens[i - 1]:seq_lens[i]] = 0 + + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + + attn_weights = torch.matmul(query_states, + key_states.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = attention_softmax(attn_weights) + + if layer_head_mask is not None: + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights + + attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state`s + # because `attn_output` can be partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(seq_length, self.embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +def dit_attention_forward( + self, + x, + rope=None, + mask=None, +) -> torch.Tensor: + batch_size = x.shape[0] + + # `sample` projections. + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = rope + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + if use_sdp_non_causal(head_dim, query.device, query.dtype): + mask = torch.where(mask, 0, torch.finfo(query.dtype).min) + x = scaled_dot_product_attention(query, key.contiguous(), value.contiguous(), mask, False) + x = x.transpose(1, 2) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + x, _ = attention_interface(self, query, key, value, attention_mask=mask, is_causal=False) + + # mask + x = x.reshape(batch_size, -1, self.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = self.to_out[0](x) + # dropout + x = self.to_out[1](x) + + return x + + +def _create_block_diff(self, x): + batch, seq_len = x.shape[0], x.shape[1] + block_indices = torch.arange(seq_len, device=x.device) // self.block_size + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + + block_diff = block_j - block_i # (n, n) + return block_diff.unsqueeze(0).unsqueeze(0)