diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 7de6620d65..faaa6e463e 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -54,6 +54,7 @@ @triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) @triton.jit def _scaled_int8_mm_kernel( A_ptr, @@ -176,7 +177,6 @@ def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tens *A.stride(), *B.stride(), *C.stride(), - EVEN_K=K % 2 == 0, COL_SCALE_SCALAR=col_scale.numel() == 1, ) return C