diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index ab8984b7ae..1d0e519dff 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -577,18 +577,24 @@ def __init__(self, config, layer_number, else: local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type) - self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \ - or args.force_ds_sequence_parallel - if self.enable_ds_sequence_parallel: + if parallel_state.get_sequence_parallel_world_size() > 1 \ + or args.force_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) + # self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) + self.compute_attn_sp = DistributedAttention(self.compute_attn, + parallel_state.get_sequence_parallel_group(), + scatter_idx=2, + gather_idx=0, + hidden_size_per_attention_head=hidden_size_per_attention_head, + num_q_per_kv=self.num_key_value_groups if projection_size != kv_projection_size else -1) + self.compute_attn = lambda mixed_x_layer, *args, **kwargs: self.compute_attn_sp(mixed_x_layer, *args, **kwargs) + + if self.use_flash_attn: + self.core_attention_flash = local_attn else: - if self.use_flash_attn: - self.core_attention_flash = local_attn - else: - self.core_attention = local_attn - self.checkpoint_core_attention = config.recompute_granularity == 'selective' + self.core_attention = local_attn + self.checkpoint_core_attention = config.recompute_granularity == 'selective' # Output. self.dense = tensor_parallel.RowParallelLinear( @@ -650,78 +656,28 @@ def split_tensor(self, mixed_x_layer): return query_layer, key_layer, value_layer - def forward(self, hidden_states, attention_mask, - encoder_output=None, inference_params=None, - rotary_pos_emb=None): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - is_first_step = False - if inference_params: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_len - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, inference_value_memory) - is_first_step = True - else: - inference_key_memory, inference_value_memory = \ - inference_params.key_value_memory_dict[self.layer_number] - - # ===================== - # Query, Key, and Value - # ===================== - - if self.attention_type == AttnType.self_attn: - # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (-1, (self.num_key_value_groups + 2), - self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn] - (query_layer, - key_layer, - value_layer) = self.split_tensor(mixed_x_layer) - - # Repeat kv - if self.use_gqa: - key_layer = self.repeat_kv(key_layer, self.num_key_value_groups) - value_layer = self.repeat_kv(value_layer, - self.num_key_value_groups) - else: - assert not self.use_gqa, 'GQA + cross-attn not tested yet' - - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, - value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query_layer, _ = self.query(hidden_states) - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - query_layer = query_layer.view(*new_tensor_shape) - + def compute_attn(self, + mixed_x_layer, + attention_mask, + inference_params=None, + rotary_pos_emb=None): + + # [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (-1, (self.num_key_value_groups + 2), + self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn] + (query_layer, + key_layer, + value_layer) = self.split_tensor(mixed_x_layer) + + # Repeat kv + if self.use_gqa: + key_layer = self.repeat_kv(key_layer, self.num_key_value_groups) + value_layer = self.repeat_kv(value_layer, + self.num_key_value_groups) # ================================== # Adjust key and value for inference # ================================== @@ -786,43 +742,90 @@ def forward(self, hidden_states, attention_mask, # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - if self.enable_ds_sequence_parallel: - if self.use_flash_attn: - if not self.use_flash_attn_triton: - query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() - for x in (query_layer, key_layer, value_layer)] - - context_layer = self.dist_attn(query_layer, key_layer, value_layer) + if self.use_flash_attn: + if not self.use_flash_attn_triton: + query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] - if not self.use_flash_attn_triton: - context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + if self.sequence_parallel: + context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) else: - context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask) - else: - if self.use_flash_attn: - if not self.use_flash_attn_triton: - query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() - for x in (query_layer, key_layer, value_layer)] - - if self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) - else: - with tensor_parallel.get_cuda_rng_tracker().fork(): - context_layer = self.core_attention_flash(query_layer, key_layer, value_layer) - if not self.use_flash_attn_triton: - context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + if not self.use_flash_attn_triton: + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + else: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) else: - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask) - else: - context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask) + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] # ================= + def forward(self, hidden_states, + attention_mask, + encoder_output=None, + inference_params=None, + rotary_pos_emb=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_len + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + is_first_step = True + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + + if self.attention_type == AttnType.self_attn: + # Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + else: + assert not self.use_gqa, 'GQA + cross-attn not tested yet' + + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + mixed_x_layer = torch.cat((query_layer, mixed_kv_layer), dim=-1).reshape(query_layer.size()[:-1], (-1,)) + + context_layer = self.compute_attn(mixed_x_layer, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) output, bias = self.dense(context_layer)