Skip to content

Commit 546e02a

Browse files
committed
Fix RS out dims
Signed-off-by: Tim Moon <tmoon@nvidia.com>
1 parent dfb53ca commit 546e02a

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ def forward(
323323
# Output buffer for Userbuffers reduce-scatter
324324
reduce_scatter_out = None
325325
if ub_overlap_rs_fprop:
326-
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
326+
out_shape = list(inp_shape)
327+
out_shape[0] //= tp_world_size
328+
out_shape[-1] = out_features
327329
reduce_scatter_out = torch.empty(
328330
out_shape, dtype=activation_dtype, device=ln_out_total.device
329331
)

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,8 @@ def forward(
457457
if ub_overlap_rs:
458458
ub_obj_fc2out = get_ub("fc2_fprop")
459459
dim_size = list(act_out.size())
460-
dim_size[0] = dim_size[0] // tp_world_size
461-
dim_size[1] = fc2_weight.size(0)
460+
dim_size[0] //= tp_world_size
461+
dim_size[-1] = fc2_weight.size(0)
462462
reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
463463

464464
# ------------------------------------------------------

transformer_engine/pytorch/module/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def forward(
278278
# Output buffer for Userbuffers reduce-scatter
279279
reduce_scatter_out = None
280280
if ub_overlap_rs_fprop:
281-
out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features]
281+
out_shape = list(inp.shape)
282+
out_shape[0] //= tp_world_size
283+
out_shape[-1] = out_features
282284
reduce_scatter_out = torch.empty(
283285
out_shape, dtype=activation_dtype, device=inputmat_total.device
284286
)

0 commit comments

Comments
 (0)