diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index cdcda6bf8..07765fa21 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -975,7 +975,7 @@ def matmul_fp8_row( allow_tf32: bool = True, fp8_fast_accum: bool = True, imprecise_acc: bool = False, - tma_persistent: bool = False, + tma_persistent: bool = True, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].