Skip to content

Commit

Permalink
src: introduce quant_entry_t and refactor arg_scales_t to rely on it
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 16, 2024
1 parent 96e5888 commit 58b979e
Show file tree
Hide file tree
Showing 140 changed files with 1,087 additions and 940 deletions.
15 changes: 8 additions & 7 deletions src/common/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ status_t binary_attr_check(const binary_desc_t &desc, const engine_t *engine,

// Check scales
if (!attr->scales_.has_default_values()) {
VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values(
{DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}),
static const std::vector<int> supported_args {
DNNL_ARG_SRC_0, DNNL_ARG_SRC_1};
VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

const auto &sc = attr->scales_;
const int mask_src_0 = sc.get(DNNL_ARG_SRC_0).mask_;
const int mask_src_1 = sc.get(DNNL_ARG_SRC_1).mask_;
for (int arg : supported_args) {
if (attr->scales_.get(arg).has_default_values()) continue;

VCHECK_BINARY_UNIMPL(utils::everyone_is(0, mask_src_0, mask_src_1),
VERBOSE_UNSUPPORTED_SCALES_CFG);
const int mask = attr->scales_.get_mask(arg);
VCHECK_BINARY_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}

// Check post-ops
Expand Down
11 changes: 7 additions & 4 deletions src/common/binary_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ struct binary_pd_t : public primitive_desc_t {

bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (int arg : supported_args) {
const auto &mask = attr()->scales_.get(arg).mask_;
ok = ok && (mask == 0);
const auto &scales = attr()->scales_;
bool ok = scales.has_default_values(supported_args);

for (const auto &arg : supported_args) {
if (scales.get(arg).has_default_values()) continue;

ok = ok && scales.get_mask(arg) == 0;
}
return ok;
}
Expand Down
22 changes: 17 additions & 5 deletions src/common/concat.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,10 +54,22 @@ status_t concat_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
VCHECK_CONCAT_UNIMPL(attr->has_default_values(smask_t::scales_runtime),
VERBOSE_UNSUPPORTED_ATTR);
const auto &scales = attr->scales_;
if (!scales.has_default_values())
for (const auto &s : scales.scales_)
VCHECK_CONCAT_UNIMPL(
s.second.mask_ == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
if (!scales.has_default_values()) {
std::vector<int> supported_args(n);
for (int i = 0; i < n; i++) {
supported_args[i] = DNNL_ARG_MULTIPLE_SRC + i;
}
VCHECK_CONCAT_UNIMPL(
attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

for (int arg : supported_args) {
if (scales.get(arg).has_default_values()) continue;

int mask = scales.get_mask(arg);
VCHECK_CONCAT_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}
}

const int ndims = src_mds[0]->ndims;
Expand Down
17 changes: 12 additions & 5 deletions src/common/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,20 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine,
// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;
const bool with_groups
= desc.src_desc.ndims != desc.weights_desc.ndims;
VCHECK_CONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst)
&& utils::one_of(mask_wei, 0, with_groups ? 3 : 1),
VCHECK_CONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(),
sc.get_mask(DNNL_ARG_SRC) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_CONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(),
utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0,
with_groups ? 3 : 1)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_CONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(),
sc.get_mask(DNNL_ARG_DST) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/convolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ struct convolution_pd_t : public primitive_desc_t {
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (int arg : supported_args) {
const auto &mask = attr()->scales_.get(arg).mask_;
if (attr()->scales_.get(arg).has_default_values()) continue;

const auto &mask = attr()->scales_.get_mask(arg);
if (arg == DNNL_ARG_WEIGHTS)
ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1));
else
Expand Down
17 changes: 12 additions & 5 deletions src/common/deconvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,20 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc,
// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;
const bool with_groups
= desc.src_desc.ndims != desc.weights_desc.ndims;
VCHECK_DECONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst)
&& utils::one_of(mask_wei, 0, with_groups ? 3 : 1),
VCHECK_DECONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(),
sc.get_mask(DNNL_ARG_SRC) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_DECONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(),
utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0,
with_groups ? 3 : 1)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_DECONV_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(),
sc.get_mask(DNNL_ARG_DST) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/deconvolution_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ struct deconvolution_pd_t : public primitive_desc_t {
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (int arg : supported_args) {
const auto &mask = attr()->scales_.get(arg).mask_;
if (attr()->scales_.get(arg).has_default_values()) continue;

const auto &mask = attr()->scales_.get_mask(arg);
if (arg == DNNL_ARG_WEIGHTS)
ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1));
else
Expand Down
18 changes: 12 additions & 6 deletions src/common/group_normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
* Copyright 2023-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -154,12 +154,18 @@ status_t group_normalization_attr_check(const group_normalization_desc_t &desc,

// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;

VCHECK_GNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst),
static const std::vector<int> supported_args {
DNNL_ARG_SRC, DNNL_ARG_DST};
VCHECK_GNORM_UNIMPL(
attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

for (int arg : supported_args) {
if (attr->scales_.get(arg).has_default_values()) continue;

const int mask = attr->scales_.get_mask(arg);
VCHECK_GNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}

// Check post-ops
Expand Down
10 changes: 5 additions & 5 deletions src/common/group_normalization_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,17 @@ struct group_normalization_fwd_pd_t : public group_normalization_pd_t {
return IMPLICATION(use_scale() || use_shift(),
weights_md()->data_type == data_type::f32);
}
bool attr_scales_ok() const {
bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC, DNNL_ARG_DST}) const {
using namespace data_type;
const auto &scales = attr()->scales_;
const std::vector<int> supported_args({DNNL_ARG_SRC, DNNL_ARG_DST});
bool ok = scales.has_default_values(supported_args);

for (const auto &arg : supported_args) {
const auto &sc = scales.get(arg);
if (!sc.has_default_values()) {
if (!scales.get(arg).has_default_values()) {
const data_type_t dt = arg_md(arg)->data_type;
ok = ok && utils::one_of(dt, s8, u8) && sc.mask_ == 0;
ok = ok && utils::one_of(dt, s8, u8);
ok = ok && scales.get_mask(arg) == 0;
}
}
return ok;
Expand Down
17 changes: 11 additions & 6 deletions src/common/inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,17 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;

VCHECK_IP_UNIMPL(utils::everyone_is(0, mask_src, mask_dst)
&& utils::one_of(mask_wei, 0, 1),
VCHECK_IP_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(),
sc.get_mask(DNNL_ARG_SRC) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_IP_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(),
utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, 1)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
VCHECK_IP_UNIMPL(
IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(),
sc.get_mask(DNNL_ARG_DST) == 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}

Expand Down
4 changes: 3 additions & 1 deletion src/common/inner_product_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ struct inner_product_pd_t : public primitive_desc_t {
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const {
bool ok = attr()->scales_.has_default_values(supported_args);
for (auto arg : supported_args) {
int mask = attr()->scales_.get(arg).mask_;
if (attr()->scales_.get(arg).has_default_values()) continue;

int mask = attr()->scales_.get_mask(arg);
if (arg == DNNL_ARG_WEIGHTS)
ok = ok && (mask == 0 || mask == (1 << 0));
else
Expand Down
16 changes: 11 additions & 5 deletions src/common/layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,18 @@ status_t layer_normalization_attr_check(const layer_normalization_desc_t &desc,

// Check scales
if (!attr->scales_.has_default_values()) {
const auto &sc = attr->scales_;
const int mask_src = sc.get(DNNL_ARG_SRC).mask_;
const int mask_dst = sc.get(DNNL_ARG_DST).mask_;

VCHECK_LNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst),
static const std::vector<int> supported_args {
DNNL_ARG_SRC, DNNL_ARG_DST};
VCHECK_LNORM_UNIMPL(
attr->scales_.has_default_values(supported_args),
VERBOSE_UNSUPPORTED_SCALES_CFG);

for (int arg : supported_args) {
if (attr->scales_.get(arg).has_default_values()) continue;

const int mask = attr->scales_.get_mask(arg);
VCHECK_LNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG);
}
}

// Check post-ops
Expand Down
16 changes: 12 additions & 4 deletions src/common/layer_normalization_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,19 @@ struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t {
return false;
}

bool attr_scales_ok() const {
bool attr_scales_ok(const std::vector<int> &supported_args
= {DNNL_ARG_SRC, DNNL_ARG_DST}) const {
using namespace data_type;
const auto &scales = attr()->scales_;
bool ok = true;
for (const auto &e : scales.scales_) {
ok = ok && e.second.mask_ == 0;
bool ok = scales.has_default_values(supported_args);

for (const auto &arg : supported_args) {
if (!scales.get(arg).has_default_values()) {
// TODO: disallow non-int8 scales?
// const data_type_t dt = arg_md(arg)->data_type;
// ok = ok && utils::one_of(dt, s8, u8);
ok = ok && scales.get_mask(arg) == 0;
}
}
return ok;
}
Expand Down
Loading

0 comments on commit 58b979e

Please sign in to comment.