From 8aecc91e3cbedfed548046530cd700a3c55614f9 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Tue, 18 Feb 2025 14:56:56 +0800 Subject: [PATCH] init --- .../generate.py | 114 ++++++++++++------ 1 file changed, 75 insertions(+), 39 deletions(-) diff --git a/python/llm/example/GPU/CPU-GPU-Hybrid-DeepSeek-R1-Inference/generate.py b/python/llm/example/GPU/CPU-GPU-Hybrid-DeepSeek-R1-Inference/generate.py index 50b7d4a2ff6..28123edeab4 100644 --- a/python/llm/example/GPU/CPU-GPU-Hybrid-DeepSeek-R1-Inference/generate.py +++ b/python/llm/example/GPU/CPU-GPU-Hybrid-DeepSeek-R1-Inference/generate.py @@ -25,6 +25,8 @@ import ipex_llm from ipex_llm.transformers import AutoModelForCausalLM +from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.models.common import scaled_dot_product_attention from transformers import AutoTokenizer, GenerationConfig from transformers.cache_utils import Cache, DynamicCache @@ -147,54 +149,88 @@ def hybrid_DeepseekV3Attention_forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # ipex-llm modify: test ipex-llm mla kernel + if False: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = q + key_states = torch.cat( + [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])], + dim=-1 + ) + import xe_addons + # print(self.rotary_emb.__class__.__name__) + + if self.rotary_emb.__class__.__name__ == "DeepseekV3YarnRotaryEmbedding": + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states[:, :, :, self.qk_nope_head_dim:], + key_states[:, :, :, self.qk_nope_head_dim:]) + else: + invalidInputError(False, f"unknown rope method: {self.rotary_emb.__class__.__name__}") + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - ) + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + if True: + attn_weights = None + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len, self.softmax_scale ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + attn_output = attn_output[:, :, :, :self.v_head_dim] + else: + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" - f" {attn_output.size()}" + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous()