diff --git a/fbgemm_gpu/src/qlinear_channelwise/qlinear_channelwise_mtia.cpp b/fbgemm_gpu/src/qlinear_channelwise/qlinear_channelwise_mtia.cpp index 07c898ebe..326629d99 100644 --- a/fbgemm_gpu/src/qlinear_channelwise/qlinear_channelwise_mtia.cpp +++ b/fbgemm_gpu/src/qlinear_channelwise/qlinear_channelwise_mtia.cpp @@ -20,9 +20,72 @@ static at::Tensor qlinear_channelwise( at::Tensor weight_zero_point, at::Tensor relu) { // quantized linear function with - // activation: per-tensor quantization + // activation: per-tensor quantization, // weight: per-tensor quantization - return x; + // X.sizes = M * K, W.sizes = N * K, Y.sizes = M * N + // Input_scale.sizes = M, Weight_scale.sizes = 1, b.sizes = N + at::Tensor X = x.contiguous(); + at::Tensor W = weight.contiguous(); + at::Tensor b = bias.contiguous(); + at::Tensor Input_scale = input_scale.contiguous(); + at::Tensor Weight_scale = weight_scale.contiguous(); + at::Tensor Weight_zero_point = weight_zero_point.contiguous(); + + const auto x_dimensions = X.sizes(); + const int x_num_dim = x_dimensions.size(); + + TORCH_CHECK( + x_dimensions.back() == W.sizes().back(), + "X's inner-most dimension must match W's inner-most dimension!"); + + const int M = x_dimensions[0]; + const int K = x_dimensions[x_num_dim - 1]; + const int N = W.sizes()[0]; + + const uint8_t* X_data = (const uint8_t*)X.contiguous().storage().data(); + const uint8_t* W_data = (const uint8_t*)W.contiguous().storage().data(); + const float* b_data = b.data_ptr(); + const uint8_t* Weight_zero_point_data = + (const uint8_t*)Weight_zero_point.contiguous().storage().data(); + + // Matmul + // X.sizes = M * K, W.sizes = N * K, Y.sizes = M * N + std::vector Y_fp_vec = + std::vector(M * (std::vector::size_type)(N)); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + int32_t matmul = 0; + for (int k = 0; k < K; k++) { + int x = X_data[m * K + k]; + int w = W_data[n * K + k]; + + matmul += (x - 127) * (w - *Weight_zero_point_data); + } + Y_fp_vec[m * N + n] = matmul; + } + } + + // re-scale to fp & add bias + // Input_scale.sizes = M, Weight_scale.sizes = 1, b.sizes = N + std::vector O_scale = std::vector(M); + for (int i = 0; i < M; i++) { + O_scale[i] = + Input_scale[i].item().toFloat() * Weight_scale[0].item().toFloat(); + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + float Y_tmp = (Y_fp_vec[i * N + j] * O_scale[i]) + b_data[j]; + if (Y_tmp > 65504.0f) { + Y_tmp = 65504.0f; + } + Y_fp_vec[i * N + j] = Y_tmp; + } + } + + auto Y = at::from_blob( + Y_fp_vec.data(), {M, N}, at::TensorOptions().dtype(torch::kFloat32)); + return Y; } static at::Tensor qlinear_quant(