diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index cb0ec2816..a79010cb2 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -193,9 +193,10 @@ def calc_quantized_size( self._comm_precision == SparseType.FP8 and self._row_dim > 0 ): ctx = none_throws(ctx) - assert ( - input_len % ctx.row_dim == 0 - ), f"input_len {input_len} is not a multiple of row dim {ctx.row_dim}" + assert input_len % ctx.row_dim == 0, ( + f"input_len {input_len} is not a multiple of row dim {ctx.row_dim} " + "Please check your batch size (power of 2 batch size is recommended)" + ) nrows = input_len // ctx.row_dim ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4 return nrows * ncols