File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
transformer_engine/pytorch/csrc/extensions Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -406,8 +406,8 @@ void inplace_swizzle_scale_for_gemm(py::handle &tensor) {
406406 auto is_empty = [] (const NVTEBasicTensor &t) -> bool {
407407 return t.shape .ndim == 1 && t.shape .data [0 ] == 0 ;
408408 };
409- const bool has_rowwise_scales = is_empty (tensor_nvte.get_rowwise_scale_inv ());
410- const bool has_columnwise_scales = is_empty (tensor_nvte.get_columnwise_scale_inv ());
409+ const bool has_rowwise_scales = ! is_empty (tensor_nvte.get_rowwise_scale_inv ());
410+ const bool has_columnwise_scales = ! is_empty (tensor_nvte.get_columnwise_scale_inv ());
411411
412412 // Swizzle scaling factors
413413 auto [rowwise_scales, columnwise_scales] = swizzle_scales_for_gemm (tensor_nvte,
You can’t perform that action at this time.
0 commit comments