Skip to content

Commit b796c96

Browse files
committed
Fix bug in inplace swizzle function
Signed-off-by: Tim Moon <[email protected]>
1 parent fa7e7c0 commit b796c96

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

transformer_engine/pytorch/csrc/extensions/swizzle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,8 @@ void inplace_swizzle_scale_for_gemm(py::handle &tensor) {
406406
auto is_empty = [] (const NVTEBasicTensor &t) -> bool {
407407
return t.shape.ndim == 1 && t.shape.data[0] == 0;
408408
};
409-
const bool has_rowwise_scales = is_empty(tensor_nvte.get_rowwise_scale_inv());
410-
const bool has_columnwise_scales = is_empty(tensor_nvte.get_columnwise_scale_inv());
409+
const bool has_rowwise_scales = !is_empty(tensor_nvte.get_rowwise_scale_inv());
410+
const bool has_columnwise_scales = !is_empty(tensor_nvte.get_columnwise_scale_inv());
411411

412412
// Swizzle scaling factors
413413
auto [rowwise_scales, columnwise_scales] = swizzle_scales_for_gemm(tensor_nvte,

0 commit comments

Comments
 (0)