Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

src: attr: quantization refactor (part 1) #2270

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/common/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ status_t binary_attr_check(const binary_desc_t &desc, const engine_t *engine,

// Check scales
if (!attr->scales_.has_default_values()) {
VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values(
{DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}),
static const std::vector<int> supported_args {
DNNL_ARG_SRC_0, DNNL_ARG_SRC_1};
VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

const auto &sc = attr->scales_;
const int mask_src_0 = sc.get(DNNL_ARG_SRC_0).mask_;
const int mask_src_1 = sc.get(DNNL_ARG_SRC_1).mask_;
for (int arg : supported_args) {
if (attr->scales_.get(arg).has_default_values()) continue;

VCHECK_BINARY_UNIMPL(utils::everyone_is(0, mask_src_0, mask_src_1),
VERBOSE_UNSUPPORTED_SCALES_CFG);
const int mask = attr->scales_.get_mask(arg);
VCHECK_BINARY_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}

// Check post-ops
Expand Down
11 changes: 7 additions & 4 deletions src/common/binary_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ struct binary_pd_t : public primitive_desc_t {

bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (int arg : supported_args) {
const auto &mask = attr()->scales_.get(arg).mask_;
ok = ok && (mask == 0);
const auto &scales = attr()->scales_;
bool ok = scales.has_default_values(supported_args);

for (const auto &arg : supported_args) {
if (scales.get(arg).has_default_values()) continue;

ok = ok && scales.get_mask(arg) == 0;
}
return ok;
}
Expand Down
22 changes: 17 additions & 5 deletions src/common/concat.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,10 +54,22 @@ status_t concat_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
VCHECK_CONCAT_UNIMPL(attr->has_default_values(smask_t::scales_runtime),
VERBOSE_UNSUPPORTED_ATTR);
const auto &scales = attr->scales_;
if (!scales.has_default_values())
for (const auto &s : scales.scales_)
VCHECK_CONCAT_UNIMPL(
s.second.mask_ == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
if (!scales.has_default_values()) {
std::vector<int> supported_args(n);
for (int i = 0; i < n; i++) {
supported_args[i] = DNNL_ARG_MULTIPLE_SRC + i;
}
VCHECK_CONCAT_UNIMPL(
attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

for (int arg : supported_args) {
if (scales.get(arg).has_default_values()) continue;

int mask = scales.get_mask(arg);
VCHECK_CONCAT_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}
}

const int ndims = src_mds[0]->ndims;
Expand Down
17 changes: 12 additions & 5 deletions src/common/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,20 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine,
// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;
const bool with_groups
= desc.src_desc.ndims != desc.weights_desc.ndims;
VCHECK_CONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst)
&& utils::one_of(mask_wei, 0, with_groups ? 3 : 1),
VCHECK_CONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(),
sc.get_mask(DNNL_ARG_SRC) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_CONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(),
utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0,
with_groups ? 3 : 1)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_CONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(),
sc.get_mask(DNNL_ARG_DST) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/convolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ struct convolution_pd_t : public primitive_desc_t {
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (int arg : supported_args) {
const auto &mask = attr()->scales_.get(arg).mask_;
if (attr()->scales_.get(arg).has_default_values()) continue;

const auto &mask = attr()->scales_.get_mask(arg);
if (arg == DNNL_ARG_WEIGHTS)
ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1));
else
Expand Down
17 changes: 12 additions & 5 deletions src/common/deconvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,20 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc,
// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;
const bool with_groups
= desc.src_desc.ndims != desc.weights_desc.ndims;
VCHECK_DECONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst)
&& utils::one_of(mask_wei, 0, with_groups ? 3 : 1),
VCHECK_DECONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(),
sc.get_mask(DNNL_ARG_SRC) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_DECONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(),
utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0,
with_groups ? 3 : 1)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_DECONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(),
sc.get_mask(DNNL_ARG_DST) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/deconvolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ struct deconvolution_pd_t : public primitive_desc_t {
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (int arg : supported_args) {
const auto &mask = attr()->scales_.get(arg).mask_;
if (attr()->scales_.get(arg).has_default_values()) continue;

const auto &mask = attr()->scales_.get_mask(arg);
if (arg == DNNL_ARG_WEIGHTS)
ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1));
else
Expand Down
18 changes: 12 additions & 6 deletions src/common/group_normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
* Copyright 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -154,12 +154,18 @@ status_t group_normalization_attr_check(const group_normalization_desc_t &desc,

// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;

VCHECK_GNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst),
static const std::vector<int> supported_args {
DNNL_ARG_SRC, DNNL_ARG_DST};
VCHECK_GNORM_UNIMPL(
attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

for (int arg : supported_args) {
if (attr->scales_.get(arg).has_default_values()) continue;

const int mask = attr->scales_.get_mask(arg);
VCHECK_GNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}

// Check post-ops
Expand Down
10 changes: 5 additions & 5 deletions src/common/group_normalization_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,17 @@ struct group_normalization_fwd_pd_t : public group_normalization_pd_t {
return IMPLICATION(use_scale() || use_shift(),
weights_md()->data_type == data_type::f32);
}
bool attr_scales_ok() const {
bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC, DNNL_ARG_DST}) const {
using namespace data_type;
const auto &scales = attr()->scales_;
const std::vector<int> supported_args({DNNL_ARG_SRC, DNNL_ARG_DST});
bool ok = scales.has_default_values(supported_args);

for (const auto &arg : supported_args) {
const auto &sc = scales.get(arg);
if (!sc.has_default_values()) {
if (!scales.get(arg).has_default_values()) {
const data_type_t dt = arg_md(arg)->data_type;
ok = ok && utils::one_of(dt, s8, u8) && sc.mask_ == 0;
ok = ok && utils::one_of(dt, s8, u8);
ok = ok && scales.get_mask(arg) == 0;
}
}
return ok;
Expand Down
17 changes: 11 additions & 6 deletions src/common/inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,17 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;

VCHECK_IP_UNIMPL(utils::everyone_is(0, mask_src, mask_dst)
&& utils::one_of(mask_wei, 0, 1),
VCHECK_IP_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(),
sc.get_mask(DNNL_ARG_SRC) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_IP_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(),
utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, 1)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_IP_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(),
sc.get_mask(DNNL_ARG_DST) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/inner_product_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ struct inner_product_pd_t : public primitive_desc_t {
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (auto arg : supported_args) {
int mask = attr()->scales_.get(arg).mask_;
if (attr()->scales_.get(arg).has_default_values()) continue;

int mask = attr()->scales_.get_mask(arg);
if (arg == DNNL_ARG_WEIGHTS)
ok = ok && (mask == 0 || mask == (1 << 0));
else
Expand Down
16 changes: 11 additions & 5 deletions src/common/layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,18 @@ status_t layer_normalization_attr_check(const layer_normalization_desc_t &desc,

// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;

VCHECK_LNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst),
static const std::vector<int> supported_args {
DNNL_ARG_SRC, DNNL_ARG_DST};
VCHECK_LNORM_UNIMPL(
attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

for (int arg : supported_args) {
if (attr->scales_.get(arg).has_default_values()) continue;

const int mask = attr->scales_.get_mask(arg);
VCHECK_LNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}

// Check post-ops
Expand Down
16 changes: 12 additions & 4 deletions src/common/layer_normalization_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,19 @@ struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t {
return false;
}

bool attr_scales_ok() const {
bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC, DNNL_ARG_DST}) const {
using namespace data_type;
const auto &scales = attr()->scales_;
bool ok = true;
for (const auto &e : scales.scales_) {
ok = ok && e.second.mask_ == 0;
bool ok = scales.has_default_values(supported_args);

for (const auto &arg : supported_args) {
if (!scales.get(arg).has_default_values()) {
// TODO: disallow non-int8 scales?
// const data_type_t dt = arg_md(arg)->data_type;
// ok = ok && utils::one_of(dt, s8, u8);
ok = ok && scales.get_mask(arg) == 0;
}
}
return ok;
}
Expand Down
Loading
Loading