Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sequence parallel with communication overlap #5691

Merged
merged 20 commits into from
Aug 1, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refine code
inkcherry committed Jun 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 70a6d0c9944d731366710662ec6440c257c155bc
108 changes: 40 additions & 68 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
@@ -45,11 +45,11 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand
shape=( inp_shape[: gather_idx] + \
[inp_shape[gather_idx] * seq_world_size,] + \
inp_shape[gather_idx + 1:])
c=output.reshape(shape).contiguous()
res=output.reshape(shape).contiguous()
if type=='dq' or type=='dk':
handle[type+'_grad']=output
handle[type+'_grad_shape']=shape
return c, work
return res, work
#!! need to delete
c= output.reshape(
inp_shape[: gather_idx] + \
@@ -61,7 +61,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand
class _SeqAllToAll(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None,is_fwd=True) -> Tensor:
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None,bwd_async=False, handle=None,type=None,is_fwd=True) -> Tensor:

ctx.group = group
ctx.scatter_idx = scatter_idx
@@ -72,16 +72,12 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int,
ctx.handle=handle
ctx.type=type

# if fwd_async and stream!=None:
if not is_fwd and type=='o':
assert stream!=None
# print0('')
res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False)

get_accelerator().current_stream().wait_stream(ctx.stream)
# elif fwd_async and handle!=None:
elif not is_fwd and (type=='q' or type=='k'):
assert fwd_async==True
type='d'+type
res , work=single_all_to_all(input, scatter_idx, gather_idx, group,True,handle,type)

@@ -99,15 +95,11 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int,

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
# print0("all2all o before")
# import pydevd
# pydevd.settrace()



#def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor:
q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False,ctx.handle,ctx.type,False), None, None,None,None,None,None,None,None)
# print0("all2all o after")

return q
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,False,ctx.handle,ctx.type,False), None,None,None,None,None,None,None)


class DistributedAttention(torch.nn.Module):
@@ -137,18 +129,15 @@ def __init__(


self.sp_stream=sp_stream
self.bwd_all2all_handels={}
self.bwd_all2all_handels['dq']=None
self.bwd_all2all_handels['dq_grad']=None
self.bwd_all2all_handels['dk']=None
self.bwd_all2all_handels['dk_grad']=None
self.overlap_handles={}
self.overlap_handles['dq']=None
self.overlap_handles['dq_grad']=None
self.overlap_handles['dk']=None
self.overlap_handles['dk_grad']=None
self.dafult_stream=get_accelerator().default_stream()

self.hook_register=False



# query = slef.linearq(hidden)
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
""" forward
@@ -169,31 +158,21 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens


#step1 get q ,k ,v outside out this function
# def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor:


def q_hook(*notneeded):

self.bwd_all2all_handels['dq'].wait()
self.sp_stream.wait_stream(torch.cuda.default_stream())

tmp=self.bwd_all2all_handels['dq_grad']
notneeded=list(notneeded)
notneeded[0]=list(notneeded[0])
notneeded[0][0]=tmp.reshape(self.bwd_all2all_handels['dq_grad_shape']).contiguous()
notneeded[0]=tuple(notneeded[0])
notneeded=tuple(notneeded)
def bwd_hook(type):

def pre_hook(*notneeded):
self.overlap_handles['d'+type].wait()
self.sp_stream.wait_stream(torch.cuda.default_stream())
tmp=self.overlap_handles['d'+type+'_grad']
notneeded=list(notneeded)
notneeded[0]=list(notneeded[0])
notneeded[0][0]=tmp.reshape(self.overlap_handles['d'+type+'_grad_shape']).contiguous()
notneeded[0]=tuple(notneeded[0])
notneeded=tuple(notneeded)
return pre_hook


def k_hook(*notneeded):
self.bwd_all2all_handels['dk'].wait()
self.sp_stream.wait_stream(torch.cuda.default_stream())
tmp=self.bwd_all2all_handels['dk_grad']
notneeded=list(notneeded)
notneeded[0]=list(notneeded[0])
notneeded[0][0]=tmp.reshape(self.bwd_all2all_handels['dk_grad_shape']).contiguous()
notneeded[0]=tuple(notneeded[0])
notneeded=tuple(notneeded)




@@ -204,41 +183,34 @@ def k_hook(*notneeded):


self.dafult_stream.wait_event(query.done_event)
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'q') #[1,512,32,32]
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,async_bwd_comm_q,self.overlap_handles,'q') #[1,512,32,32]
self.dafult_stream.wait_event(key.done_event)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'k') #[1,512,32,32]
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,async_bwd_comm_k, self.overlap_handles,'k') #[1,512,32,32]
self.dafult_stream.wait_stream(self.sp_stream)
value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx,None,False,False, self.bwd_all2all_handels,'v') #[1,512,32,32]
value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx,None,False, self.overlap_handles,'v') #[1,512,32,32]

# hard code currently
if True:
async_bwd_comm_q=True
async_bwd_comm_k=True
#eval interval
fn_q = query.grad_fn.next_functions[0][0]
fn_q.register_prehook(q_hook)
fn_k = key.grad_fn.next_functions[0][0]
fn_k.register_prehook(k_hook)
#do dq qk k v
# def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None) -> Tensor:

self.bwd_all2all_handels['fwd_q'].wait()
self.bwd_all2all_handels['fwd_k'].wait()
# self.bwd_all2all_handels['fwd_q'].wait()
grad_fn_q = query.grad_fn.next_functions[0][0]
grad_fn_q.register_prehook(bwd_hook(type='q'))
grad_fn_k = key.grad_fn.next_functions[0][0]
grad_fn_k.register_prehook(bwd_hook(type='k'))



self.overlap_handles['fwd_q'].wait()
self.overlap_handles['fwd_k'].wait()
# self.overlap_handles['fwd_q'].wait()
#all2all ayns to k_dense_bwd wait
#out shape : e.g., [s:h/p:]
# print(query_layer) #2,8, 2,4 sp=2 2gpus
# #
# print(key_layer)
# print(value_layer) #seq_len 16 , sp 2 , head_dim = 4, num_heads=4, hidding=16


context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) #[8,512,4,32]
bwd_o_async=False
if self.sp_stream is not None:
bwd_o_async=True
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,False,bwd_o_async)
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,bwd_o_async)


#out e.g., [s/p::h]
return output

#o= self.dense(output)