From bfcf5a9d3cd6e347b8f5d57361e8719f89e497a1 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 27 Nov 2023 16:40:36 -0800 Subject: [PATCH] fuse query, key, and value all-2-all for better SP perforamnce --- deepspeed/sequence/layer.py | 43 +++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e1dbff87f4ec..214e64017e9d 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -67,41 +67,46 @@ class DistributedAttention(torch.nn.Module): gather_idx (int): gather_idx for all2all comm """ - def __init__( - self, - local_attention: Module, - sequence_process_group: dist.ProcessGroup, - scatter_idx: int = 2, - gather_idx: int = 0, - ) -> None: + def __init__(self, + local_attention: Module, + sequence_process_group: dist.ProcessGroup, + scatter_idx: int = 2, + gather_idx: int = 0, + hidden_size_per_attention_head: int = 128, + num_q_per_kv: int = -1) -> None: super(DistributedAttention, self).__init__() self.local_attn = local_attention self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx + self.hidden_size_per_attention_head = hidden_size_per_attention_head + self.num_q_per_kv = num_q_per_kv - def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: + def forward(self, mixed_x_layer: Tensor, *args: Any, **kwargs: Any) -> Tensor: """ forward Arguments: - query (Tensor): query input to the layer - key (Tensor): key input to the layer - value (Tensor): value input to the layer + mixed_x_layer including: + 1. query (Tensor): query input to the layer + 2. key (Tensor): key input to the layer + 3. value (Tensor): value input to the layer args: other args - + kwargs: other kw args Returns: * output (Tensor): context output """ - # TODO Merge three alltoall calls into one - # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! - #in shape : e.g., [s/p:h:] - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + sq, bs = mixed_x_layer.shape[:2] + if self.num_q_per_kv > 0 and \ + mixed_x_layer.shape[-1] % ((self.num_q_per_kv + 2) * self.hidden_size_per_attention_head) == 0: + intermediate_shape = (sq, bs, -1, (self.num_q_per_kv + 2), self.hidden_size_per_attention_head) + else: + intermediate_shape = (sq, bs, -1, self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*intermediate_shape) + mixed_x_layer = _SeqAllToAll.apply(self.spg, mixed_x_layer, self.scatter_idx, self.gather_idx) #out shape : e.g., [s:h/p:] - context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) + context_layer = self.local_attn(mixed_x_layer.reshape(sq, bs, -1), *args, **kwargs) output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)