From be1d5cadead271fdf6952d38e55219a50728a194 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Fri, 15 Sep 2023 11:54:38 -0700 Subject: [PATCH] Improve quantize_comm error message (#2018) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2018 As titled Reviewed By: jianyuh, henrylhtsang, edqwerty10 Differential Revision: D49295738 fbshipit-source-id: 45524d8e220ba6b686a99d201e24c6a3d839aed7 --- fbgemm_gpu/fbgemm_gpu/quantize_comm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index cb0ec2816..a79010cb2 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -193,9 +193,10 @@ def calc_quantized_size( self._comm_precision == SparseType.FP8 and self._row_dim > 0 ): ctx = none_throws(ctx) - assert ( - input_len % ctx.row_dim == 0 - ), f"input_len {input_len} is not a multiple of row dim {ctx.row_dim}" + assert input_len % ctx.row_dim == 0, ( + f"input_len {input_len} is not a multiple of row dim {ctx.row_dim} " + "Please check your batch size (power of 2 batch size is recommended)" + ) nrows = input_len // ctx.row_dim ncols = (ctx.row_dim + 3) // 4 * 4 + 2 * 4 return nrows * ncols