File tree Expand file tree Collapse file tree 3 files changed +8
-4
lines changed
transformer_engine/pytorch/module Expand file tree Collapse file tree 3 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -323,7 +323,9 @@ def forward(
323323 # Output buffer for Userbuffers reduce-scatter
324324 reduce_scatter_out = None
325325 if ub_overlap_rs_fprop :
326- out_shape = [reduce (multiply_op , inp_shape [:- 1 ]) // tp_world_size , out_features ]
326+ out_shape = list (inp_shape )
327+ out_shape [0 ] //= tp_world_size
328+ out_shape [- 1 ] = out_features
327329 reduce_scatter_out = torch .empty (
328330 out_shape , dtype = activation_dtype , device = ln_out_total .device
329331 )
Original file line number Diff line number Diff line change @@ -457,8 +457,8 @@ def forward(
457457 if ub_overlap_rs :
458458 ub_obj_fc2out = get_ub ("fc2_fprop" )
459459 dim_size = list (act_out .size ())
460- dim_size [0 ] = dim_size [ 0 ] // tp_world_size
461- dim_size [1 ] = fc2_weight .size (0 )
460+ dim_size [0 ] //= tp_world_size
461+ dim_size [- 1 ] = fc2_weight .size (0 )
462462 reduce_scatter_out = torch .empty (dim_size , dtype = activation_dtype , device = device )
463463
464464 # ------------------------------------------------------
Original file line number Diff line number Diff line change @@ -278,7 +278,9 @@ def forward(
278278 # Output buffer for Userbuffers reduce-scatter
279279 reduce_scatter_out = None
280280 if ub_overlap_rs_fprop :
281- out_shape = [reduce (multiply_op , inp .shape [:- 1 ]) // tp_world_size , out_features ]
281+ out_shape = list (inp .shape )
282+ out_shape [0 ] //= tp_world_size
283+ out_shape [- 1 ] = out_features
282284 reduce_scatter_out = torch .empty (
283285 out_shape , dtype = activation_dtype , device = inputmat_total .device
284286 )
You can’t perform that action at this time.
0 commit comments