Skip to content

Commit

Permalink
CPU implementaion of step 2 - per-channel quantize operator (pytorch#…
Browse files Browse the repository at this point in the history
…2341)

Summary:
Pull Request resolved: pytorch#2341

The operator was registered in D52136852

Reviewed By: jiyuanzFB

Differential Revision: D53746793

fbshipit-source-id: 2dff2d969d3a5d457d2b20b65e84883ce92f1407
  • Loading branch information
Jiawei Zhang authored and facebook-github-bot committed Feb 22, 2024
1 parent d777a8d commit ad27fd5
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions fbgemm_gpu/src/qlinear_channelwise/qlinear_channelwise_mtia.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include "torch/types.h"

static at::Tensor qlinear_channelwise(
at::Tensor x,
Expand All @@ -32,13 +33,26 @@ static at::Tensor qlinear_quant(
at::Tensor weight_scale,
at::Tensor weight_zero_point,
at::Tensor relu) {
assert(x.options().dtype() == at::kHalf);
assert(weight.options().dtype() == at::kQInt8);
assert(bias.options().dtype() == at::kFloat);
assert(input_scale.options().dtype() == at::kFloat);
assert(weight_scale.options().dtype() == at::kFloat);
assert(weight_zero_point.options().dtype() == at::kQUInt8);
return x;
at::Tensor X = x.contiguous();
const float* x_data = X.data_ptr<float>();
const int M = X.sizes()[0];
const int K = X.sizes()[1];
at::Tensor I_S = input_scale.contiguous();
const float* input_scale_data = I_S.data_ptr<float>();

std::vector<uint8_t> X_int8_vec =
std::vector<uint8_t>(M * (std::vector<uint8_t>::size_type)(K));
for (int m = 0; m < M; m++) {
const float inv_scale = 1.0f / input_scale_data[m];
for (int k = 0; k < K; k++) {
int32_t val = int32_t(inv_scale * x_data[m * K + k]) + 127;
X_int8_vec[m * K + k] = uint8_t(std::max(0, std::min(val, UINT8_MAX)));
}
}

auto Y = at::from_blob(
X_int8_vec.data(), {M, K}, at::TensorOptions().dtype(torch::kUInt8));
return Y;
}

static at::Tensor qlinear_qparams(
Expand Down

0 comments on commit ad27fd5

Please sign in to comment.