From 58b979eac6c39853199645be9de64b0347d54f56 Mon Sep 17 00:00:00 2001 From: Dmitrii Zarukin Date: Wed, 6 Nov 2024 17:41:55 -0800 Subject: [PATCH] src: introduce quant_entry_t and refactor arg_scales_t to rely on it --- src/common/binary.cpp | 15 +- src/common/binary_pd.hpp | 11 +- src/common/concat.cpp | 22 +- src/common/convolution.cpp | 17 +- src/common/convolution_pd.hpp | 4 +- src/common/deconvolution.cpp | 17 +- src/common/deconvolution_pd.hpp | 4 +- src/common/group_normalization.cpp | 18 +- src/common/group_normalization_pd.hpp | 10 +- src/common/inner_product.cpp | 17 +- src/common/inner_product_pd.hpp | 4 +- src/common/layer_normalization.cpp | 16 +- src/common/layer_normalization_pd.hpp | 16 +- src/common/matmul.cpp | 94 +++--- src/common/matmul_pd.hpp | 40 ++- src/common/primitive_attr.cpp | 12 +- src/common/primitive_attr.hpp | 2 +- src/common/primitive_attr_quant.cpp | 99 +++++++ src/common/primitive_attr_quant.hpp | 270 +++++++++--------- src/common/primitive_hashing.cpp | 16 +- src/common/primitive_hashing.hpp | 1 - src/common/reorder.cpp | 18 +- src/common/sdpa_pd.hpp | 17 +- src/common/sdpa_types.hpp | 4 +- src/common/serialization.cpp | 14 +- src/common/serialization_stream.hpp | 3 +- src/common/softmax.cpp | 20 +- src/common/softmax_pd.hpp | 15 +- src/common/verbose.cpp | 33 +-- src/common/verbose.hpp | 1 + src/cpu/aarch64/acl_reorder.hpp | 12 +- src/cpu/aarch64/brgemm/brgemm.cpp | 17 +- src/cpu/aarch64/jit_brdgmm_dw_conv.cpp | 2 +- src/cpu/aarch64/jit_brgemm_conv_utils.cpp | 4 +- src/cpu/aarch64/jit_brgemm_post_ops.hpp | 4 +- src/cpu/aarch64/jit_uni_reorder.cpp | 11 +- src/cpu/aarch64/jit_uni_reorder_utils.cpp | 13 +- src/cpu/aarch64/matmul/acl_lowp_matmul.cpp | 6 +- src/cpu/aarch64/matmul/brgemm_matmul.cpp | 2 +- .../aarch64/matmul/brgemm_matmul_utils.cpp | 15 +- src/cpu/dw_convolution_utils.hpp | 16 +- src/cpu/gemm_convolution_utils.cpp | 2 +- src/cpu/gemm_inner_product_utils.cpp | 6 +- src/cpu/gemm_x8s8s32x_convolution.cpp | 12 +- src/cpu/gemm_x8s8s32x_inner_product.cpp | 5 +- src/cpu/matmul/gemm_bf16_matmul.cpp | 19 +- src/cpu/matmul/gemm_f32_matmul.cpp | 19 +- src/cpu/matmul/gemm_x8s8s32x_matmul.cpp | 9 +- src/cpu/matmul/matmul_utils.hpp | 7 + src/cpu/matmul/ref_matmul.cpp | 13 +- src/cpu/matmul/ref_matmul_int8.cpp | 20 +- src/cpu/ref_concat.hpp | 7 +- src/cpu/ref_convolution_int8.cpp | 10 +- src/cpu/ref_deconvolution.cpp | 14 +- src/cpu/ref_fused_convolution.hpp | 3 +- src/cpu/ref_inner_product_int8.cpp | 4 +- src/cpu/ref_sum.hpp | 2 +- src/cpu/reorder/cpu_reorder_pd.hpp | 10 +- src/cpu/reorder/simple_reorder.hpp | 38 ++- src/cpu/scale_utils.cpp | 41 ++- src/cpu/scale_utils.hpp | 2 +- src/cpu/x64/brgemm/brgemm.cpp | 23 +- src/cpu/x64/brgemm/capi/brgemm_api.cpp | 2 +- .../jit_avx512_core_amx_1x1_conv_kernel.cpp | 3 +- .../jit_avx512_core_amx_1x1_convolution.cpp | 5 +- .../jit_avx512_core_amx_1x1_convolution.hpp | 4 +- .../x64/jit_avx512_core_amx_conv_kernel.cpp | 4 +- .../x64/jit_avx512_core_amx_convolution.cpp | 15 +- .../x64/jit_avx512_core_amx_convolution.hpp | 8 +- .../x64/jit_avx512_core_amx_deconvolution.cpp | 5 +- .../x64/jit_avx512_core_scale_precompute.cpp | 16 +- .../x64/jit_avx512_core_scale_precompute.hpp | 8 +- ...t_avx512_core_x8s8s32x_1x1_conv_kernel.cpp | 4 +- ...t_avx512_core_x8s8s32x_1x1_convolution.cpp | 6 +- .../jit_avx512_core_x8s8s32x_conv_kernel.cpp | 4 +- .../jit_avx512_core_x8s8s32x_convolution.cpp | 4 +- ...jit_avx512_core_x8s8s32x_deconvolution.cpp | 6 +- src/cpu/x64/jit_brdgmm_dw_conv.cpp | 11 +- src/cpu/x64/jit_brgemm_1x1_conv.cpp | 9 +- src/cpu/x64/jit_brgemm_conv.cpp | 9 +- src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp | 9 +- src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp | 2 +- src/cpu/x64/jit_brgemm_conv_utils.cpp | 4 +- src/cpu/x64/jit_brgemm_inner_product.cpp | 3 +- src/cpu/x64/jit_brgemm_inner_product.hpp | 4 +- .../x64/jit_brgemm_inner_product_utils.cpp | 3 +- src/cpu/x64/jit_brgemm_post_ops.cpp | 3 +- src/cpu/x64/jit_uni_reorder.cpp | 11 +- src/cpu/x64/jit_uni_reorder_utils.cpp | 13 +- .../x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp | 4 +- .../x64/jit_uni_x8s8s32x_1x1_convolution.cpp | 6 +- src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp | 4 +- src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp | 4 +- .../x64/jit_uni_x8s8s32x_deconvolution.cpp | 6 +- src/cpu/x64/matmul/brgemm_matmul.cpp | 21 +- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 16 +- src/gpu/generic/ref_concat.hpp | 8 +- .../sycl/layer_normalizations_kernels.hpp | 30 +- src/gpu/generic/sycl/ref_binary.hpp | 3 +- src/gpu/generic/sycl/ref_convolution.cpp | 9 +- src/gpu/generic/sycl/ref_convolution.hpp | 2 +- .../generic/sycl/ref_layer_normalizations.cpp | 8 +- .../generic/sycl/ref_layer_normalizations.hpp | 2 +- src/gpu/generic/sycl/ref_matmul.cpp | 2 +- src/gpu/generic/sycl/ref_matmul.hpp | 3 +- src/gpu/generic/sycl/ref_reorder.cpp | 4 +- src/gpu/generic/sycl/ref_reorder.hpp | 2 +- src/gpu/generic/sycl/ref_softmax.hpp | 11 +- src/gpu/generic/sycl/ref_sum.hpp | 12 +- src/gpu/generic/sycl/reorder_kernels.hpp | 8 +- src/gpu/gpu_utils.hpp | 28 +- src/gpu/intel/jit/conv/config.cpp | 4 +- src/gpu/intel/jit/gemm/gen_gemm.hpp | 24 +- src/gpu/intel/jit/gemm/jit_gemm_pd.cpp | 32 +-- src/gpu/intel/jit/ir/post_ops.cpp | 50 ++-- src/gpu/intel/jit/ir/tensor_config.cpp | 2 +- src/gpu/intel/jit/reorder/gen_reorder.cpp | 9 +- src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp | 12 +- src/gpu/intel/ocl/gemm/ref_gemm.hpp | 13 +- src/gpu/intel/ocl/gemm_inner_product.hpp | 21 +- src/gpu/intel/ocl/gemm_matmul.hpp | 30 +- src/gpu/intel/ocl/gen9_binary.hpp | 18 +- src/gpu/intel/ocl/generic_reorder.cpp | 4 +- src/gpu/intel/ocl/micro_sdpa.cpp | 9 +- src/gpu/intel/ocl/micro_sdpa.hpp | 4 +- src/gpu/intel/ocl/multi_po_reorder_binary.hpp | 4 +- src/gpu/intel/ocl/ref_matmul.cpp | 20 +- src/gpu/intel/ocl/ref_matmul.hpp | 6 +- src/gpu/intel/primitive_conf.cpp | 10 +- src/gpu/intel/primitive_conf.hpp | 3 +- src/gpu/nvidia/cudnn_binary.hpp | 7 +- src/gpu/nvidia/cudnn_convolution.hpp | 20 +- src/gpu/nvidia/cudnn_inner_product.hpp | 20 +- src/gpu/nvidia/cudnn_matmul.hpp | 9 +- src/gpu/nvidia/cudnn_matmul_lt.hpp | 34 ++- src/gpu/nvidia/cudnn_matmul_lt_impl.hpp | 15 +- src/gpu/nvidia/cudnn_reorder.hpp | 19 +- src/gpu/nvidia/cudnn_reorder_lt.hpp | 46 +-- src/gpu/nvidia/cudnn_softmax.hpp | 8 +- tests/gtests/test_iface_attr.cpp | 8 +- 140 files changed, 1087 insertions(+), 940 deletions(-) create mode 100644 src/common/primitive_attr_quant.cpp diff --git a/src/common/binary.cpp b/src/common/binary.cpp index a29aeb017d7..7020be02b92 100644 --- a/src/common/binary.cpp +++ b/src/common/binary.cpp @@ -55,16 +55,17 @@ status_t binary_attr_check(const binary_desc_t &desc, const engine_t *engine, // Check scales if (!attr->scales_.has_default_values()) { - VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values( - {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}), + static const std::vector supported_args { + DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}; + VCHECK_BINARY_UNIMPL(attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); - const auto &sc = attr->scales_; - const int mask_src_0 = sc.get(DNNL_ARG_SRC_0).mask_; - const int mask_src_1 = sc.get(DNNL_ARG_SRC_1).mask_; + for (int arg : supported_args) { + if (attr->scales_.get(arg).has_default_values()) continue; - VCHECK_BINARY_UNIMPL(utils::everyone_is(0, mask_src_0, mask_src_1), - VERBOSE_UNSUPPORTED_SCALES_CFG); + const int mask = attr->scales_.get_mask(arg); + VCHECK_BINARY_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops diff --git a/src/common/binary_pd.hpp b/src/common/binary_pd.hpp index b9c94e58475..a9b624aa276 100644 --- a/src/common/binary_pd.hpp +++ b/src/common/binary_pd.hpp @@ -179,10 +179,13 @@ struct binary_pd_t : public primitive_desc_t { bool attr_scales_ok(const std::vector &supported_args = {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_DST}) const { - bool ok = attr()->scales_.has_default_values(supported_args); - for (int arg : supported_args) { - const auto &mask = attr()->scales_.get(arg).mask_; - ok = ok && (mask == 0); + const auto &scales = attr()->scales_; + bool ok = scales.has_default_values(supported_args); + + for (const auto &arg : supported_args) { + if (scales.get(arg).has_default_values()) continue; + + ok = ok && scales.get_mask(arg) == 0; } return ok; } diff --git a/src/common/concat.cpp b/src/common/concat.cpp index d686df416f8..4932b7709f1 100644 --- a/src/common/concat.cpp +++ b/src/common/concat.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,10 +54,22 @@ status_t concat_primitive_desc_create(std::shared_ptr &pd, VCHECK_CONCAT_UNIMPL(attr->has_default_values(smask_t::scales_runtime), VERBOSE_UNSUPPORTED_ATTR); const auto &scales = attr->scales_; - if (!scales.has_default_values()) - for (const auto &s : scales.scales_) - VCHECK_CONCAT_UNIMPL( - s.second.mask_ == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + if (!scales.has_default_values()) { + std::vector supported_args(n); + for (int i = 0; i < n; i++) { + supported_args[i] = DNNL_ARG_MULTIPLE_SRC + i; + } + VCHECK_CONCAT_UNIMPL( + attr->scales_.has_default_values(supported_args), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (scales.get(arg).has_default_values()) continue; + + int mask = scales.get_mask(arg); + VCHECK_CONCAT_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } + } } const int ndims = src_mds[0]->ndims; diff --git a/src/common/convolution.cpp b/src/common/convolution.cpp index c74e1e2d7a8..7e5e759e284 100644 --- a/src/common/convolution.cpp +++ b/src/common/convolution.cpp @@ -178,13 +178,20 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine, // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; const bool with_groups = desc.src_desc.ndims != desc.weights_desc.ndims; - VCHECK_CONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst) - && utils::one_of(mask_wei, 0, with_groups ? 3 : 1), + VCHECK_CONV_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(), + sc.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_CONV_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(), + utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, + with_groups ? 3 : 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_CONV_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(), + sc.get_mask(DNNL_ARG_DST) == 0), VERBOSE_UNSUPPORTED_SCALES_CFG); } diff --git a/src/common/convolution_pd.hpp b/src/common/convolution_pd.hpp index ee6f631b405..7cb62ec1910 100644 --- a/src/common/convolution_pd.hpp +++ b/src/common/convolution_pd.hpp @@ -242,7 +242,9 @@ struct convolution_pd_t : public primitive_desc_t { = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { bool ok = attr()->scales_.has_default_values(supported_args); for (int arg : supported_args) { - const auto &mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1)); else diff --git a/src/common/deconvolution.cpp b/src/common/deconvolution.cpp index 00f3f89d037..336a3c9095f 100644 --- a/src/common/deconvolution.cpp +++ b/src/common/deconvolution.cpp @@ -172,13 +172,20 @@ status_t deconv_attr_check(const deconvolution_desc_t &desc, // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; const bool with_groups = desc.src_desc.ndims != desc.weights_desc.ndims; - VCHECK_DECONV_UNIMPL(utils::everyone_is(0, mask_src, mask_dst) - && utils::one_of(mask_wei, 0, with_groups ? 3 : 1), + VCHECK_DECONV_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(), + sc.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_DECONV_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(), + utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, + with_groups ? 3 : 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_DECONV_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(), + sc.get_mask(DNNL_ARG_DST) == 0), VERBOSE_UNSUPPORTED_SCALES_CFG); } diff --git a/src/common/deconvolution_pd.hpp b/src/common/deconvolution_pd.hpp index 32dfb1a58ef..cd582e28e3e 100644 --- a/src/common/deconvolution_pd.hpp +++ b/src/common/deconvolution_pd.hpp @@ -173,7 +173,9 @@ struct deconvolution_pd_t : public primitive_desc_t { = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { bool ok = attr()->scales_.has_default_values(supported_args); for (int arg : supported_args) { - const auto &mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) ok = ok && (mask == 0 || mask == (with_groups() ? 3 : 1)); else diff --git a/src/common/group_normalization.cpp b/src/common/group_normalization.cpp index 4e0abf3b6a2..6c4aec9f72c 100644 --- a/src/common/group_normalization.cpp +++ b/src/common/group_normalization.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2023-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -154,12 +154,18 @@ status_t group_normalization_attr_check(const group_normalization_desc_t &desc, // Check scales if (!attr->scales_.has_default_values()) { - const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_GNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst), + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_GNORM_UNIMPL( + attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (attr->scales_.get(arg).has_default_values()) continue; + + const int mask = attr->scales_.get_mask(arg); + VCHECK_GNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops diff --git a/src/common/group_normalization_pd.hpp b/src/common/group_normalization_pd.hpp index a9a48ed29f0..574dde8193b 100644 --- a/src/common/group_normalization_pd.hpp +++ b/src/common/group_normalization_pd.hpp @@ -190,17 +190,17 @@ struct group_normalization_fwd_pd_t : public group_normalization_pd_t { return IMPLICATION(use_scale() || use_shift(), weights_md()->data_type == data_type::f32); } - bool attr_scales_ok() const { + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_DST}) const { using namespace data_type; const auto &scales = attr()->scales_; - const std::vector supported_args({DNNL_ARG_SRC, DNNL_ARG_DST}); bool ok = scales.has_default_values(supported_args); for (const auto &arg : supported_args) { - const auto &sc = scales.get(arg); - if (!sc.has_default_values()) { + if (!scales.get(arg).has_default_values()) { const data_type_t dt = arg_md(arg)->data_type; - ok = ok && utils::one_of(dt, s8, u8) && sc.mask_ == 0; + ok = ok && utils::one_of(dt, s8, u8); + ok = ok && scales.get_mask(arg) == 0; } } return ok; diff --git a/src/common/inner_product.cpp b/src/common/inner_product.cpp index 4c1943cf67f..2ed1c179f0c 100644 --- a/src/common/inner_product.cpp +++ b/src/common/inner_product.cpp @@ -133,12 +133,17 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine, // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_wei = sc.get(DNNL_ARG_WEIGHTS).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_IP_UNIMPL(utils::everyone_is(0, mask_src, mask_dst) - && utils::one_of(mask_wei, 0, 1), + VCHECK_IP_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_SRC).has_default_values(), + sc.get_mask(DNNL_ARG_SRC) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_IP_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_WEIGHTS).has_default_values(), + utils::one_of(sc.get_mask(DNNL_ARG_WEIGHTS), 0, 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_IP_UNIMPL( + IMPLICATION(!sc.get(DNNL_ARG_DST).has_default_values(), + sc.get_mask(DNNL_ARG_DST) == 0), VERBOSE_UNSUPPORTED_SCALES_CFG); } diff --git a/src/common/inner_product_pd.hpp b/src/common/inner_product_pd.hpp index 5d7d2163ddf..cac09295bb0 100644 --- a/src/common/inner_product_pd.hpp +++ b/src/common/inner_product_pd.hpp @@ -184,7 +184,9 @@ struct inner_product_pd_t : public primitive_desc_t { = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { bool ok = attr()->scales_.has_default_values(supported_args); for (auto arg : supported_args) { - int mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.get(arg).has_default_values()) continue; + + int mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) ok = ok && (mask == 0 || mask == (1 << 0)); else diff --git a/src/common/layer_normalization.cpp b/src/common/layer_normalization.cpp index 79ccc98c45c..17b59bcc587 100644 --- a/src/common/layer_normalization.cpp +++ b/src/common/layer_normalization.cpp @@ -164,12 +164,18 @@ status_t layer_normalization_attr_check(const layer_normalization_desc_t &desc, // Check scales if (!attr->scales_.has_default_values()) { - const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_LNORM_UNIMPL(utils::everyone_is(0, mask_src, mask_dst), + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_LNORM_UNIMPL( + attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (attr->scales_.get(arg).has_default_values()) continue; + + const int mask = attr->scales_.get_mask(arg); + VCHECK_LNORM_UNIMPL(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops diff --git a/src/common/layer_normalization_pd.hpp b/src/common/layer_normalization_pd.hpp index 0629be31180..f1df246db9a 100644 --- a/src/common/layer_normalization_pd.hpp +++ b/src/common/layer_normalization_pd.hpp @@ -248,11 +248,19 @@ struct layer_normalization_fwd_pd_t : public layer_normalization_pd_t { return false; } - bool attr_scales_ok() const { + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_DST}) const { + using namespace data_type; const auto &scales = attr()->scales_; - bool ok = true; - for (const auto &e : scales.scales_) { - ok = ok && e.second.mask_ == 0; + bool ok = scales.has_default_values(supported_args); + + for (const auto &arg : supported_args) { + if (!scales.get(arg).has_default_values()) { + // TODO: disallow non-int8 scales? + // const data_type_t dt = arg_md(arg)->data_type; + // ok = ok && utils::one_of(dt, s8, u8); + ok = ok && scales.get_mask(arg) == 0; + } } return ok; } diff --git a/src/common/matmul.cpp b/src/common/matmul.cpp index bd9fb69f067..66de4962fe5 100644 --- a/src/common/matmul.cpp +++ b/src/common/matmul.cpp @@ -88,29 +88,65 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine, // Check scales if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; - const auto &sc_src = sc.get(DNNL_ARG_SRC); - const auto &sc_wei = sc.get(DNNL_ARG_WEIGHTS); - const int mask_src = sc_src.mask_; - const int mask_wei = sc_wei.mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; + dim_t src_scale_group_k = 1; + if (!sc.get(DNNL_ARG_SRC).has_default_values()) { + const int mask_src = sc.get_mask(DNNL_ARG_SRC); + + VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K, + src_qmask_M + src_qmask_K), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + if (!sc.get(DNNL_ARG_SRC).has_default_groups()) { + if (mask_src & src_qmask_K) + src_scale_group_k = sc.get_group(DNNL_ARG_SRC, 1); + } + + // Due to hardware specifics, groups should be multiple of 32. + VCHECK_MATMUL_UNIMPL(IMPLICATION(src_scale_group_k > 1, + src_scale_group_k % 32 == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + dim_t wei_scale_group_k = 1; + dim_t wei_scale_group_n = 1; + if (!sc.get(DNNL_ARG_WEIGHTS).has_default_values()) { + const int mask_wei = sc.get_mask(DNNL_ARG_WEIGHTS); + + // Masks for weights scales can be any - skipping them. + + if (!sc.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + if (mask_wei & wei_qmask_K) + wei_scale_group_k = sc.get_group(DNNL_ARG_WEIGHTS, 0); + if (mask_wei & wei_qmask_N) + wei_scale_group_n = sc.get_group(DNNL_ARG_WEIGHTS, 1); + } + + // Groups per N are solely for weights decompression as it's + // impossible to get performant kernel for a single `k` element in + // chain for regular quantized case. + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_scale_group_n > 1, + attr->fpmath_.apply_to_int_), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + // Due to hardware specifics, groups should be multiple of 32. + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_scale_group_k > 1, + wei_scale_group_k % 32 == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VCHECK_MATMUL_UNIMPL(IMPLICATION(wei_scale_group_n > 1, + wei_scale_group_n % 32 == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } + + if (!sc.get(DNNL_ARG_DST).has_default_values()) { + const int mask_dst = sc.get_mask(DNNL_ARG_DST); + + VCHECK_MATMUL_UNIMPL(mask_dst == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } - VCHECK_MATMUL_UNIMPL(utils::one_of(mask_src, 0, src_qmask_K, - src_qmask_M + src_qmask_K), - VERBOSE_UNSUPPORTED_SCALES_CFG); - // Masks for weights scales can be any - skipping them. - VCHECK_MATMUL_UNIMPL(mask_dst == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); // Check dependency between scales. // Source scales groups are supported for int8 source and must divide // or be divided by weights groups when both are greater than 1. - const auto src_scale_group_k - = (mask_src & src_qmask_K) && sc_src.ndims_ > 0 - ? sc_src.group_dims_[1] - : 1; - const auto wei_scale_group_k - = (mask_wei & wei_qmask_K) && sc_wei.ndims_ > 0 - ? sc_wei.group_dims_[0] - : 1; const bool groups_are_divisible = IMPLICATION( src_scale_group_k > 1 && wei_scale_group_k > 1, (src_scale_group_k % wei_scale_group_k == 0) @@ -118,28 +154,6 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine, VCHECK_MATMUL_UNIMPL(IMPLICATION(src_scale_group_k > 1, src_is_int8 && groups_are_divisible), VERBOSE_UNSUPPORTED_SCALES_CFG); - - // Groups per N are solely for weights decompression as it's impossible - // to get performant kernel for a single `k` element in chain for - // regular quantized case. - const auto wei_scale_group_n - = (mask_wei & wei_qmask_N) && sc_wei.ndims_ > 0 - ? sc_wei.group_dims_[1] - : 1; - VCHECK_MATMUL_UNIMPL( - IMPLICATION(wei_scale_group_n > 1, attr->fpmath_.apply_to_int_), - VERBOSE_UNSUPPORTED_SCALES_CFG); - - // Due to hardware specifics, groups should be multiple of 32. - VCHECK_MATMUL_UNIMPL( - IMPLICATION(src_scale_group_k > 1, src_scale_group_k % 32 == 0), - VERBOSE_UNSUPPORTED_SCALES_CFG); - VCHECK_MATMUL_UNIMPL( - IMPLICATION(wei_scale_group_k > 1, wei_scale_group_k % 32 == 0), - VERBOSE_UNSUPPORTED_SCALES_CFG); - VCHECK_MATMUL_UNIMPL( - IMPLICATION(wei_scale_group_n > 1, wei_scale_group_n % 32 == 0), - VERBOSE_UNSUPPORTED_SCALES_CFG); } // Check zero points diff --git a/src/common/matmul_pd.hpp b/src/common/matmul_pd.hpp index d1255f3b227..ce282c40819 100644 --- a/src/common/matmul_pd.hpp +++ b/src/common/matmul_pd.hpp @@ -161,29 +161,26 @@ struct matmul_pd_t : public primitive_desc_t { virtual bool attr_scales_ok(const std::vector &supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { - if (attr()->scales_.has_default_values()) return true; + const auto &scales = attr()->scales_; + if (scales.has_default_values()) return true; - bool ok = attr()->scales_.has_default_values(supported_args); + bool ok = scales.has_default_values(supported_args); for (int arg : supported_args) { - const auto &sc = attr()->scales_.get(arg); - const auto &mask = sc.mask_; - if (sc.has_default_values()) { continue; } + if (scales.get(arg).has_default_values()) { continue; } + const auto &mask = scales.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS) { - const bool wei_k_group_ok - = IMPLICATION(sc.ndims_ == 2 && sc.group_dims_[0] > 1, - K() % sc.group_dims_[0] == 0); - const bool wei_n_group_ok - = IMPLICATION(sc.ndims_ == 2 && sc.group_dims_[1] > 1, - N() % sc.group_dims_[1] == 0); + const auto &g0 = scales.get_group(arg, 0); + const auto &g1 = scales.get_group(arg, 1); + const bool wei_k_group_ok = IMPLICATION(g0 > 1, K() % g1 == 0); + const bool wei_n_group_ok = IMPLICATION(g1 > 1, N() % g0 == 0); // Any group is allowed to be greater than 1 but only one at a // time, not both. - ok = ok && utils::one_of(sc.ndims_, 0, 2) - && IMPLICATION(sc.ndims_ == 2, - utils::one_of( - 1, sc.group_dims_[0], sc.group_dims_[1]) - && wei_k_group_ok && wei_n_group_ok); + ok = ok + && IMPLICATION(!scales.get(arg).has_default_groups(), + utils::one_of(1, g0, g1) && wei_k_group_ok + && wei_n_group_ok); // Mask over K dim is allowed for decompression or dynamic // quantization features only. @@ -200,12 +197,13 @@ struct matmul_pd_t : public primitive_desc_t { ok = ok && utils::one_of(mask, 0, src_qmask_K(), src_qmask_M() + src_qmask_K()); - ok = ok && utils::one_of(sc.ndims_, 0, 2); - ok = ok && IMPLICATION((mask & src_qmask_K()), sc.ndims_ == 2); ok = ok - && IMPLICATION(sc.ndims_ == 2, - sc.group_dims_[0] == 1 - && K() % sc.group_dims_[1] == 0); + && IMPLICATION((mask & src_qmask_K()), + !scales.get(arg).has_default_groups()); + ok = ok + && IMPLICATION(!scales.get(arg).has_default_groups(), + scales.get_group(arg, 0) + && K() % scales.get_group(arg, 1) == 0); } else { ok = ok && (mask == 0); } diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index 54df07f27fc..c384e7505fa 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -35,11 +35,6 @@ const primitive_attr_t &default_attr() { return default_attr_instance; } -const runtime_scales_t &default_runtime_scale() { - static const runtime_scales_t default_runtime_scale_instance; - return default_runtime_scale_instance; -} - void rnn_create_time_scales_t::set_single_scale(float scale) { count_ = 1; mask_ = 0; @@ -543,8 +538,9 @@ status_t dnnl_primitive_attr_set_scratchpad_mode( status_t dnnl_primitive_attr_set_scales_mask( primitive_attr_t *attr, int arg, int mask) { - bool ok = attr && mask >= 0 && arg >= 0; - if (!ok) return invalid_arguments; + VCHECK_ATTR(attr, VERBOSE_NULL_ARG); + VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask"); + VCHECK_ATTR(arg >= 0, VERBOSE_BAD_PARAM, "arg"); return attr->scales_.set(arg, mask); } @@ -564,7 +560,7 @@ status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg, VERBOSE_INVALID_DATATYPE, "scales"); VCHECK_ATTR(IMPLICATION(ndims, validate_dims(ndims, group_dims)), VERBOSE_BAD_PARAM, "group_dims"); - return attr->scales_.set(arg, mask, ndims, group_dims, data_type); + return attr->scales_.set(arg, mask, data_type, ndims, group_dims); } status_t dnnl_primitive_attr_set_zero_points_mask( diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index 29e01e9bbf0..f4e22020917 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -669,7 +669,7 @@ struct dnnl_primitive_attr : public dnnl::impl::c_compatible { } // NOTE: make sure that the types below have overloaded comparison operator - dnnl::impl::arg_scales_t scales_; + dnnl::impl::scales_t scales_; dnnl::impl::zero_points_t zero_points_; dnnl::impl::scratchpad_mode_t scratchpad_mode_; dnnl::impl::fpmath_t fpmath_; diff --git a/src/common/primitive_attr_quant.cpp b/src/common/primitive_attr_quant.cpp new file mode 100644 index 00000000000..7a6bf1fbaf9 --- /dev/null +++ b/src/common/primitive_attr_quant.cpp @@ -0,0 +1,99 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/primitive_attr_quant.hpp" +#include "common/primitive_hashing.hpp" +#include "common/verbose.hpp" + +namespace dnnl { +namespace impl { + +const quant_entry_t &default_quant_entry() { + static const quant_entry_t default_quant_entry; + return default_quant_entry; +} + +size_t quant_entry_t::get_hash(size_t seed) const { + seed = hash_combine(seed, mask_); + seed = hash_combine(seed, static_cast(data_type_)); + seed = hash_combine(seed, group_ndims_); + if (group_ndims_ > 0) + seed = primitive_hashing::get_array_hash( + seed, group_dims_, group_ndims_); + return seed; +} + +void quant_entry_t::serialize(serialization_stream_t &sstream) const { + sstream.write(&mask_); + sstream.write(&data_type_); + sstream.write(&group_ndims_); + if (group_ndims_ > 0) sstream.write(group_dims_, group_ndims_); +} + +std::string quant_entry_t::get_verbose() const { + std::string s; + s.append(std::to_string(mask_)); + s.append(":").append(dnnl_dt2str(data_type_)); + if (group_ndims_ > 0) { + s.append(":") + .append(std::to_string(group_dims_[0])) + .append("x") + .append(std::to_string(group_dims_[1])); + } + return s; +} + +std::ostream &operator<<(std::ostream &ss, const quant_entry_t &e) { + ss << e.get_verbose(); + return ss; +} + +size_t scales_t::get_hash(size_t seed) const { + // Go through scales for all arguments. + for (const auto &e : scales_) { + seed = hash_combine(seed, e.first); + seed = e.second.get_hash(seed); + } + return seed; +} + +void scales_t::serialize(serialization_stream_t &sstream) const { + for (const auto &e : scales_) { + sstream.write(&e.first); + e.second.serialize(sstream); + } +} + +std::string scales_t::get_verbose() const { + std::string s; + std::string empty_delim, attr_delim = "+"; + std::string delim = empty_delim; + for (const auto &scale : scales_) { + const auto &q = scale.second; + if (q.has_default_values()) continue; + + int arg = scale.first; + s.append(delim) + .append(arg2str(arg)) + .append(":") + .append(q.get_verbose()); + delim = attr_delim; + } + return s; +} + +} // namespace impl +} // namespace dnnl diff --git a/src/common/primitive_attr_quant.hpp b/src/common/primitive_attr_quant.hpp index fbf838bd176..371526ba8e7 100644 --- a/src/common/primitive_attr_quant.hpp +++ b/src/common/primitive_attr_quant.hpp @@ -28,175 +28,185 @@ // dependency between headers when it comes to inclusion of opdesc.hpp which // sdpa_desc_t is a part of. -#include "utils.hpp" +#include "common/serialization_stream.hpp" +#include "common/utils.hpp" #include -#include +#include #include +#include namespace dnnl { namespace impl { -struct runtime_scales_t; -const runtime_scales_t &default_runtime_scale(); - -struct runtime_scales_t : public c_compatible { - // Clang-3.8.1 raises an error for a default initialization of a const - // object. Const runtime_scales_t object is used as default_scales. - // runtime_scales_t() = default; - runtime_scales_t() {} - - runtime_scales_t &operator=(const runtime_scales_t &rhs) { - mask_ = rhs.mask_; - is_set_ = rhs.is_set_; - ndims_ = rhs.ndims_; - if (ndims_ > 0) utils::array_copy(group_dims_, rhs.group_dims_, ndims_); - data_type_ = rhs.data_type_; - return *this; - } +struct quant_entry_t; +const quant_entry_t &default_quant_entry(); - status_t set(int mask) { return set(0, mask, {}, data_type::f32); } +struct quant_entry_t : public c_compatible { + quant_entry_t() = default; - status_t set(int ndims, int mask, const dims_t group_dims, - data_type_t data_type = data_type::f32) { + // `set(...)` approach is taken over constructors as the usage model assumes + // the change of state of this object but it doesn't require its destruction + // which would come with some performance price which prevails in this case. + status_t set(int mask, data_type_t data_type) { + return set(mask, data_type, 0, {}); + } + status_t set(int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims) { mask_ = mask; - is_set_ = true; - ndims_ = ndims; - if (ndims > 0) utils::array_copy(group_dims_, group_dims, ndims); data_type_ = data_type; + group_ndims_ = group_ndims; + if (group_ndims_ > 0) { + utils::array_copy(group_dims_, group_dims, group_ndims_); + } return status::success; } + status_t set(const quant_entry_t &other) { + return set(other.mask_, other.data_type_, other.group_ndims_, + other.group_dims_); + } + + quant_entry_t &operator=(const quant_entry_t &rhs) { + auto st = this->set(rhs); + assert(st == status::success); + UNUSED(st); + return *this; + } + + bool has_default_values() const { return *this == default_quant_entry(); } + bool has_default_groups() const { + return this->group_ndims_ == default_quant_entry().group_ndims_; + } + + int get_mask() const { return mask_; } + data_type_t get_data_type() const { return data_type_; } + dim_t get_group(int d) const { + // If groups were not requested, return `1` for convenience. + if (group_ndims_ == default_quant_entry().group_ndims_) return 1; + // But if they were, any out of bound access would return `0` and likely + // lead to a division by zero which is fast to catch. + if (d >= group_ndims_) return 0; + return group_dims_[d]; + } - bool operator==(const runtime_scales_t &rhs) const { - return mask_ == rhs.mask_ && is_set_ == rhs.is_set_ - && ndims_ == rhs.ndims_ - && IMPLICATION(ndims_ > 0, - utils::array_cmp(group_dims_, rhs.group_dims_, ndims_)) - && data_type_ == rhs.data_type_; + // Note: keep the definition here to satisfy the + // `gtests/internals/test_comparison_operators` linking requirements which + // mandates bodies to be in the header file. + bool operator==(const quant_entry_t &rhs) const { + return mask_ == rhs.mask_ && data_type_ == rhs.data_type_ + && group_ndims_ == rhs.group_ndims_ + && IMPLICATION(group_ndims_ > 0, + utils::array_cmp( + group_dims_, rhs.group_dims_, group_ndims_)); } - bool has_default_values() const { return *this == default_runtime_scale(); } + size_t get_hash(size_t seed) const; + + void serialize(serialization_stream_t &sstream) const; - bool has_default_groups() const { return 0 == ndims_; } - bool has_default_data_type() const { return data_type_ == data_type::f32; } + std::string get_verbose() const; - // TODO: replace with `-1` to remove `is_set_`. - // Hide `mask_` under `private:` to force interface usage. - int mask_ = 0; - bool is_set_ = false; - int ndims_ = 0; - dims_t group_dims_ = {}; - data_type_t data_type_ = data_type::f32; +private: + // Note: INT_MIN is used on purpose to avoid potential issues when + // `(mask & bit)` expression will return `true`. `INT_MIN` is represented + // as `10...0` in bits and will avoid such situations. + int mask_ = INT_MIN; + data_type_t data_type_ = data_type::undef; + int group_ndims_ = 0; + dims_t group_dims_ {}; }; -struct arg_scales_t : public c_compatible { - arg_scales_t() = default; +std::ostream &operator<<(std::ostream &ss, const quant_entry_t &e); - const runtime_scales_t &get(int arg) const { - static const runtime_scales_t default_scales; +struct scales_t : public c_compatible { + scales_t() = default; + + const quant_entry_t &get(int arg) const { const auto it = scales_.find(arg); - if (it == scales_.end()) return default_scales; + if (it == scales_.end()) return default_quant_entry(); return it->second; } - status_t set(int arg, const runtime_scales_t &scale) { + // See `set(...)` comment for `quant_entry_t` for a design choice + // explanation. + status_t set(int arg, int mask) { + return set(arg, mask, default_data_type, 0, {}); + } + status_t set(int arg, int mask, data_type_t data_type, int group_ndims, + const dims_t group_dims) { if (!check_arg(arg)) return status::invalid_arguments; - scales_[arg] = scale; + CHECK(scales_[arg].set(mask, data_type, group_ndims, group_dims)); return status::success; } - - bool operator==(const arg_scales_t &rhs) const { - return scales_ == rhs.scales_; - } - - bool has_default_values(const std::vector &skip_args = {}) const { - auto predicate = [](const runtime_scales_t &s) { - return s.has_default_values(); - }; - return has_default_property(skip_args, predicate); + // Use this interface with `default_quant_entry` when need to remove a + // specific scale. + status_t set(int arg, const quant_entry_t &other) { + return scales_[arg].set(other); } - bool has_default_data_type(const std::vector &skip_args = {}) const { - auto predicate = [](const runtime_scales_t &s) { - return s.has_default_data_type(); - }; - return has_default_property(skip_args, predicate); + bool has_default_values(const std::vector &supported_args = {}) const { + auto predicate + = [](const quant_entry_t &s) { return s.has_default_values(); }; + return has_default_property(supported_args, predicate); } - bool has_default_groups(const std::vector &skip_args = {}) const { - auto predicate = [](const runtime_scales_t &s) { - return s.has_default_groups(); + // This interface checks the content of all scales, and allows to ignore + // certain arguments. + bool has_default_data_type( + const std::vector &supported_args = {}) const { + auto predicate = [](const quant_entry_t &s) { + // Note: `data_type::undef` represents `default_quant_entry`. + return utils::one_of( + s.get_data_type(), default_data_type, data_type::undef); }; - return has_default_property(skip_args, predicate); + return has_default_property(supported_args, predicate); } - - status_t set(int arg, int mask) { - return set(arg, mask, 0, {}, data_type::f32); - } - - status_t set(int arg, int mask, int ndims, const dims_t group_dims, - data_type_t data_type) { - if (!check_arg(arg)) return status::invalid_arguments; - return scales_[arg].set(ndims, mask, group_dims, data_type); + // This interface checks specific argument. It exists because quant_entry_t + // doesn't have a notion of default data_type, only scales do. + bool has_default_data_type(int arg) const { + // Note: `data_type::undef` represents `default_quant_entry`. + return utils::one_of( + get(arg).get_data_type(), default_data_type, data_type::undef); } - // TODO: move to `private` and keep a single interface per entry. - status_t get(int arg, int *mask, bool *is_set, int *ndims = nullptr, - dims_t group_dims = nullptr, - data_type_t *data_type = nullptr) const { - if (!check_arg(arg)) return status::invalid_arguments; - const auto &s = get(arg); - if (mask) *mask = s.mask_; - if (is_set) *is_set = s.is_set_; - if (ndims) *ndims = s.ndims_; - if (group_dims && s.ndims_ > 0) - utils::array_copy(group_dims, s.group_dims_, s.ndims_); - if (data_type) *data_type = s.data_type_; - return status::success; + bool has_default_groups(const std::vector &supported_args = {}) const { + auto predicate + = [](const quant_entry_t &s) { return s.has_default_groups(); }; + return has_default_property(supported_args, predicate); } + int get_mask(int arg) const { return get(arg).get_mask(); } data_type_t get_data_type(int arg) const { - data_type_t data_type; - auto st = get(arg, nullptr, nullptr, nullptr, nullptr, &data_type); - if (st != status::success) return data_type::undef; - return data_type; + return get(arg).get_data_type(); } + dim_t get_group(int arg, int d) const { return get(arg).get_group(d); } - status_t reset(int arg) { - if (!check_arg(arg)) return status::invalid_arguments; - const auto it = scales_.find(arg); - if (it != scales_.end()) scales_.erase(it); - return status::success; + bool operator==(const scales_t &rhs) const { + return scales_ == rhs.scales_; } - status_t copy_from(const arg_scales_t &other) { - for (auto it = other.scales_.begin(); it != other.scales_.end(); ++it) { - // Find an entry that can match the arguments without constructing a - // new object. - if (scales_.count(it->first) == 1) { - auto &entry = scales_[it->first]; - if (entry == it->second) continue; - } + size_t get_hash(size_t seed) const; - CHECK(set(it->first, it->second)); - } - return status::success; - } + void serialize(serialization_stream_t &sstream) const; - std::map scales_; + std::string get_verbose() const; private: + // Sorted property of `std::map` is used for hashing. + std::map scales_; + static constexpr data_type_t default_data_type = data_type::f32; + bool check_arg(int arg) const { + // regular + for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { + if (arg == sa) return true; + } // binary - for (const auto &sa : {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1}) { + for (const auto &sa : {DNNL_ARG_SRC_1}) { if (arg == sa) return true; } // concat if (arg & DNNL_ARG_MULTIPLE_SRC) return true; - // convolution - for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - if (arg == sa) return true; - } // depth-wise convolution post op for (const auto &sa : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { if (arg == (DNNL_ARG_ATTR_POST_OP_DW | sa)) return true; @@ -206,19 +216,23 @@ struct arg_scales_t : public c_compatible { return false; } - bool has_default_property(const std::vector &skip_args, - bool (*predicate)(const runtime_scales_t &)) const { + // The function makes sure that if any argument was specified by user, that + // only `supported_args` have their value customized, rest unsupported + // values were not updated. + bool has_default_property(const std::vector &supported_args, + bool (*predicate)(const quant_entry_t &)) const { for (const auto &s : scales_) { - if (!predicate(s.second)) { - bool skip = false; - for (const auto &skip_a : skip_args) - if (s.first == skip_a) { - skip = true; - break; - } - if (skip) continue; - return false; - } + // Arg passed the condition, check the next one. + if (predicate(s.second)) continue; + + bool allow_non_default = false; + for (const auto &supported_arg : supported_args) + if (s.first == supported_arg) { + allow_non_default = true; + break; + } + if (allow_non_default) continue; + return false; } return true; } diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index 75301d504b5..de308cfd57f 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -15,6 +15,7 @@ *******************************************************************************/ #include +#include "primitive_attr.hpp" #include "primitive_desc.hpp" #include "type_helpers.hpp" #include "utils.hpp" @@ -232,20 +233,7 @@ size_t get_attr_hash(const primitive_attr_t &attr) { } if (!attr.scales_.has_default_values()) { - // go through scales for all arguments - for (const auto &p : attr.scales_.scales_) { - // scales: arg - seed = hash_combine(seed, p.first); - // scales: mask - seed = hash_combine(seed, p.second.mask_); - // scales: groups - const int ndims = p.second.ndims_; - seed = hash_combine(seed, ndims); - if (ndims > 0) - seed = get_array_hash(seed, p.second.group_dims_, ndims); - // scales: data type - seed = hash_combine(seed, static_cast(p.second.data_type_)); - } + seed = attr.scales_.get_hash(seed); } // zero_points for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) diff --git a/src/common/primitive_hashing.hpp b/src/common/primitive_hashing.hpp index e43bfcf9669..f8f517d9d9d 100644 --- a/src/common/primitive_hashing.hpp +++ b/src/common/primitive_hashing.hpp @@ -23,7 +23,6 @@ #include "common/c_types_map.hpp" #include "common/engine_id.hpp" -#include "common/primitive_attr.hpp" #include "common/type_helpers.hpp" #include "common/verbose.hpp" diff --git a/src/common/reorder.cpp b/src/common/reorder.cpp index 506e939de5a..2ca95e3274c 100644 --- a/src/common/reorder.cpp +++ b/src/common/reorder.cpp @@ -102,22 +102,21 @@ status_t reorder_primitive_desc_create(std::shared_ptr &pd, if (!attr->scales_.has_default_values()) { const auto &sc = attr->scales_; const auto &sc_src = sc.get(DNNL_ARG_SRC); - const int mask_src = sc_src.mask_; + const int mask_src = sc.get_mask(DNNL_ARG_SRC); VCHECK_REORDER(IMPLICATION(utils::one_of(src_md->data_type, data_type::s4, data_type::u4), mask_src > 0), VERBOSE_INVALID_DATATYPE, "mask for int4 source"); - if (sc_src.ndims_ > 0) { + if (!sc_src.has_default_groups()) { const int src_ndims = s_mdw.ndims(); const bool group_dims_are_consistent - = IMPLICATION(sc_src.group_dims_[0] > 1, - src_md->dims[src_ndims - 2] - % sc_src.group_dims_[0] + = IMPLICATION(sc_src.get_group(0) > 1, + src_md->dims[src_ndims - 2] % sc_src.get_group(0) == 0) - && IMPLICATION(sc_src.group_dims_[1] > 1, - src_md->dims[src_ndims - 1] % sc_src.group_dims_[1] + && IMPLICATION(sc_src.get_group(1) > 1, + src_md->dims[src_ndims - 1] % sc_src.get_group(1) == 0); VCHECK_REORDER(group_dims_are_consistent, "groups dimensions are not consistent with reorder " @@ -132,9 +131,8 @@ status_t reorder_primitive_desc_create(std::shared_ptr &pd, "mask is not consistent with groups"); } - const auto &sc_dst = sc.get(DNNL_ARG_DST); - VCHECK_REORDER(sc_dst.ndims_ == 0, VERBOSE_BAD_NDIMS, "dst scales", - sc_dst.ndims_); + VCHECK_REORDER(sc.get(DNNL_ARG_DST).has_default_groups(), + VERBOSE_UNSUPPORTED_SCALES_CFG); } bool is_cross_engine = src_engine != dst_engine diff --git a/src/common/sdpa_pd.hpp b/src/common/sdpa_pd.hpp index aece62aff54..d6ac73f8aca 100644 --- a/src/common/sdpa_pd.hpp +++ b/src/common/sdpa_pd.hpp @@ -127,7 +127,9 @@ struct sdpa_pd_t : public primitive_desc_t { } /// Returns the data type of the scales tensor for the KQ matmul - data_type_t key_scales_dt() const { return desc()->kq_scales.data_type_; } + data_type_t key_scales_dt() const { + return desc()->kq_scales.get_data_type(); + } /// Returns the data type of the zero points tensor for the KQ matmul data_type_t key_zp_dt() const { @@ -135,7 +137,9 @@ struct sdpa_pd_t : public primitive_desc_t { } /// Returns the data type of the scales tensor for the VS matmul - data_type_t value_scales_dt() const { return desc()->vs_scales.data_type_; } + data_type_t value_scales_dt() const { + return desc()->vs_scales.get_data_type(); + } /// Returns the data type of the zero points tensor for the VS matmul data_type_t value_zp_dt() const { @@ -195,18 +199,19 @@ struct sdpa_pd_t : public primitive_desc_t { private: static int scale_group_size( - const runtime_scales_t &scales, const memory_desc_t &desc) { + const quant_entry_t &scales, const memory_desc_t &desc) { dim_t out = utils::array_product(desc.dims, desc.ndims); + const auto mask = scales.get_mask(); if (scales.has_default_groups()) { - for (int idx : mask_iterator(scales.mask_)) { + for (int idx : mask_iterator(mask)) { out /= desc.dims[idx]; } } else { - for (int idx : mask_iterator(scales.mask_)) { + for (int idx : mask_iterator(mask)) { if (idx < 2) { out /= desc.dims[idx]; } else { - out /= (desc.dims[idx] / scales.group_dims_[idx - 2]); + out /= (desc.dims[idx] / scales.get_group(idx - 2)); } } } diff --git a/src/common/sdpa_types.hpp b/src/common/sdpa_types.hpp index 2e7a7b11ea0..e3cf32f32a4 100644 --- a/src/common/sdpa_types.hpp +++ b/src/common/sdpa_types.hpp @@ -47,9 +47,9 @@ struct sdpa_desc_t : public op_desc_t { // primitive_attr_t can't be used because of deleted copy-ctor, but desc_t // must be copyable. - runtime_scales_t kq_scales {}; + quant_entry_t kq_scales {}; zero_points_t kq_zero_points {}; - runtime_scales_t vs_scales {}; + quant_entry_t vs_scales {}; zero_points_t vs_zero_points {}; memory_desc_t dst_desc {}; diff --git a/src/common/serialization.cpp b/src/common/serialization.cpp index 0d9b59340b0..174773a45a9 100644 --- a/src/common/serialization.cpp +++ b/src/common/serialization.cpp @@ -188,19 +188,7 @@ void serialize_attr( if (!attr.scales_.has_default_values()) { sstream.write("scale:"); - // go through scales for all arguments - for (const auto &p : attr.scales_.scales_) { - // scales: arg - sstream.write(&p.first); - // scales: mask - sstream.write(&p.second.mask_); - // scales: groups - const int ndims = p.second.ndims_; - sstream.write(&ndims); - if (ndims > 0) sstream.write(p.second.group_dims_, ndims); - // scales: data type - sstream.write(&p.second.data_type_); - } + attr.scales_.serialize(sstream); } // zero_points if (!attr.zero_points_.has_default_values()) sstream.write("zp:"); diff --git a/src/common/serialization_stream.hpp b/src/common/serialization_stream.hpp index 28eb32aad61..a977b30b9a2 100644 --- a/src/common/serialization_stream.hpp +++ b/src/common/serialization_stream.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #define COMMON_SERIALIZATION_STREAM_HPP #include +#include #include #include diff --git a/src/common/softmax.cpp b/src/common/softmax.cpp index 94e6e9c4ca5..a9d581fdfaa 100644 --- a/src/common/softmax.cpp +++ b/src/common/softmax.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -112,13 +112,21 @@ status_t softmax_attr_check(const softmax_desc_t &desc, const engine_t *engine, VCHECK_SOFTMAX_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt), VERBOSE_UNSUPPORTED_ATTR); + // Check scales if (!attr->scales_.has_default_values()) { - const auto &sc = attr->scales_; - const int mask_src = sc.get(DNNL_ARG_SRC).mask_; - const int mask_dst = sc.get(DNNL_ARG_DST).mask_; - - VCHECK_SOFTMAX_UNIMPL(utils::everyone_is(0, mask_src, mask_dst), + static const std::vector supported_args { + DNNL_ARG_SRC, DNNL_ARG_DST}; + VCHECK_SOFTMAX_UNIMPL( + attr->scales_.has_default_values(supported_args), VERBOSE_UNSUPPORTED_SCALES_CFG); + + for (int arg : supported_args) { + if (attr->scales_.get(arg).has_default_values()) continue; + + const int mask = attr->scales_.get_mask(arg); + VCHECK_SOFTMAX_UNIMPL( + mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + } } // Check post-ops diff --git a/src/common/softmax_pd.hpp b/src/common/softmax_pd.hpp index 6e6c12c54c2..89a54d7735d 100644 --- a/src/common/softmax_pd.hpp +++ b/src/common/softmax_pd.hpp @@ -176,11 +176,18 @@ struct softmax_fwd_pd_t : public softmax_pd_t { dst_md_, src_md_.format_desc.blocking); } - bool attr_scales_ok() const { + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_DST}) const { const auto &scales = attr()->scales_; - bool ok = true; - for (const auto &e : scales.scales_) { - ok = ok && e.second.mask_ == 0; + bool ok = scales.has_default_values(supported_args); + + for (const auto &arg : supported_args) { + if (scales.get(arg).has_default_values()) continue; + + // TODO: disallow non-int8 scales? + // const data_type_t dt = arg_md(arg)->data_type; + // ok = ok && utils::one_of(dt, s8, u8); + ok = ok && scales.get_mask(arg) == 0; } return ok; } diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index 2112080170a..a8e2612db1c 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -315,6 +315,10 @@ bool get_verbose_timestamp() { void pd_info_t::init( dnnl::impl::engine_t *, const dnnl::impl::primitive_desc_t *) {} +std::string arg2str(int arg) { + return std::string(); +} + std::string rt_mds2str(primitive_kind_t prim_kind, const memory_desc_t *src_md, const memory_desc_t *wei_md, const memory_desc_t *bia_md, const memory_desc_t *dst_md) { @@ -626,18 +630,6 @@ std::string md2desc_str(const memory_desc_t *md) { return s; } -std::ostream &operator<<(std::ostream &ss, const runtime_scales_t &scale) { - ss << scale.mask_; - ss << ":" << scale.data_type_; - if (scale.ndims_) { - ss << ":"; - for (int i = 0; i < scale.ndims_ - 1; ++i) - ss << scale.group_dims_[i] << 'x'; - ss << scale.group_dims_[scale.ndims_ - 1]; - } - return ss; -} - std::ostream &operator<<( std::ostream &ss, const rnn_create_time_scales_t &rnn_scales) { ss << rnn_scales.mask_; @@ -749,20 +741,13 @@ std::ostream &operator<<(std::ostream &ss, const primitive_attr_t *attr) { if (deterministic) { ss << field_delim() << "attr-deterministic:" << deterministic; } - if (attr->has_default_values()) return ss; - const arg_scales_t &as = attr->scales_; - if (!as.has_default_values()) { - std::string delim = empty_delim; - ss << field_delim() << "attr-scales:"; - for (const auto &map_entry : as.scales_) { - const auto &val = map_entry.second; - if (val.has_default_values()) continue; + // Fast exit if rest attributes were not specified. + if (attr->has_default_values()) return ss; - int arg = map_entry.first; - ss << delim << arg2str(arg) << ":" << val; - delim = attr_delim; - } + const scales_t &scales = attr->scales_; + if (!scales.has_default_values()) { + ss << field_delim() << "attr-scales:" << scales.get_verbose(); } const zero_points_t &zp = attr->zero_points_; diff --git a/src/common/verbose.hpp b/src/common/verbose.hpp index 9c722f59654..1c4cd67f1e0 100644 --- a/src/common/verbose.hpp +++ b/src/common/verbose.hpp @@ -377,6 +377,7 @@ std::string md2fmt_str( const char *name, const memory_desc_t *md, format_kind_t user_format); std::string md2dim_str( const memory_desc_t *md, dims_type_t dims_type = dims_type_t::dims); +std::string arg2str(int arg); // Returns a verbose string of dimensions or descriptor from src, wei, and/or // dst memory descs. Can be called externally to provide info about actual // values of runtime dimensions. diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp index e586ed4e304..1b25146e7fe 100644 --- a/src/cpu/aarch64/acl_reorder.hpp +++ b/src/cpu/aarch64/acl_reorder.hpp @@ -95,12 +95,12 @@ struct acl_reorder_fwd_t : public primitive_t { if (!ok) return status::unimplemented; - int mask = -1; - bool is_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); - const memory_desc_wrapper input_d(src_md); - if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) - return status::unimplemented; + if (!attr->scales_.get(DNNL_ARG_DST).has_default_values()) { + int mask = attr->scales_.get_mask(DNNL_ARG_DST); + const memory_desc_wrapper input_d(src_md); + if (input_d.has_runtime_dims_or_strides() && mask > 0) + return status::unimplemented; + } // Create and check primitive descriptor auto _pd = make_unique_pd(attr, src_engine->kind(), src_md, diff --git a/src/cpu/aarch64/brgemm/brgemm.cpp b/src/cpu/aarch64/brgemm/brgemm.cpp index 64f73814e30..39700adfb9b 100644 --- a/src/cpu/aarch64/brgemm/brgemm.cpp +++ b/src/cpu/aarch64/brgemm/brgemm.cpp @@ -297,21 +297,22 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr, if (brg->with_scales) { // Note. the current version supports only two different output scale // types: - // 1) common (mask_ = 0) + // 1) common (mask = 0) // 2) per_n_dim_scale - broadcast across n dimension; // for convolution and inner product promitives it corresponds - // to "per_oc" mask_ = 1 << 1; for matmul - to - // mask_ = (1 << (ndims - 1))), where ndims is number of + // to "per_oc" mask = 1 << 1; for matmul - to + // mask = (1 << (ndims - 1))), where ndims is number of // dimensions for original matmul problem - // So if wei_scales.mask_ != 0 (not common) it's assumed here that scale - // type is per_n_dim_scale and driver which calls brgemm kernel checked - // that mask has correct value for this case - brg->is_oc_scale = wei_scales.mask_ != 0; + // So if wei_scales.get_mask() > 0 (not common) it's assumed here that + // scale type is per_n_dim_scale and driver which calls brgemm kernel + // checked that mask has correct value for this case + brg->is_oc_scale = wei_scales.get_mask() > 0; } const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST); brg->with_dst_scales = !dst_scales.has_default_values(); - const bool scales_ok = src_scales.mask_ == 0 && dst_scales.mask_ == 0 + const bool scales_ok = src_scales.get_mask() == 0 + && dst_scales.get_mask() == 0 && attr->scales_.has_default_values( {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); if (!scales_ok) return status::unimplemented; diff --git a/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp b/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp index 24e018aef02..a8be300ddb0 100644 --- a/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp +++ b/src/cpu/aarch64/jit_brdgmm_dw_conv.cpp @@ -200,7 +200,7 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) { const auto &wei_scales = attr_.scales_.get(DNNL_ARG_WEIGHTS); jcp.with_scale = !src_scales.has_default_values() || !wei_scales.has_default_values(); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; const bool scales_ok = attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); diff --git a/src/cpu/aarch64/jit_brgemm_conv_utils.cpp b/src/cpu/aarch64/jit_brgemm_conv_utils.cpp index 3b9d3422594..026290c53e9 100644 --- a/src/cpu/aarch64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/aarch64/jit_brgemm_conv_utils.cpp @@ -1993,7 +1993,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; // disables the shape with small ic but large spatial // or specific large spatial shapes for int8 conv @@ -2190,7 +2190,7 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; // enable ununroll_bd_loop for big shapes to reduce kernel sizes jcp.ununroll_bd_loop diff --git a/src/cpu/aarch64/jit_brgemm_post_ops.hpp b/src/cpu/aarch64/jit_brgemm_post_ops.hpp index 4257bfe31b8..5aed828a582 100644 --- a/src/cpu/aarch64/jit_brgemm_post_ops.hpp +++ b/src/cpu/aarch64/jit_brgemm_post_ops.hpp @@ -325,8 +325,8 @@ struct jit_brgemm_kernel_post_ops : public jit_generator { const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); // per_oc: conv: 1 << 0, (1 << 1) + (1 << 0) (with groups) // per_oc: ip: 1 << 0 - is_oc_scale_ - = utils::one_of(wei_scales.mask_, 1 << 0, (1 << 1) + (1 << 0)); + is_oc_scale_ = utils::one_of( + wei_scales.get_mask(), 1 << 0, (1 << 1) + (1 << 0)); LDD_ = brg.LDD; inp_dt_ = brg.dt_c; diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp index d6c20904c90..24ed2b491f8 100644 --- a/src/cpu/aarch64/jit_uni_reorder.cpp +++ b/src/cpu/aarch64/jit_uni_reorder.cpp @@ -2788,13 +2788,10 @@ status_t jit_uni_reorder_t::pd_t::init_scratchpad() { compensation_reduce_size); } - const memory_desc_wrapper input_d(src_md()); - int scales_mask = -1; - bool is_set = false; - CHECK(attr()->scales_.get(DNNL_ARG_DST, &scales_mask, &is_set)); - - if (is_set && scales_mask > 0) { - get_D_values(input_d, scales_mask, nullptr, &D_mask_, nullptr); + if (!attr()->scales_.get(DNNL_ARG_DST).has_default_values()) { + const memory_desc_wrapper input_d(src_md()); + int mask = attr()->scales_.get_mask(DNNL_ARG_DST); + get_D_values(input_d, mask, nullptr, &D_mask_, nullptr); if (D_mask_ > 1) { scratchpad.template book( memory_tracking::names::key_reorder_precomputed_dst_scales, diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/jit_uni_reorder_utils.cpp index 63fc09a4cf2..d086e0649c3 100644 --- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp +++ b/src/cpu/aarch64/jit_uni_reorder_utils.cpp @@ -276,24 +276,21 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.src_scale_type = scale_type_t::NONE; int src_mask = 0; - bool is_src_set = false; - CHECK(attr->scales_.get(DNNL_ARG_SRC, &src_mask, &is_src_set)); - if (is_src_set) { + if (!attr->scales_.get(DNNL_ARG_SRC).has_default_values()) { + src_mask = attr->scales_.get_mask(DNNL_ARG_SRC); p.src_scale_type = src_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } p.dst_scale_type = scale_type_t::NONE; int dst_mask = 0; - bool is_dst_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &dst_mask, &is_dst_set)); - if (is_dst_set) { + if (!attr->scales_.get(DNNL_ARG_DST).has_default_values()) { + dst_mask = attr->scales_.get_mask(DNNL_ARG_DST); p.dst_scale_type = dst_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } - if (is_src_set && is_dst_set && src_mask != dst_mask) - return status::unimplemented; + if (src_mask != dst_mask) return status::unimplemented; p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) ? om_d.extra().scale_adjust diff --git a/src/cpu/aarch64/matmul/acl_lowp_matmul.cpp b/src/cpu/aarch64/matmul/acl_lowp_matmul.cpp index 076d5fd321a..ca826296706 100644 --- a/src/cpu/aarch64/matmul/acl_lowp_matmul.cpp +++ b/src/cpu/aarch64/matmul/acl_lowp_matmul.cpp @@ -66,11 +66,11 @@ status_t acl_lowp_matmul_t::pd_t::init(engine_t *engine) { | smask_t::zero_points_runtime | smask_t::post_ops), "only scale, zero point and post-ops attrs supported"); - VDISPATCH_MATMUL(attr()->scales_.get(DNNL_ARG_SRC).mask_ == 0 + VDISPATCH_MATMUL(attr()->scales_.get_mask(DNNL_ARG_SRC) == 0 && attr()->zero_points_.get(DNNL_ARG_SRC) == 0 - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && attr()->zero_points_.get(DNNL_ARG_WEIGHTS) == 0 - && attr()->scales_.get(DNNL_ARG_DST).mask_ == 0 + && attr()->scales_.get_mask(DNNL_ARG_DST) == 0 && attr()->zero_points_.get(DNNL_ARG_DST) == 0, "common scales and zero points only"); diff --git a/src/cpu/aarch64/matmul/brgemm_matmul.cpp b/src/cpu/aarch64/matmul/brgemm_matmul.cpp index 7ede5613803..ea6ab1482a9 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul.cpp @@ -74,7 +74,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { bool ok = attr_scales_ok(supported_args); if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp index 1506be014a9..0610147c752 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp @@ -783,21 +783,22 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); - bgmmc.with_scales = !src_scales.has_default_values() - || !wei_scales.has_default_values(); - if (bgmmc.with_scales) { - bgmmc.is_oscale_per_n = wei_scales.mask_ == 1 << (bgmmc.ndims - 1); + const bool has_wei_scales = !wei_scales.has_default_values(); + bgmmc.with_scales = !src_scales.has_default_values() || has_wei_scales; + if (has_wei_scales) { + bgmmc.is_oscale_per_n + = wei_scales.get_mask() == (1 << (bgmmc.ndims - 1)); // only common and per-oc-channel scales are supported - VCONDCHECK_BG(wei_scales.mask_ == 0 || bgmmc.is_oscale_per_n, + VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_oscale_per_n, VERBOSE_UNSUPPORTED_SCALES_CFG); } const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); bgmmc.with_dst_scales = !dst_scales.has_default_values(); // only common scales are supported - if (bgmmc.with_dst_scales && dst_scales.mask_ != 0) - return status::unimplemented; + VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); const auto &p = attr.post_ops_; bgmmc.with_sum = p.find(primitive_kind::sum) != -1; diff --git a/src/cpu/dw_convolution_utils.hpp b/src/cpu/dw_convolution_utils.hpp index 088e01b9964..23e26c986d2 100644 --- a/src/cpu/dw_convolution_utils.hpp +++ b/src/cpu/dw_convolution_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2023 Intel Corporation +* Copyright 2020-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,23 +43,19 @@ inline status_t get_depthwise_conv_desc(convolution_desc_t &cd_dw, // post-ops after depthwise post-op. auto &dw_po = attr_1x1.post_ops_.entry_[dw_po_index].depthwise_conv; - // erase 1x1 conv scales - for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - auto &scale = attr_dw.scales_.get(arg); - if (!scale.has_default_values()) attr_dw.scales_.reset(arg); - } - const auto &dw_src_scales = attr_1x1.scales_.get(DNNL_ARG_DST); const auto &dw_wei_scales = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); const auto &dw_dst_scales = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST); + + assert(attr_dw.scales_.has_default_values()); if (!dw_src_scales.has_default_values()) - attr_dw.scales_.set(DNNL_ARG_SRC, dw_src_scales.mask_); + CHECK(attr_dw.scales_.set(DNNL_ARG_SRC, dw_src_scales.get_mask())); if (!dw_wei_scales.has_default_values()) - attr_dw.scales_.set(DNNL_ARG_WEIGHTS, dw_wei_scales.mask_); + CHECK(attr_dw.scales_.set(DNNL_ARG_WEIGHTS, dw_wei_scales.get_mask())); if (!dw_dst_scales.has_default_values()) - attr_dw.scales_.set(DNNL_ARG_DST, dw_dst_scales.mask_); + CHECK(attr_dw.scales_.set(DNNL_ARG_DST, dw_dst_scales.get_mask())); auto dw_po_len = attr_1x1.post_ops_.len() - (dw_po_index + 1); attr_dw.post_ops_.entry_.resize(dw_po_len); diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp index dcf4b688fc4..58946ba1b60 100644 --- a/src/cpu/gemm_convolution_utils.cpp +++ b/src/cpu/gemm_convolution_utils.cpp @@ -2125,7 +2125,7 @@ status_t init_conf(conv_gemm_conf_t &jcp, jcp.dst_os_stride = dst_d.is_blocking_desc() ? dst_d.blocking_desc().strides[ndims - 1] : 0; - jcp.scale_idx_mult = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0; + jcp.scale_idx_mult = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0; jcp.with_dst_scale = !attr.scales_.get(DNNL_ARG_DST).has_default_values(); book_precomputed_scales(scratchpad, attr.scales_, jcp.ngroups * jcp.oc); diff --git a/src/cpu/gemm_inner_product_utils.cpp b/src/cpu/gemm_inner_product_utils.cpp index 815e953898b..fbdb237ca1a 100644 --- a/src/cpu/gemm_inner_product_utils.cpp +++ b/src/cpu/gemm_inner_product_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,8 +161,8 @@ pp_kernel_t::pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, || !attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) , ndims_(dst_md->ndims) { - if (do_scale_) { - int wei_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_; + if (!attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { + int wei_mask = attr->scales_.get_mask(DNNL_ARG_WEIGHTS); // matmul: per_oc: 1 << (ndims_ - 1) // ip: per_oc: 1 << 0 scale_idx_mult_ = wei_mask == (1 << (ndims_ - 1)) || wei_mask == 1 << 0; diff --git a/src/cpu/gemm_x8s8s32x_convolution.cpp b/src/cpu/gemm_x8s8s32x_convolution.cpp index 8482ae65eb0..d953ed29363 100644 --- a/src/cpu/gemm_x8s8s32x_convolution.cpp +++ b/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -135,10 +135,9 @@ status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, - pd()->IC(), pd()->OC(), false, wei_scale_mask != 0, pd()->attr()); + pd()->IC(), pd()->OC(), false, wei_scale_mask > 0, pd()->attr()); parallel(jcp.nthr, [&](const int ithr, const int nthr) { status_t st_thr = execute_forward_thr(ithr, nthr, src_base, wei_base, @@ -358,16 +357,15 @@ status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr( const auto diff_src_dt_size = types::data_type_size(diff_src_md.data_type()); - const int scale_idx_mult = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + const int scale_idx_mult = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << static_cast(pd()->with_groups())); DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, - pd()->IC(), pd()->OC(), false, wei_scale_mask != 0, pd()->attr()); + pd()->IC(), pd()->OC(), false, wei_scale_mask > 0, pd()->attr()); const dim_t work_amount = jcp.ngroups * jcp.mb; diff --git a/src/cpu/gemm_x8s8s32x_inner_product.cpp b/src/cpu/gemm_x8s8s32x_inner_product.cpp index 341a584a276..f132de24b03 100644 --- a/src/cpu/gemm_x8s8s32x_inner_product.cpp +++ b/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -64,10 +64,9 @@ status_t gemm_x8s8s32x_inner_product_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, - IC, OC, false, wei_scale_mask != 0, pd()->attr()); + IC, OC, false, wei_scale_mask > 0, pd()->attr()); int32_t *acc = pd()->dst_is_acc_ ? (int32_t *)dst diff --git a/src/cpu/matmul/gemm_bf16_matmul.cpp b/src/cpu/matmul/gemm_bf16_matmul.cpp index cd415b94743..f6a7e310ae0 100644 --- a/src/cpu/matmul/gemm_bf16_matmul.cpp +++ b/src/cpu/matmul/gemm_bf16_matmul.cpp @@ -107,7 +107,7 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes( bool ok = attr_scales_ok(); if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -145,11 +145,15 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes( // set state CHECK(params_.pp_attr_.copy_from(*attr())); params_.gemm_applies_output_scales_ - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias(); + = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && !with_bias(); if (params_.gemm_applies_output_scales_) { - params_.pp_attr_.scales_.reset(DNNL_ARG_SRC); - params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_SRC, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_WEIGHTS, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); } // check post-ops @@ -203,11 +207,10 @@ status_t gemm_bf16_matmul_t::execute_ref( DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, src_d.dims()[ndims - 1], dst_d.dims()[ndims - 1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); if (src_d.has_zero_dim() || weights_d.has_zero_dim() || dst_d.has_zero_dim()) @@ -254,7 +257,7 @@ status_t gemm_bf16_matmul_t::execute_ref( const float beta = params.gemm_beta_; const dim_t acc_ldc = dst_is_acc ? ldc : N; const int scale_idx_mult - = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + = this->pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << (ndims - 1)); std::atomic st(status::success); diff --git a/src/cpu/matmul/gemm_f32_matmul.cpp b/src/cpu/matmul/gemm_f32_matmul.cpp index de57af38944..6873ec13cbd 100644 --- a/src/cpu/matmul/gemm_f32_matmul.cpp +++ b/src/cpu/matmul/gemm_f32_matmul.cpp @@ -52,7 +52,7 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { bool ok = attr_scales_ok(); if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -131,10 +131,14 @@ status_t gemm_f32_matmul_t::pd_t::configure_attributes() { CHECK(params_.pp_attr_.copy_from(*attr())); params_.gemm_applies_output_scales_ - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 && !with_bias(); + = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0 && !with_bias(); if (params_.gemm_applies_output_scales_) { - params_.pp_attr_.scales_.reset(DNNL_ARG_SRC); - params_.pp_attr_.scales_.reset(DNNL_ARG_WEIGHTS); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_SRC, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_MATMUL_SC(params_.pp_attr_.scales_.set( + DNNL_ARG_WEIGHTS, default_quant_entry()), + VERBOSE_UNSUPPORTED_SCALES_CFG); } const auto &po = params_.pp_attr_.post_ops_; @@ -186,11 +190,10 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, src_d.dims()[ndims - 1], dst_d.dims()[ndims - 1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); if (src_d.has_zero_dim() || weights_d.has_zero_dim() || dst_d.has_zero_dim()) @@ -237,7 +240,7 @@ status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const dim_t acc_ldc = dst_is_acc ? ldc : N; const int scale_idx_mult - = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + = this->pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << (ndims - 1)); std::atomic st(status::success); diff --git a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp index 5fab321d7af..b429c5268db 100644 --- a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp +++ b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp @@ -63,7 +63,7 @@ status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) { bool ok = attr_scales_ok(); if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() - && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -203,11 +203,10 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); auto &scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, src_d.dims()[ndims - 1], dst_d.dims()[ndims - 1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINT_VALUE(weights_zero_point, DNNL_ARG_WEIGHTS); @@ -276,7 +275,7 @@ status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const float beta = params.gemm_beta_; const dim_t acc_ldc = dst_is_acc ? ldc : N; const int scale_idx_mult - = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ + = this->pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == (1 << (ndims - 1)); std::atomic st(status::success); diff --git a/src/cpu/matmul/matmul_utils.hpp b/src/cpu/matmul/matmul_utils.hpp index 996683c522d..b05de80a8df 100644 --- a/src/cpu/matmul/matmul_utils.hpp +++ b/src/cpu/matmul/matmul_utils.hpp @@ -156,6 +156,11 @@ struct matmul_helper_t { static status_t get_quant_md(memory_desc_t &md, const int ndims, const dims_t in_dims, const int quant_mask, const dim_t g0, const dim_t g1, const data_type_t dt) { + if (dt == data_type::undef) { + md = glob_zero_md; + return status::success; + } + dims_t quant_dims {}; utils::copy_dims_with_mask(quant_dims, in_dims, ndims, quant_mask, /* fill_with_ones = */ true); @@ -172,6 +177,8 @@ struct matmul_helper_t { static dim_t get_quant_off(const dims_t &input_idx, const int ndims, const int quant_mask, const dim_t g0, const dim_t g1, const memory_desc_t &quant_md) { + if (types::is_zero_md(&quant_md)) return 0; + dims_t quant_idx {}; utils::array_copy(quant_idx, input_idx, ndims); utils::apply_mask_on_dims(quant_idx, ndims, quant_mask); diff --git a/src/cpu/matmul/ref_matmul.cpp b/src/cpu/matmul/ref_matmul.cpp index d81b63a2b4b..53ca9c79af4 100644 --- a/src/cpu/matmul/ref_matmul.cpp +++ b/src/cpu/matmul/ref_matmul.cpp @@ -118,19 +118,14 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); const bool with_dst_scales = !attr_scales.get(DNNL_ARG_DST).has_default_values(); - const auto wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + const auto wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); const dim_t wei_scale_stride_n = (wei_scale_mask & pd()->wei_qmask_N()) ? 1 : 0; - const auto &wei_scale_dt = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; + const auto &wei_scale_dt = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); const auto wei_scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); - const auto wei_scale_group_ndim = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto wei_scale_group_k = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : 1; - const auto wei_scale_group_n = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[1] - : 1; + const auto wei_scale_group_k = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); + const auto wei_scale_group_n = attr_scales.get_group(DNNL_ARG_WEIGHTS, 1); // Initialize a memory desc for quant entries for easier offset calculation. memory_desc_t wei_scale_md {}; CHECK(matmul_helper_t::get_quant_md(wei_scale_md, ndims, weights_d.dims(), diff --git a/src/cpu/matmul/ref_matmul_int8.cpp b/src/cpu/matmul/ref_matmul_int8.cpp index bbe4804994f..4cf6d3397db 100644 --- a/src/cpu/matmul/ref_matmul_int8.cpp +++ b/src/cpu/matmul/ref_matmul_int8.cpp @@ -127,16 +127,10 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); const bool with_dst_scales = !attr_scales.get(DNNL_ARG_DST).has_default_values(); - const int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); const auto &wei_scale_dt = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); - const bool wei_scale_per_k = wei_scale_mask & pd()->wei_qmask_K(); - const auto wei_scale_group_ndim = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto wei_scale_group_k = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : (wei_scale_per_k ? 1 : K); - const auto wei_scale_group_n = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[1] - : 1; + const auto wei_scale_group_k = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); + const auto wei_scale_group_n = attr_scales.get_group(DNNL_ARG_WEIGHTS, 1); const auto wei_scale_ngroups_k = K / wei_scale_group_k; // Initialize a memory desc for quant entries for easier offset calculation. memory_desc_t wei_scale_md {}; @@ -146,13 +140,9 @@ status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { const bool with_src_scales = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); - const int src_scale_mask = attr_scales.get(DNNL_ARG_SRC).mask_; + const int src_scale_mask = attr_scales.get_mask(DNNL_ARG_SRC); const auto &src_scale_dt = attr_scales.get_data_type(DNNL_ARG_SRC); - const bool src_scale_per_k = src_scale_mask & pd()->src_qmask_K(); - const auto src_scale_group_ndim = attr_scales.get(DNNL_ARG_SRC).ndims_; - const auto src_scale_group_k = src_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_SRC).group_dims_[1] - : (src_scale_per_k ? 1 : K); + const auto src_scale_group_k = attr_scales.get_group(DNNL_ARG_SRC, 1); const auto src_scale_ngroups_k = K / src_scale_group_k; // Initialize a memory desc for quant entries for easier offset calculation. memory_desc_t src_scale_md {}; diff --git a/src/cpu/ref_concat.hpp b/src/cpu/ref_concat.hpp index 516e2ff0aac..d6c911dbaf6 100644 --- a/src/cpu/ref_concat.hpp +++ b/src/cpu/ref_concat.hpp @@ -59,10 +59,9 @@ struct ref_concat_t : public primitive_t { for (int i = 0; i < n_; ++i) { primitive_attr_t r_attr; if (!sc.get(DNNL_ARG_MULTIPLE_SRC + i).has_default_values()) { - int mask = 0; - CHECK(sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, nullptr)); - if (mask != 0) return status::unimplemented; - r_attr.scales_.set(DNNL_ARG_SRC, mask); + int mask = sc.get_mask(DNNL_ARG_MULTIPLE_SRC + i); + VDISPATCH_CONCAT(mask == 0, VERBOSE_UNSUPPORTED_SCALES_CFG); + CHECK(r_attr.scales_.set(DNNL_ARG_SRC, mask)); } CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, src_md(i), src_image_md(i), &r_attr)); diff --git a/src/cpu/ref_convolution_int8.cpp b/src/cpu/ref_convolution_int8.cpp index b1c99eb8cda..e969258fb1c 100644 --- a/src/cpu/ref_convolution_int8.cpp +++ b/src/cpu/ref_convolution_int8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ namespace { void dequantize(float &d, dim_t g, dim_t C, dim_t c, const float *wei_scales, bool with_groups, int wei_mask, const float *src_scales) { // scale_idx_mult = 1 for per_channel scales and 0, otherwise - const int wei_scale_idx_mult = wei_mask != 0; + const int wei_scale_idx_mult = wei_mask > 0; float scale = 1.0f; if (src_scales) scale *= src_scales[0]; if (wei_scales) scale *= wei_scales[(g * C + c) * wei_scale_idx_mult]; @@ -63,8 +63,7 @@ status_t ref_convolution_int8_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); @@ -290,8 +289,7 @@ status_t ref_convolution_int8_bwd_data_t::execute_backward_data( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(diff_dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); diff --git a/src/cpu/ref_deconvolution.cpp b/src/cpu/ref_deconvolution.cpp index 08c2304e675..f3245b13ef2 100644 --- a/src/cpu/ref_deconvolution.cpp +++ b/src/cpu/ref_deconvolution.cpp @@ -174,8 +174,7 @@ status_t ref_deconvolution_fwd_t::compute_oscale( DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const memory_desc_wrapper dst_d(pd()->dst_md()); @@ -190,7 +189,7 @@ status_t ref_deconvolution_fwd_t::compute_oscale( const auto maybe_oscale = [](float &d, dim_t oc, const float *src_scales, const float *wei_scales, int wei_mask) { // scale_idx_mult = 1 for per_oc scales and 0, otherwise - const int wei_scale_idx_mult = wei_mask != 0; + const int wei_scale_idx_mult = wei_mask > 0; d *= src_scales[0] * wei_scales[oc * wei_scale_idx_mult]; }; @@ -216,7 +215,7 @@ status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int dst_scale_mask = pd()->attr()->scales_.get(DNNL_ARG_DST).mask_; + const int dst_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_DST); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); const bool is_dst_zp_common @@ -242,7 +241,7 @@ status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx, const auto maybe_scale = [](float &d, dim_t oc, const float *scales, int mask) { // scale_idx_mult = 1 for per_oc scales and 0, otherwise - const int scale_idx_mult = mask != 0; + const int scale_idx_mult = mask > 0; d *= scales[oc * scale_idx_mult]; }; @@ -536,9 +535,8 @@ status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const { float *conv_output = scratchpad.get(key_deconv_bias); - const auto &arg_scales = pd()->attr()->scales_; - const auto &src_scales = arg_scales.get(DNNL_ARG_SRC); - const auto &wei_scales = arg_scales.get(DNNL_ARG_WEIGHTS); + const auto &src_scales = pd()->attr()->scales_.get(DNNL_ARG_SRC); + const auto &wei_scales = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS); if (!src_scales.has_default_values() || !wei_scales.has_default_values()) { compute_oscale(ctx, conv_output); diff --git a/src/cpu/ref_fused_convolution.hpp b/src/cpu/ref_fused_convolution.hpp index 1e00b56d576..ec00848e51a 100644 --- a/src/cpu/ref_fused_convolution.hpp +++ b/src/cpu/ref_fused_convolution.hpp @@ -227,7 +227,8 @@ struct ref_fused_convolution_fwd_t : public primitive_t { auto &scale = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | arg); if (!scale.has_default_values()) - attr_1x1.scales_.reset(DNNL_ARG_ATTR_POST_OP_DW | arg); + CHECK(attr_1x1.scales_.set(DNNL_ARG_ATTR_POST_OP_DW | arg, + default_quant_entry())); } // erase post-ops after fusion as they will be handled separately auto &e = attr_1x1.post_ops_.entry_; diff --git a/src/cpu/ref_inner_product_int8.cpp b/src/cpu/ref_inner_product_int8.cpp index 91198c680ab..ff0bffda821 100644 --- a/src/cpu/ref_inner_product_int8.cpp +++ b/src/cpu/ref_inner_product_int8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,7 +82,7 @@ status_t ref_inner_product_int8_fwd_t::execute_forward( auto maybe_oscale = [&](float &d, dim_t oc) { // scale_idx_mult = 1 for per_oc scales and 0, otherwise const int scale_idx_mult - = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == (1 << 0); + = attr_scales.get_mask(DNNL_ARG_WEIGHTS) == (1 << 0); d *= src_scales[0] * wei_scales[oc * scale_idx_mult]; }; diff --git a/src/cpu/ref_sum.hpp b/src/cpu/ref_sum.hpp index 7256114d685..06a30a753a7 100644 --- a/src/cpu/ref_sum.hpp +++ b/src/cpu/ref_sum.hpp @@ -46,7 +46,7 @@ struct ref_sum_t : public primitive_t { reorder_pds_.resize(n_ + need_output_reorder()); for (int i = 0; i < n_; ++i) { primitive_attr_t r_attr; - r_attr.scales_.set(DNNL_ARG_SRC, 0); + CHECK(r_attr.scales_.set(DNNL_ARG_SRC, 0)); if (i != 0) r_attr.post_ops_.append_sum(1.0); CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, src_md(i), dst_acc_md(), &r_attr)); diff --git a/src/cpu/reorder/cpu_reorder_pd.hpp b/src/cpu/reorder/cpu_reorder_pd.hpp index d1c8499c151..a16672c7e62 100644 --- a/src/cpu/reorder/cpu_reorder_pd.hpp +++ b/src/cpu/reorder/cpu_reorder_pd.hpp @@ -82,15 +82,15 @@ struct cpu_reorder_pd_t : public reorder_pd_t { const float *dst_scales) const { using namespace dnnl::impl::memory_tracking::names; - int mask = -1; - bool is_set = false; - auto status = attr->scales_.get(DNNL_ARG_DST, &mask, &is_set); - if (status != status::success) return nullptr; + if (attr->scales_.get(DNNL_ARG_DST).has_default_values()) { + return dst_scales; + } // It's possible that mask > 0 but `count` is still `1`. This case is // covered by `DEFINE_ARG_SCALES_BUFFER` macro and no need to inverse // in such case. - if (is_set && mask > 0 && count > 1) { + int mask = attr->scales_.get_mask(DNNL_ARG_DST); + if (mask > 0 && count > 1) { auto loc_scales = scratchpad.template get( key_reorder_precomputed_dst_scales); if (!loc_scales) return nullptr; diff --git a/src/cpu/reorder/simple_reorder.hpp b/src/cpu/reorder/simple_reorder.hpp index fbc13f38e0c..2d590689762 100644 --- a/src/cpu/reorder/simple_reorder.hpp +++ b/src/cpu/reorder/simple_reorder.hpp @@ -134,11 +134,11 @@ inline status_t get_scales_mask( *src_mask = 0; if (!s.get(DNNL_ARG_SRC).has_default_values()) - *src_mask = s.get(DNNL_ARG_SRC).mask_; + *src_mask = s.get_mask(DNNL_ARG_SRC); *dst_mask = 0; if (!s.get(DNNL_ARG_DST).has_default_values()) - *dst_mask = s.get(DNNL_ARG_DST).mask_; + *dst_mask = s.get_mask(DNNL_ARG_DST); // This is used in a check function. if (*src_mask > 0 && *dst_mask > 0 && *dst_mask != *src_mask) @@ -152,11 +152,10 @@ inline bool simple_attr_check(const primitive_attr_t *attr, if (sum_support) skip_mask = skip_mask | smask_t::post_ops; if (!attr->has_default_values(skip_mask)) return false; for (int arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { - const auto &sc = attr->scales_.get(arg); // Data type for scales is not generally supported. - if (!sc.has_default_data_type()) return false; + if (!attr->scales_.has_default_data_type(arg)) return false; // Groups are generally not supported. - if (!sc.has_default_groups()) return false; + if (!attr->scales_.get(arg).has_default_groups()) return false; } if (many_scales_support) return true; int src_mask, dst_mask; @@ -2331,11 +2330,9 @@ struct simple_reorder_impl 0 ? sc_src.group_dims_[0] : 1; + const auto src_scales_group0 = sc_src.get_group(0); // Applied to the last dimension. - const auto src_scales_group1 - = sc_src.ndims_ > 0 ? sc_src.group_dims_[1] : 1; + const auto src_scales_group1 = sc_src.get_group(1); memory_desc_t src_scales_md {}; if (has_src_scales) { @@ -2540,7 +2537,8 @@ struct simple_reorder_implscales_.get(DNNL_ARG_DST); const bool has_dst_scales = !sc_dst.has_default_values(); if (has_dst_scales) { - VDISPATCH_REORDER_IC(sc_dst.has_default_data_type() + VDISPATCH_REORDER_IC( + attr->scales_.has_default_data_type(DNNL_ARG_DST) && sc_dst.has_default_groups(), VERBOSE_UNSUPPORTED_SCALES_CFG); } @@ -2568,11 +2566,9 @@ struct simple_reorder_implattr()->scales_.get(DNNL_ARG_SRC); const bool has_src_scales = !sc_src.has_default_values(); // Applied to the pre-last dimension. - const auto src_scales_group0 - = sc_src.ndims_ > 0 ? sc_src.group_dims_[0] : 1; + const auto src_scales_group0 = sc_src.get_group(0); // Applied to the last dimension. - const auto src_scales_group1 - = sc_src.ndims_ > 0 ? sc_src.group_dims_[1] : 1; + const auto src_scales_group1 = sc_src.get_group(1); memory_desc_t src_scales_md {}; if (has_src_scales) { get_quant_md(src_scales_md, ndims, input_d.dims(), src_scales_mask, @@ -2690,12 +2686,14 @@ struct simple_reorder_t : public primitive_t { spec>::is_applicable(src_md, dst_md, attr); if (status != status::success) return status; - int mask = -1; - bool is_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); const memory_desc_wrapper input_d(src_md); - if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) - return status::unimplemented; + + int mask = -1; + if (!attr->scales_.get(DNNL_ARG_DST).has_default_values()) { + mask = attr->scales_.get_mask(DNNL_ARG_DST); + if (input_d.has_runtime_dims_or_strides() && mask > 0) + return status::unimplemented; + } auto _pd = make_unique_pd(attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); @@ -2709,7 +2707,7 @@ struct simple_reorder_t : public primitive_t { scratchpad.book(memory_tracking::names::key_reorder_space, scratchpad_sz_, 1, 16); - if (is_set && mask > 0) { + if (mask > 0) { dim_t D_mask; _pd->get_D_values(input_d, mask, nullptr, &D_mask, nullptr); scratchpad.template book( diff --git a/src/cpu/scale_utils.cpp b/src/cpu/scale_utils.cpp index c6d92a33e2f..03e4081903f 100644 --- a/src/cpu/scale_utils.cpp +++ b/src/cpu/scale_utils.cpp @@ -32,7 +32,7 @@ constexpr size_t scales_simd_w = 16; } void book_precomputed_scales(memory_tracking::registrar_t &scratchpad, - const arg_scales_t &attr_scales, size_t wei_scale_count, + const scales_t &attr_scales, size_t wei_scale_count, bool force_scales_book) { using namespace dnnl::impl::memory_tracking::names; @@ -40,13 +40,11 @@ void book_precomputed_scales(memory_tracking::registrar_t &scratchpad, = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); const bool with_wei_scales = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const auto wei_scales_dt = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto wei_scale_groups_ndims - = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; + if ((with_src_scales && with_wei_scales) || force_scales_book - || (wei_scales_dt != data_type::f32 && with_wei_scales) - || (wei_scale_groups_ndims > 0 && with_wei_scales)) { - const int wei_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + || !attr_scales.has_default_data_type(DNNL_ARG_WEIGHTS) + || !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups()) { + const int wei_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); const size_t precomputed_scales_size = wei_mask == 0 ? scales_simd_w : nstl::max( @@ -64,23 +62,23 @@ bool req_copy_scales( = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); const bool with_wei_scales = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const auto wei_scales_dt = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto wei_scale_groups_ndims - = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; return (with_src_scales && with_wei_scales) || scale_adjust_factor != 1.0f - || (wei_scales_dt != data_type::f32 && with_wei_scales) - || (wei_scale_groups_ndims > 0 && with_wei_scales); + || !attr_scales.has_default_data_type(DNNL_ARG_WEIGHTS) + || !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups(); } const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales, dim_t oc, const primitive_attr_t *attr, float scale_adjust_factor) { - // Note: per-ic-channel is no supported in default - const int wei_scale_mask = attr->scales_.get(DNNL_ARG_WEIGHTS).mask_; + // Note: per-ic-channel is no supported by default. + const int wei_scale_mask = attr->scales_.get_mask(DNNL_ARG_WEIGHTS); return precompute_scales(scratchpad, src_scales, wei_scales, 1, oc, false, - wei_scale_mask != 0, attr, scale_adjust_factor, false); + wei_scale_mask > 0, attr, scale_adjust_factor, false); } +// Note: `wei_scale_per_ic` and `wei_scale_per_oc` could be identified in this +// function unless different primitives have same definition of `per_ic` and +// `per_oc` masks. Mostly, matmul is different from anybody else. const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales, dim_t IC, dim_t OC, const bool wei_scale_per_ic, const bool wei_scale_per_oc, @@ -96,7 +94,9 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const float *scales = nullptr; if (req_copy_scales(attr, scale_adjust_factor)) { - const int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + assert(wei_scale_mask >= 0); + size_t size = 0; auto loc_scales = scratchpad.template get(key_precomputed_scales, &size); @@ -108,12 +108,9 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, const dim_t count = nstl::min( static_cast(size / sizeof(float)), wei_scale_count); const auto wei_scale_dt - = attr_scales.get(DNNL_ARG_WEIGHTS).data_type_; - const auto wei_scale_groups_ndims - = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto wei_scale_groups_ic = wei_scale_groups_ndims > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : 1; + = attr_scales.get_data_type(DNNL_ARG_WEIGHTS); + const auto wei_scale_groups_ic + = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); // Note: per-ic-channel scales is only supported for // weights decompression for now if ((wei_scale_per_ic && wei_scale_groups_ic > 1) diff --git a/src/cpu/scale_utils.hpp b/src/cpu/scale_utils.hpp index 7c1ce535889..9b7c6f0cefc 100644 --- a/src/cpu/scale_utils.hpp +++ b/src/cpu/scale_utils.hpp @@ -26,7 +26,7 @@ namespace impl { namespace cpu { void book_precomputed_scales(memory_tracking::registrar_t &scratchpad, - const arg_scales_t &attr_scales, size_t wei_scales_count, + const scales_t &attr_scales, size_t wei_scales_count, bool force_scales_book = false); bool req_copy_scales( diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index eadc19c646f..e441695d3f8 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -434,23 +434,26 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, if (brg->with_scales) { // Note. the current version supports only two different output scale // types: - // 1) common (mask_ = 0) + // 1) common (mask = 0) // 2) per_n_dim_scale - broadcast across n dimension; // for convolution and inner product promitives it corresponds - // to "per_oc" mask_ = 1 << 1; for matmul - to - // mask_ = (1 << (ndims - 1))), where ndims is number of + // to "per_oc" mask = 1 << 1; for matmul - to + // mask = (1 << (ndims - 1))), where ndims is number of // dimensions for original matmul problem - // So if wei_scales.mask_ != 0 (not common) it's assumed here that scale - // type is per_n_dim_scale and driver which calls brgemm kernel checked - // that mask has correct value for this case - brg->is_oc_scale = wei_scales.mask_ != 0; + // So if wei_scales.get_mask() > 0 (not common) it's assumed here that + // scale type is per_n_dim_scale and driver which calls brgemm kernel + // checked that mask has correct value for this case + brg->is_oc_scale = wei_scales.get_mask() > 0; } const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST); brg->with_dst_scales = !dst_scales.has_default_values(); - const bool scales_ok = src_scales.mask_ == 0 && dst_scales.mask_ == 0 - && attr->scales_.has_default_values( - {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); + const bool scales_ok = attr->scales_.has_default_values({DNNL_ARG_SRC, + DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) + && IMPLICATION(!src_scales.has_default_values(), + src_scales.get_mask() == 0) + && IMPLICATION(!dst_scales.has_default_values(), + dst_scales.get_mask() == 0); if (!scales_ok) return status::unimplemented; auto init_zp_type diff --git a/src/cpu/x64/brgemm/capi/brgemm_api.cpp b/src/cpu/x64/brgemm/capi/brgemm_api.cpp index a7b57f73706..1e5ea93e052 100644 --- a/src/cpu/x64/brgemm/capi/brgemm_api.cpp +++ b/src/cpu/x64/brgemm/capi/brgemm_api.cpp @@ -279,7 +279,7 @@ status_t brgemm_t::execute(const void *A_ptr, const void *B_ptr, const void *wei_scales_ptr = attr_params->get_scales(DNNL_ARG_WEIGHTS); if (wei_scales_ptr == nullptr) return status::invalid_arguments; - int wei_mask = attr_.scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = attr_.scales_.get_mask(DNNL_ARG_WEIGHTS); if (wei_mask > 0) { for (dim_t i = 0; i < N_; i++) { const float wei_scale_val = cpu::io::load_float_value( diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp index e57e4b9ed76..d27a22646d4 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.cpp @@ -1217,9 +1217,8 @@ status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, = (avaliable_ops) ? ops_tile_store / avaliable_ops + 1 : 0; if (jcp.per_one_pstore > 12) jcp.per_one_pstore = 0; - const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0; jcp.dst_scale = !dst_scales.has_default_values(); return status::success; diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp index 5a61a9fa38f..58473076f53 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.cpp @@ -73,11 +73,10 @@ status_t jit_avx512_core_amx_1x1_convolution_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales( ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->IC(), - pd()->OC(), false, wei_scale_mask != 0, pd()->attr(), + pd()->OC(), false, wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get()); DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); diff --git a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp index eae42216786..aad5959dd0e 100644 --- a/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp @@ -118,8 +118,8 @@ struct jit_avx512_core_amx_1x1_convolution_fwd_t : public primitive_t { const auto attr = pd()->attr(); if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t(attr))); CHECK(jit_scale_precompute_->create_kernel()); diff --git a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp index ef00ff87e11..682c7c7ade4 100644 --- a/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp @@ -2655,7 +2655,7 @@ status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); // Note: currently unsupported, results in seg-fault @@ -3969,7 +3969,7 @@ status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp, const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_ic_scale = wei_scales.mask_ != 0; + jcp.is_ic_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); return status::success; diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp index 2377ca8bf47..86af8cd23be 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.cpp @@ -99,11 +99,10 @@ jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales( ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->IC(), - pd()->OC(), false, wei_scale_mask != 0, pd()->attr(), + pd()->OC(), false, wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get()); auto inp_p_buffer = ctx.get_scratchpad_grantor().template get( @@ -457,11 +456,10 @@ status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales( ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->IC(), - pd()->OC(), false, wei_scale_mask != 0, pd()->attr(), + pd()->OC(), false, wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get()); // TODO: use block offset instead of hand-calculated one @@ -831,11 +829,10 @@ status_t jit_avx512_core_amx_convolution_bwd_data_t::execute_backward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales( ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->IC(), - pd()->OC(), false, wei_scale_mask != 0, pd()->attr(), + pd()->OC(), false, wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get()); amx_utils::execute_backward_convolution_body(ctx, pd()->jcp_, kernel_, diff --git a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp index ccd9a4d0a73..e781957cc44 100644 --- a/src/cpu/x64/jit_avx512_core_amx_convolution.hpp +++ b/src/cpu/x64/jit_avx512_core_amx_convolution.hpp @@ -120,8 +120,8 @@ struct jit_avx512_core_amx_convolution_fwd_t : public primitive_t { const auto attr = pd()->attr(); if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t(attr))); CHECK(jit_scale_precompute_->create_kernel()); @@ -203,8 +203,8 @@ struct jit_avx512_core_amx_convolution_bwd_data_t : public primitive_t { const auto attr = pd()->attr(); if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t(attr))); CHECK(jit_scale_precompute_->create_kernel()); diff --git a/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp b/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp index de930820f80..07dc8ec31f9 100644 --- a/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp +++ b/src/cpu/x64/jit_avx512_core_amx_deconvolution.cpp @@ -78,11 +78,10 @@ status_t jit_avx512_core_amx_deconvolution_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), src_scales, wei_scales, src_d.dims()[1], dst_d.dims()[1], false, - wei_scale_mask != 0, pd()->attr()); + wei_scale_mask > 0, pd()->attr()); // The body of bwd/d convolution harness is called with: // 1. src as input instead of diff_dst diff --git a/src/cpu/x64/jit_avx512_core_scale_precompute.cpp b/src/cpu/x64/jit_avx512_core_scale_precompute.cpp index 1f98294a715..e36e6462317 100644 --- a/src/cpu/x64/jit_avx512_core_scale_precompute.cpp +++ b/src/cpu/x64/jit_avx512_core_scale_precompute.cpp @@ -44,7 +44,7 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, ; if (jit_scale_precompute) { const auto &attr_scales = attr->scales_; - const int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); size_t size = 0; auto loc_scales = scratchpad.template get( memory_tracking::names::key_precomputed_scales, &size); @@ -54,9 +54,9 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, = wei_scale_per_ic ? wei_scale_per_oc ? OC : 1 : 0; const auto with_wei_scale = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); - const auto wei_scale_groups_ndims - = with_wei_scale ? attr_scales.get(DNNL_ARG_WEIGHTS).ndims_ : 0; - const auto wei_scale_group_stride = wei_scale_groups_ndims > 0 + const auto wei_scale_has_groups = with_wei_scale + && !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups(); + const auto wei_scale_group_stride = wei_scale_has_groups ? wei_scale_stride_ic * sizeof(float) : 0; @@ -66,14 +66,14 @@ const float *precompute_scales(const memory_tracking::grantor_t &scratchpad, assert(req_copy_scales(attr, scale_adjust_factor)); assert(mayiuse(avx512_core)); - assert(wei_scale_mask != 0); - if (wei_scale_groups_ndims > 0) { + assert(wei_scale_mask > 0); + if (wei_scale_has_groups) { assert(count == wei_scale_count); const auto wei_scale_groups_ic - = attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0]; + = attr_scales.get_group(DNNL_ARG_WEIGHTS, 0); const dim_t wei_scale_nb_ic = IC / wei_scale_groups_ic; const auto wei_scale_dt_sz = types::data_type_size( - attr_scales.get(DNNL_ARG_WEIGHTS).data_type_); + attr_scales.get_data_type(DNNL_ARG_WEIGHTS)); for (int nb_ic = 0; nb_ic < wei_scale_nb_ic; nb_ic++) { const auto offset = nb_ic * wei_scale_stride_ic; jrp.nelems_ = wei_scale_stride_ic; diff --git a/src/cpu/x64/jit_avx512_core_scale_precompute.hpp b/src/cpu/x64/jit_avx512_core_scale_precompute.hpp index 432a9aee05c..771b7afdc73 100644 --- a/src/cpu/x64/jit_avx512_core_scale_precompute.hpp +++ b/src/cpu/x64/jit_avx512_core_scale_precompute.hpp @@ -73,14 +73,10 @@ struct jit_avx512_core_scale_precompute_t : public jit_generator { , with_wei_scales_( !attr_->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) , wei_scales_dt_(with_wei_scales_ - ? attr_->scales_.get(DNNL_ARG_WEIGHTS).data_type_ + ? attr_->scales_.get_data_type(DNNL_ARG_WEIGHTS) : data_type::f32) , wei_scales_dsz_(types::data_type_size(wei_scales_dt_)) - , wei_groups_ic_(with_wei_scales_ - && attr_->scales_.get(DNNL_ARG_WEIGHTS).ndims_ - > 0 - ? attr_->scales_.get(DNNL_ARG_WEIGHTS).group_dims_[0] - : 1) + , wei_groups_ic_(attr_->scales_.get_group(DNNL_ARG_WEIGHTS, 0)) , scale_adjust_factor_(scale_adjust_factor) , compute_scale_factor_(scale_adjust_factor_ != 1) {} diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp index b8e581de704..0f2afaec195 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -1206,7 +1206,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); jcp.wei_adj_scale @@ -1222,7 +1222,7 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { using namespace dnnl::impl::memory_tracking::names; - const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS); const dim_t scales_count = wei_mask == 0 ? 1 : static_cast(jcp.oc) * jcp.ngroups; const dim_t count = nstl::max(scales_count, (dim_t)jcp.ic_block); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp index 538a89cf5ee..faa40ff3419 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,7 +70,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward( = scratchpad.template get(key_conv_adjusted_scales); // Src scale is always a single value float src_scale = src_scales[0]; - int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) ? 1.f / pd()->jcp_.wei_adj_scale : 1.f; @@ -92,7 +92,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward( auto dw_local_scales = dw_scratchpad.template get(key_conv_adjusted_scales); auto attr_dw = pd()->dw_conv_pd_->attr(); - int wei_mask = attr_dw->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = attr_dw->scales_.get_mask(DNNL_ARG_WEIGHTS); dim_t count = wei_mask == 0 ? 1 : pd()->dw_conv_pd_->OC(); float factor = 1.f / jcp_dw->wei_adj_scale; if (count == 1) { diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp index a808e692752..1686b7b7852 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -1486,7 +1486,7 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); jcp.has_vnni = mayiuse(avx512_core_vnni); @@ -1765,7 +1765,7 @@ status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS); const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups; dim_t count = wei_mask == 0 ? (dim_t)16 : scales_count; scratchpad.book(key_conv_adjusted_scales, count); diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp index 1242fd8e3f8..b216b90e293 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2016-2023 Intel Corporation +* Copyright 2016-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ const float *jit_avx512_core_x8s8s32x_convolution_fwd_t::adjust_oscales( const float *wei_scales) const { auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); const float src_scale = src_scales[0]; - const int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) ? 1.f / pd()->jcp_.wei_adj_scale : 1.f; diff --git a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp index c02bc84bdfe..7a73114e002 100644 --- a/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -294,7 +294,7 @@ status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); jcp.dst_dt = dst_d.data_type(); @@ -386,7 +386,7 @@ bool _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( void _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS); const dim_t scales_count = mask == 0 ? 1 : static_cast(jcp.oc) * jcp.ngroups; const dim_t count = nstl::max(scales_count, 16); @@ -1393,7 +1393,7 @@ const float *jit_avx512_core_x8s8s32x_deconvolution_fwd_t::adjust_oscales( const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales) const { auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); - int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) ? 1.f / pd()->jcp_.wei_adj_scale : 1.0f; diff --git a/src/cpu/x64/jit_brdgmm_dw_conv.cpp b/src/cpu/x64/jit_brdgmm_dw_conv.cpp index 75a9481b2a8..6ef246b69cc 100644 --- a/src/cpu/x64/jit_brdgmm_dw_conv.cpp +++ b/src/cpu/x64/jit_brdgmm_dw_conv.cpp @@ -240,7 +240,7 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) { const auto &wei_scales = attr_.scales_.get(DNNL_ARG_WEIGHTS); jcp.with_scale = !src_scales.has_default_values() || !wei_scales.has_default_values(); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; const bool scales_ok = attr_scales_ok({DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}); @@ -555,8 +555,8 @@ status_t brdgmm_dw_convolution_fwd_t::init(engine_t *engine) { const auto attr = pd()->attr(); if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t(attr))); CHECK(jit_scale_precompute_->create_kernel()); @@ -586,11 +586,10 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales( ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->IC(), - pd()->OC(), false, wei_scale_mask != 0, pd()->attr(), + pd()->OC(), false, wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get()); const memory_desc_wrapper weights_d(pd()->weights_md(0)); diff --git a/src/cpu/x64/jit_brgemm_1x1_conv.cpp b/src/cpu/x64/jit_brgemm_1x1_conv.cpp index 66588ac3f34..4b9b78a7760 100644 --- a/src/cpu/x64/jit_brgemm_1x1_conv.cpp +++ b/src/cpu/x64/jit_brgemm_1x1_conv.cpp @@ -281,8 +281,8 @@ status_t brgemm_1x1_convolution_fwd_t::init(engine_t *engine) { if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr, jcp.scale_adjust_factor)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t( attr, jcp.scale_adjust_factor))); @@ -713,11 +713,10 @@ status_t brgemm_1x1_convolution_fwd_t::execute_forward_all( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales(scratchpad, src_scales, wei_scales, pd()->IC(), pd()->OC(), false, - wei_scale_mask != 0, pd()->attr(), jit_scale_precompute_.get(), + wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get(), jcp.scale_adjust_factor); DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); diff --git a/src/cpu/x64/jit_brgemm_conv.cpp b/src/cpu/x64/jit_brgemm_conv.cpp index 618395344bf..2f393834e2e 100644 --- a/src/cpu/x64/jit_brgemm_conv.cpp +++ b/src/cpu/x64/jit_brgemm_conv.cpp @@ -984,8 +984,8 @@ status_t brgemm_convolution_fwd_t::init(engine_t *engine) { if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr, jcp.scale_adjust_factor)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t( attr, jcp.scale_adjust_factor))); @@ -1318,11 +1318,10 @@ status_t brgemm_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales(scratchpad, src_scales, wei_scales, pd()->IC(), pd()->OC(), false, - wei_scale_mask != 0, pd()->attr(), jit_scale_precompute_.get(), + wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get(), jcp.scale_adjust_factor); brgemm_exec_ctx_t brgemm_ctx(ctx, _pd); diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp index 01841aca2af..2c14383ceb1 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp @@ -646,8 +646,8 @@ status_t brgemm_convolution_bwd_strided_t::init(engine_t *engine) { if (is_jit_supported && pd()->IC() > 1 && req_copy_scales(attr, jcp.scale_adjust_factor)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t( attr, jcp.scale_adjust_factor))); @@ -697,11 +697,10 @@ status_t brgemm_convolution_bwd_strided_t::execute( const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales(scratchpad, src_scales, wei_scales, pd()->OC(), pd()->IC(), false, - wei_scale_mask != 0, pd()->attr(), jit_scale_precompute_.get(), + wei_scale_mask > 0, pd()->attr(), jit_scale_precompute_.get(), jcp.scale_adjust_factor); brgemm_bwd_exec_ctx_t brgemm_ctx(ctx, _pd); diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp index 268755c69ab..558e3b0e191 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp @@ -2048,7 +2048,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_ic_scale = wei_scales.mask_ != 0; + jcp.is_ic_scale = wei_scales.get_mask() > 0; } jcp.req_brg_comp_pad = false; diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index 176d8201fd1..bf67831c2e5 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2287,7 +2287,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; const bool compensation_w_padding = (jcp.s8s8_compensation_required || jcp.src_zero_point) @@ -2548,7 +2548,7 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, jcp.with_scales = !src_scales.has_default_values() || !wei_scales.has_default_values() || jcp.scale_adjust_factor != 1.0f; - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; // enable ununroll_bd_loop for big shapes to reduce kernel sizes jcp.ununroll_bd_loop diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 45d40417af8..d0662824b7b 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -94,8 +94,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); const float *oscales = scale_utils::precompute_scales(scratchpad, src_scales, wei_scales, pd()->IC(), pd()->OC(), false, wei_scale_mask == (1 << 0), diff --git a/src/cpu/x64/jit_brgemm_inner_product.hpp b/src/cpu/x64/jit_brgemm_inner_product.hpp index dca15bb151c..b05903a0677 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.hpp +++ b/src/cpu/x64/jit_brgemm_inner_product.hpp @@ -222,8 +222,8 @@ struct brgemm_inner_product_fwd_t : public primitive_t { const auto attr = pd()->attr(); if (is_jit_supported && pd()->OC() > 1 && req_copy_scales(attr)) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t(attr))); CHECK(jit_scale_precompute_->create_kernel()); diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index 3de2b4c10eb..835bbac37e1 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -444,8 +444,7 @@ status_t jit_brgemm_ip_fwd_conf_t::init_conf(cpu_isa_t isa, const memory_desc_wrapper dst_d(&dst_md); if (!post_ops_ok(attr, dst_d)) return status::unimplemented; if (jbgp.with_scales) { - const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); - jbgp.is_oc_scale = wei_scales.mask_ != 0; + jbgp.is_oc_scale = attr.scales_.get_mask(DNNL_ARG_WEIGHTS) > 0; } const int min_ic_divisor = is_amx_int8 ? 4 : is_amx_xf16 ? 2 : 1; diff --git a/src/cpu/x64/jit_brgemm_post_ops.cpp b/src/cpu/x64/jit_brgemm_post_ops.cpp index d762fbad7f8..fce7613b1d3 100644 --- a/src/cpu/x64/jit_brgemm_post_ops.cpp +++ b/src/cpu/x64/jit_brgemm_post_ops.cpp @@ -374,7 +374,8 @@ dnnl::impl::cpu::x64::jit_brgemm_kernel_post_ops_t< const auto &wei_scales = attr_.scales_.get(DNNL_ARG_WEIGHTS); // per_oc: conv: 1 << 0, (1 << 1) + (1 << 0) (with groups) // per_oc: ip: 1 << 0 - is_oc_scale_ = utils::one_of(wei_scales.mask_, 1 << 0, (1 << 1) + (1 << 0)); + is_oc_scale_ + = utils::one_of(wei_scales.get_mask(), 1 << 0, (1 << 1) + (1 << 0)); inp_dt_ = brg_.dt_c; out_dt_ = brg_.dt_d; diff --git a/src/cpu/x64/jit_uni_reorder.cpp b/src/cpu/x64/jit_uni_reorder.cpp index f79a5d32b1f..291c2311d31 100644 --- a/src/cpu/x64/jit_uni_reorder.cpp +++ b/src/cpu/x64/jit_uni_reorder.cpp @@ -2324,13 +2324,10 @@ status_t jit_uni_reorder_t::pd_t::init_scratchpad() { compensation_reduce_size); } - const memory_desc_wrapper input_d(src_md()); - int scales_mask = -1; - bool is_set = false; - CHECK(attr()->scales_.get(DNNL_ARG_DST, &scales_mask, &is_set)); - - if (is_set && scales_mask > 0) { - get_D_values(input_d, scales_mask, nullptr, &D_mask_, nullptr); + if (!attr()->scales_.get(DNNL_ARG_DST).has_default_values()) { + const memory_desc_wrapper input_d(src_md()); + int mask = attr()->scales_.get_mask(DNNL_ARG_DST); + get_D_values(input_d, mask, nullptr, &D_mask_, nullptr); if (D_mask_ > 1) { scratchpad.template book( memory_tracking::names::key_reorder_precomputed_dst_scales, diff --git a/src/cpu/x64/jit_uni_reorder_utils.cpp b/src/cpu/x64/jit_uni_reorder_utils.cpp index cf9b343cb37..a3a04bfd2a1 100644 --- a/src/cpu/x64/jit_uni_reorder_utils.cpp +++ b/src/cpu/x64/jit_uni_reorder_utils.cpp @@ -273,24 +273,21 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.src_scale_type = scale_type_t::NONE; int src_mask = 0; - bool is_src_set = false; - CHECK(attr->scales_.get(DNNL_ARG_SRC, &src_mask, &is_src_set)); - if (is_src_set) { + if (!attr->scales_.get(DNNL_ARG_SRC).has_default_values()) { + src_mask = attr->scales_.get_mask(DNNL_ARG_SRC); p.src_scale_type = src_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } p.dst_scale_type = scale_type_t::NONE; int dst_mask = 0; - bool is_dst_set = false; - CHECK(attr->scales_.get(DNNL_ARG_DST, &dst_mask, &is_dst_set)); - if (is_dst_set) { + if (!attr->scales_.get(DNNL_ARG_DST).has_default_values()) { + dst_mask = attr->scales_.get_mask(DNNL_ARG_DST); p.dst_scale_type = dst_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } - if (is_src_set && is_dst_set && src_mask != dst_mask) - return status::unimplemented; + if (src_mask != dst_mask) return status::unimplemented; p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) ? om_d.extra().scale_adjust diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp index 8186789adc4..627b20e536f 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.cpp @@ -910,7 +910,7 @@ status_t jit_uni_x8s8s32x_1x1_conv_kernel::init_conf( const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); jcp.wei_adj_scale @@ -927,7 +927,7 @@ void jit_uni_x8s8s32x_1x1_conv_kernel::init_scratchpad( const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { using namespace dnnl::impl::memory_tracking::names; - const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS); const dim_t scales_count = wei_mask == 0 ? 1 : static_cast(jcp.oc) * jcp.ngroups; const dim_t count = nstl::max(scales_count, 8); diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp index 6b0e7e9d9e9..d068c1e3517 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,7 +76,7 @@ status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward( const float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) ? 1.f / pd()->jcp_.wei_adj_scale : 1.0f; - int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); if (wei_mask == 0) { utils::array_set( local_scales, src_scales[0] * wei_scales[0] * factor, 8); @@ -94,7 +94,7 @@ status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward( auto dw_local_scales = dw_scratchpad.template get(key_conv_adjusted_scales); - int wei_mask = attr_dw->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = attr_dw->scales_.get_mask(DNNL_ARG_WEIGHTS); float factor = 1.f / jcp_dw->wei_adj_scale; if (wei_mask == 0) { utils::array_set(dw_local_scales, diff --git a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp index d4512765ee3..309b8c1b63a 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_conv_kernel.cpp @@ -1315,7 +1315,7 @@ status_t jit_uni_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); const auto zp = attr.zero_points_; @@ -1614,7 +1614,7 @@ void jit_uni_x8s8s32x_fwd_kernel::init_scratchpad( memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS); const dim_t scales_count = mask == 0 ? 1 : static_cast(jcp.oc) * jcp.ngroups; dim_t count = scales_count == 1 ? (dim_t)8 : scales_count; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp index bfa20f0d33e..be3c69b9b42 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,7 +44,7 @@ const float *jit_uni_x8s8s32x_convolution_fwd_t::adjust_oscales( const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales) const { auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); - int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) ? 1.f / pd()->jcp_.wei_adj_scale : 1.0f; diff --git a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp index 9b03681d900..9b1c487ce6d 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_deconvolution.cpp @@ -272,7 +272,7 @@ status_t jit_uni_x8s8s32x_deconv_fwd_kernel::init_conf( const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); - jcp.is_oc_scale = wei_scales.mask_ != 0; + jcp.is_oc_scale = wei_scales.get_mask() > 0; jcp.dst_scale = !dst_scales.has_default_values(); jcp.post_ops = p; @@ -386,7 +386,7 @@ template void jit_uni_x8s8s32x_deconv_fwd_kernel::init_scratchpad( memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; + const int mask = attr.scales_.get_mask(DNNL_ARG_WEIGHTS); const dim_t scales_count = mask == 0 ? 1 : static_cast(jcp.oc) * jcp.ngroups; dim_t count = nstl::max(scales_count, 8); @@ -1451,7 +1451,7 @@ const float *jit_uni_x8s8s32x_deconvolution_fwd_t::adjust_oscales( const memory_tracking::grantor_t &scratchpad, const float *src_scales, const float *wei_scales) const { auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); - int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int wei_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) ? 1.f / pd()->jcp_.wei_adj_scale : 1.0f; diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 56115a31cdb..d0f84cc0df1 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -86,7 +86,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { const auto &asc = attr()->scales_; if (!asc.get(DNNL_ARG_SRC).has_default_values() && !asc.get(DNNL_ARG_WEIGHTS).has_default_values() - && asc.get(DNNL_ARG_WEIGHTS).mask_ != 0) { + && asc.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; } @@ -100,11 +100,11 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { if (!asc.get(DNNL_ARG_WEIGHTS).has_default_values()) { if (!asc.get(DNNL_ARG_WEIGHTS).has_default_groups()) { // Only grouping over K is supported. - ok = ok && asc.get(DNNL_ARG_WEIGHTS).group_dims_[1] == 1; + ok = ok && asc.get_group(DNNL_ARG_WEIGHTS, 1) == 1; // Only 'per_ocic' mask is supported, but not 'per_tensor' in // benchdnn terms. In numbers, it's '12' is supported while for // 4D '15' is required. - const int mask = asc.get(DNNL_ARG_WEIGHTS).mask_; + const int mask = asc.get_mask(DNNL_ARG_WEIGHTS); const int ndims = weights_md_.ndims; const int last_dim = (1 << (ndims - 1)); const int prelast_dim = (1 << (ndims - 2)); @@ -298,8 +298,8 @@ status_t brgemm_matmul_t::init(engine_t *engine) { if (is_jit_supported && wei_scale_count > 1 && req_copy_scales(attr) && !bgmmc.req_transpose_scales) { const auto &attr_scales = attr->scales_; - int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_scale_mask != 0) { + int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); + if (wei_scale_mask > 0) { CHECK(safe_ptr_assign(jit_scale_precompute_, new jit_avx512_core_scale_precompute_t(attr))); CHECK(jit_scale_precompute_->create_kernel()); @@ -324,10 +324,13 @@ status_t brgemm_matmul_t::execute_body(const exec_ctx_t &ctx) const { matmul_helper_t helper(src_d, weights_d, dst_d); const auto &bgmmc = pd()->get_brgemm_matmul_conf(); - const int wei_scale_mask - = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; - const bool wei_scale_per_k = wei_scale_mask & pd()->wei_qmask_K(); - const bool wei_scale_per_n = wei_scale_mask & pd()->wei_qmask_N(); + const bool has_wei_scales + = !pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); + const int wei_scale_mask = pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS); + const bool wei_scale_per_k + = has_wei_scales && (wei_scale_mask & pd()->wei_qmask_K()); + const bool wei_scale_per_n + = has_wei_scales && (wei_scale_mask & pd()->wei_qmask_N()); const float *oscales = scale_utils::precompute_scales( ctx.get_scratchpad_grantor(), src_scales, wei_scales, pd()->K(), pd()->N(), wei_scale_per_k, wei_scale_per_n, pd()->attr(), diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index a0b0a7f82e2..fa4cacff031 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -1316,19 +1316,19 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); - bgmmc.with_scales = !src_scales.has_default_values() - || !wei_scales.has_default_values(); - if (bgmmc.with_scales) { + const bool has_wei_scales = !wei_scales.has_default_values(); + bgmmc.with_scales = !src_scales.has_default_values() || has_wei_scales; + if (has_wei_scales) { const auto wei_qmask_N = 1 << (bgmmc.ndims - 1); const auto wei_qmask_K = 1 << (bgmmc.ndims - 2); - bgmmc.is_oscale_per_k = wei_scales.mask_ & wei_qmask_K; - bgmmc.is_oscale_per_n = wei_scales.mask_ & wei_qmask_N; + bgmmc.is_oscale_per_k = wei_scales.get_mask() & wei_qmask_K; + bgmmc.is_oscale_per_n = wei_scales.get_mask() & wei_qmask_N; bgmmc.apply_scales_in_buffer_b = bgmmc.is_oscale_per_k && bgmmc.with_wei_decompression && bgmmc.N * bgmmc.K != 1; // only common and per-oc-channel scales are supported // only per-ic-channel scales is supprted with weight decompression - VCONDCHECK_BG(wei_scales.mask_ == 0 || bgmmc.is_oscale_per_n + VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_oscale_per_n || IMPLICATION(bgmmc.is_oscale_per_k, bgmmc.with_wei_decompression), VERBOSE_UNSUPPORTED_SCALES_CFG); @@ -1337,8 +1337,8 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); bgmmc.with_dst_scales = !dst_scales.has_default_values(); // only common scales are supported - VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.mask_ != 0), - VERBOSE_UNSUPPORTED_SCALES_CFG) + VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); const auto &p = attr.post_ops_; bgmmc.with_sum = p.find(primitive_kind::sum) != -1; diff --git a/src/gpu/generic/ref_concat.hpp b/src/gpu/generic/ref_concat.hpp index 5c230cc634d..7c0a23081d3 100644 --- a/src/gpu/generic/ref_concat.hpp +++ b/src/gpu/generic/ref_concat.hpp @@ -60,12 +60,8 @@ struct ref_concat_t : public gpu::primitive_t { reorder_pds_.resize(n_ + use_tent_dst()); for (int i = 0; i < n_; ++i) { primitive_attr_t r_attr; - int mask = 0; - bool is_set = false; - VDISPATCH_CONCAT_SC( - sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, &is_set), - VERBOSE_UNSUPPORTED_SCALES_CFG); - if (is_set) { + if (!sc.get(DNNL_ARG_MULTIPLE_SRC + i).has_default_values()) { + int mask = sc.get_mask(DNNL_ARG_MULTIPLE_SRC + i); VDISPATCH_CONCAT(mask == 0, "non-zero mask"); VDISPATCH_CONCAT_SC(r_attr.scales_.set(DNNL_ARG_SRC, mask), VERBOSE_UNSUPPORTED_SCALES_CFG); diff --git a/src/gpu/generic/sycl/layer_normalizations_kernels.hpp b/src/gpu/generic/sycl/layer_normalizations_kernels.hpp index 1c09075158a..ab8e4c076a0 100644 --- a/src/gpu/generic/sycl/layer_normalizations_kernels.hpp +++ b/src/gpu/generic/sycl/layer_normalizations_kernels.hpp @@ -79,12 +79,20 @@ struct layer_normalization_fwd_kernel_vec_t { memory_tensor_t data_mem(data_, conf_.data_md); memory_tensor_t scale_mem(scale_, conf_.data_scaleshift_md); memory_tensor_t shift_mem(shift_, conf_.data_scaleshift_md); - memory_plain_t rt_scale_mem(rt_scale_, conf_.scales_src_dt); - memory_plain_t dst_scale_mem(dst_scale_, conf_.scales_dst_dt); memory_tensor_t stat_mem(stat_, conf_.stat_md); memory_plain_t var_mem(var_, conf_.var_dt); memory_tensor_t dst_mem(dst_, conf_.dst_md); + float sr = 1.f; + if (!conf_.src_def) { + memory_plain_t rt_scale_mem(rt_scale_, conf_.scales_src_dt); + sr = rt_scale_mem.load(0); + } + float ds = 1.f; + if (!conf_.dst_def) { + memory_plain_t dst_scale_mem(dst_scale_, conf_.scales_dst_dt); + ds = dst_scale_mem.load(0); + } float eps = epsilon(); const size_t s_off = conf_.stat_md.off_l(idx); auto v_mean = stat_mem.load(s_off); @@ -104,8 +112,6 @@ struct layer_normalization_fwd_kernel_vec_t { float s = data_mem.load(src_off); float d = sm * (s - v_mean) + sv; - float sr = conf_.src_def ? 1.f : rt_scale_mem.load(0); - float ds = conf_.dst_def ? 1.f : dst_scale_mem.load(0); d = (d * sr * (1.f / ds)); dst_mem.store(d, d_off); } @@ -171,8 +177,6 @@ struct layer_normalization_fwd_kernel_vec1_t { memory_tensor_t data_mem(data_, conf_.data_md); memory_tensor_t scale_mem(scale_, conf_.data_scaleshift_md); memory_tensor_t shift_mem(shift_, conf_.data_scaleshift_md); - memory_plain_t rt_scale_mem(rt_scale_, conf_.scales_src_dt); - memory_plain_t dst_scale_mem(dst_scale_, conf_.scales_dst_dt); memory_tensor_t stat_out_mem(mean_out_, conf_.stat_md); memory_plain_t var_out_mem(var_out_, conf_.var_dt); memory_tensor_t dst_mem(dst_, conf_.dst_md); @@ -181,6 +185,17 @@ struct layer_normalization_fwd_kernel_vec1_t { stat_out_mem.store(0, idx); var_out_mem.store(0, idx); } + float sr = 1.f; + if (!conf_.src_def) { + memory_plain_t rt_scale_mem(rt_scale_, conf_.scales_src_dt); + sr = rt_scale_mem.load(0); + } + float ds = 1.f; + if (!conf_.dst_def) { + memory_plain_t dst_scale_mem(dst_scale_, conf_.scales_dst_dt); + ds = dst_scale_mem.load(0); + } + float eps = epsilon(); const size_t s_off = conf_.stat_md.off_l(idx); float v_mean = 0.f; @@ -217,9 +232,6 @@ struct layer_normalization_fwd_kernel_vec1_t { const auto d_off = dst_md().off_l(index); float s = data_mem.load(src_off); float d = sm * (s - v_mean) + sv; - - float sr = conf_.src_def ? 1.f : rt_scale_mem.load(0); - float ds = conf_.dst_def ? 1.f : dst_scale_mem.load(0); d = (d * sr * (1.f / ds)); dst_mem.store(d, d_off); diff --git a/src/gpu/generic/sycl/ref_binary.hpp b/src/gpu/generic/sycl/ref_binary.hpp index e6b8ec6f2ce..2e633d5240c 100644 --- a/src/gpu/generic/sycl/ref_binary.hpp +++ b/src/gpu/generic/sycl/ref_binary.hpp @@ -87,8 +87,7 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t { const auto &scales = attr()->scales_; bool dt_ok = true; for (auto arg : supported_args) { - auto &s = scales.get(arg); - dt_ok = dt_ok && is_supported_type(s.data_type_); + dt_ok = dt_ok && is_supported_type(scales.get_data_type(arg)); } return dt_ok && attr_scales_ok(supported_args); } diff --git a/src/gpu/generic/sycl/ref_convolution.cpp b/src/gpu/generic/sycl/ref_convolution.cpp index b2f0458eba3..3da156cb7b5 100644 --- a/src/gpu/generic/sycl/ref_convolution.cpp +++ b/src/gpu/generic/sycl/ref_convolution.cpp @@ -44,8 +44,7 @@ status_t ref_convolution_fwd_t::pd_t::init_conf() { = !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); conf_.do_scale_dst = !attr()->scales_.get(DNNL_ARG_DST).has_default_values(); - conf_.single_weight_scale - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0; + conf_.single_weight_scale = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0; conf_.use_data_zeropoints = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC_0); @@ -109,8 +108,7 @@ status_t ref_convolution_bwd_data_t::pd_t::init_conf() { = !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); conf_.do_scale_dst = !attr()->scales_.get(DNNL_ARG_DST).has_default_values(); - conf_.single_weight_scale - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0; + conf_.single_weight_scale = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0; conf_.use_data_zeropoints = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC_0); @@ -175,8 +173,7 @@ status_t ref_convolution_bwd_weights_t::pd_t::init_conf() { = !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); conf_.do_scale_dst = !attr()->scales_.get(DNNL_ARG_DST).has_default_values(); - conf_.single_weight_scale - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0; + conf_.single_weight_scale = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0; conf_.use_data_zeropoints = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC_0); diff --git a/src/gpu/generic/sycl/ref_convolution.hpp b/src/gpu/generic/sycl/ref_convolution.hpp index 55faae92c08..f930b713f4e 100644 --- a/src/gpu/generic/sycl/ref_convolution.hpp +++ b/src/gpu/generic/sycl/ref_convolution.hpp @@ -66,7 +66,7 @@ inline bool check_convolution_scales_types(const primitive_attr_t *attr) { const auto &scales = attr->scales_; for (auto arg : supported_args) { - auto dt = scales.get(arg).data_type_; + const auto dt = scales.get_data_type(arg); if (!is_supported_type(dt)) { return false; } } return true; diff --git a/src/gpu/generic/sycl/ref_layer_normalizations.cpp b/src/gpu/generic/sycl/ref_layer_normalizations.cpp index 69763aa2568..c3b3f9aa032 100644 --- a/src/gpu/generic/sycl/ref_layer_normalizations.cpp +++ b/src/gpu/generic/sycl/ref_layer_normalizations.cpp @@ -44,12 +44,8 @@ status_t ref_layer_normalization_fwd_t::pd_t::init_conf() { conf_.src_def = attr()->scales_.get(DNNL_ARG_SRC).has_default_values(); conf_.dst_def = attr()->scales_.get(DNNL_ARG_DST).has_default_values(); - conf_.scales_src_dt = conf_.src_def - ? data_type_t::dnnl_f32 - : attr()->scales_.get(DNNL_ARG_SRC).data_type_; - conf_.scales_dst_dt = conf_.dst_def - ? data_type_t::dnnl_f32 - : attr()->scales_.get(DNNL_ARG_DST).data_type_; + conf_.scales_src_dt = attr()->scales_.get_data_type(DNNL_ARG_SRC); + conf_.scales_dst_dt = attr()->scales_.get_data_type(DNNL_ARG_DST); conf_.use_scale = use_scale(); conf_.use_shift = use_shift(); diff --git a/src/gpu/generic/sycl/ref_layer_normalizations.hpp b/src/gpu/generic/sycl/ref_layer_normalizations.hpp index 52bd8e6f16f..19d943d70e5 100644 --- a/src/gpu/generic/sycl/ref_layer_normalizations.hpp +++ b/src/gpu/generic/sycl/ref_layer_normalizations.hpp @@ -81,7 +81,7 @@ struct ref_layer_normalization_fwd_t : public gpu::generic::sycl::primitive_t { const auto &scales = attr()->scales_; for (auto arg : supported_args) { - auto dt = scales.get(arg).data_type_; + const auto dt = scales.get_data_type(arg); if (!is_supported_type(dt)) { return false; } } return true; diff --git a/src/gpu/generic/sycl/ref_matmul.cpp b/src/gpu/generic/sycl/ref_matmul.cpp index d79f80b3cf0..7c428dcd92e 100644 --- a/src/gpu/generic/sycl/ref_matmul.cpp +++ b/src/gpu/generic/sycl/ref_matmul.cpp @@ -33,7 +33,7 @@ void ref_matmul_t::pd_t::init_conf() { conf_.do_scale_dst = !attr()->scales_.get(DNNL_ARG_DST).has_default_values(); conf_.single_weights_scale - = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0; + = attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) == 0; conf_.use_data_zeropoints = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC_0); diff --git a/src/gpu/generic/sycl/ref_matmul.hpp b/src/gpu/generic/sycl/ref_matmul.hpp index 2fa308b419f..63ed8afd3e6 100644 --- a/src/gpu/generic/sycl/ref_matmul.hpp +++ b/src/gpu/generic/sycl/ref_matmul.hpp @@ -122,8 +122,7 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t { const auto &scales = attr()->scales_; bool dt_ok = true; for (auto arg : supported_args) { - auto &s = scales.get(arg); - dt_ok = dt_ok && is_supported_type(s.data_type_); + dt_ok = dt_ok && is_supported_type(scales.get_data_type(arg)); } return dt_ok && attr_scales_ok(supported_args); } diff --git a/src/gpu/generic/sycl/ref_reorder.cpp b/src/gpu/generic/sycl/ref_reorder.cpp index c7c8c25bfdc..478453d218c 100644 --- a/src/gpu/generic/sycl/ref_reorder.cpp +++ b/src/gpu/generic/sycl/ref_reorder.cpp @@ -34,10 +34,10 @@ status_t ref_reorder_t::pd_t::init_conf() { conf_.do_scale_src = !attr()->scales_.get(DNNL_ARG_SRC_0).has_default_values(); - conf_.scale_src_mask = attr()->scales_.get(DNNL_ARG_SRC_0).mask_; + conf_.scale_src_mask = attr()->scales_.get_mask(DNNL_ARG_SRC_0); conf_.do_scale_dst = !attr()->scales_.get(DNNL_ARG_DST).has_default_values(); - conf_.scale_dst_mask = attr()->scales_.get(DNNL_ARG_DST).mask_; + conf_.scale_dst_mask = attr()->scales_.get_mask(DNNL_ARG_DST); conf_.post_ops = sycl_post_ops_t(attr(), dst_md()); return status::success; diff --git a/src/gpu/generic/sycl/ref_reorder.hpp b/src/gpu/generic/sycl/ref_reorder.hpp index 3155d8d1486..ff8ea8ba510 100644 --- a/src/gpu/generic/sycl/ref_reorder.hpp +++ b/src/gpu/generic/sycl/ref_reorder.hpp @@ -103,7 +103,7 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t { const auto &scales = attr()->scales_; for (auto arg : supported_args) { - auto dt = scales.get(arg).data_type_; + const auto dt = scales.get_data_type(arg); if (!is_supported_type(dt)) { return false; } } return true; diff --git a/src/gpu/generic/sycl/ref_softmax.hpp b/src/gpu/generic/sycl/ref_softmax.hpp index 8143eb6bc51..94910c72b32 100644 --- a/src/gpu/generic/sycl/ref_softmax.hpp +++ b/src/gpu/generic/sycl/ref_softmax.hpp @@ -50,7 +50,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t { VDISPATCH_SOFTMAX(attr()->has_default_values( sm::scales_runtime | sm::post_ops), VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_SOFTMAX(attr_oscale_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_SOFTMAX(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_SOFTMAX(sycl_post_ops_t::post_ops_ok(attr(), true, false), VERBOSE_UNSUPPORTED_POSTOP); VDISPATCH_SOFTMAX_SC( @@ -67,15 +67,6 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t { sycl_softmax_conf_t conf_; status_t init_conf(); - bool attr_oscale_ok() const { - const auto &scales = attr()->scales_; - bool ok = true; - for (const auto &e : scales.scales_) { - ok = ok && e.second.mask_ == 0; - } - return ok; - } - bool check_data_types(data_type_t src) { return utils::one_of(src, data_type::f32, data_type::bf16, data_type::f16, data_type::s8, data_type::u8); diff --git a/src/gpu/generic/sycl/ref_sum.hpp b/src/gpu/generic/sycl/ref_sum.hpp index 89dd8b5aee1..09754b21075 100644 --- a/src/gpu/generic/sycl/ref_sum.hpp +++ b/src/gpu/generic/sycl/ref_sum.hpp @@ -60,15 +60,13 @@ struct ref_sum_t : public gpu::generic::sycl::primitive_t { // Block formats are not yet supported // Dimensions can not be > 6 - VDISPATCH_SUM( - !(!src_d.is_plain() - || src_d.ndims() > xpu::sycl::md_t::max_dims), + VDISPATCH_SUM(src_d.is_plain() + && src_d.ndims() <= xpu::sycl::md_t::max_dims, VERBOSE_UNSUPPORTED_TENSOR_LAYOUT, "src"); - VDISPATCH_SUM(!(!attr()->scales_.has_default_values() - && !is_supported_type( - scales.get(DNNL_ARG_SRC + i) - .data_type_)), + VDISPATCH_SUM(attr()->scales_.has_default_values() + || is_supported_type( + scales.get_data_type(DNNL_ARG_SRC + i)), VERBOSE_UNSUPPORTED_ATTR); } diff --git a/src/gpu/generic/sycl/reorder_kernels.hpp b/src/gpu/generic/sycl/reorder_kernels.hpp index e9e317fcd99..4320d7645a7 100644 --- a/src/gpu/generic/sycl/reorder_kernels.hpp +++ b/src/gpu/generic/sycl/reorder_kernels.hpp @@ -73,14 +73,14 @@ struct reorder_kernel_t { = (i < src_md().ndims()) ? src_md().strides()[i] : INT_MAX; } dims_t dims_scales_src; - if (conf_.scale_src_mask != 0) { + if (conf_.scale_src_mask > 0) { for (int i = 0; i < max_supported_ndims; i++) { dims_scales_src[i] = conf_.scale_src_mask >> i & 1 ? dims[i] : 1; } } dims_t dims_scales_dst; - if (conf_.scale_dst_mask != 0) { + if (conf_.scale_dst_mask > 0) { for (int i = 0; i < max_supported_ndims; i++) { dims_scales_dst[i] = conf_.scale_dst_mask >> i & 1 ? dims[i] : 1; @@ -97,7 +97,7 @@ struct reorder_kernel_t { auto src = src_mem.load(idx); if (conf_.do_scale_src) { - if (conf_.scale_src_mask != 0) { + if (conf_.scale_src_mask > 0) { int scale_idx = 0; for (int i = 0; i < max_supported_ndims; i++) { if (i < src_md().ndims()) { @@ -116,7 +116,7 @@ struct reorder_kernel_t { auto acc = src; acc = conf_.post_ops.apply(acc, dst_, dst_idx); if (conf_.do_scale_dst) { - if (conf_.scale_dst_mask != 0) { + if (conf_.scale_dst_mask > 0) { int scale_idx = 0; for (int i = 0; i < max_supported_ndims; i++) { if (i < src_md().ndims()) { diff --git a/src/gpu/gpu_utils.hpp b/src/gpu/gpu_utils.hpp index 18c82b1dccc..0cbfa2bc4a8 100644 --- a/src/gpu/gpu_utils.hpp +++ b/src/gpu/gpu_utils.hpp @@ -32,7 +32,7 @@ namespace gpu { inline dim_t get_attr_oscales_count(int mask, const memory_desc_wrapper &md) { dim_t count = 1; - if (mask == 0) return count; + if (mask <= 0) return count; for (int d = 0; d < md.ndims(); d++) { const int dim_mask = 1 << d; @@ -45,13 +45,14 @@ inline dim_t get_attr_oscales_count(int mask, const memory_desc_wrapper &md) { class scales_query_t { public: bool has_default_values() const { return scales_.has_default_values(); } - int get_mask() const { return scales_.mask_; } + int get_mask() const { return scales_.get_mask(); } size_t get_count() const { return count_; } - data_type_t get_data_type() const { return scales_.data_type_; } + data_type_t get_data_type() const { return scales_.get_data_type(); } dim_t get_group() const { - if (scales_.ndims_ < 2) return 1; - const auto g0 = scales_.group_dims_[0]; - const auto g1 = scales_.group_dims_[1]; + if (scales_.has_default_groups()) return 1; + + const auto g0 = scales_.get_group(0); + const auto g1 = scales_.get_group(1); assert(utils::one_of(1, g0, g1)); return g0 > 1 ? g0 : g1; } @@ -59,13 +60,16 @@ class scales_query_t { int get_group_dim() const { // If groups are not identified, they should be set to `1`, and // it shouldn't hurt to divide by 1 any dim. Just use 0th for that. - if (scales_.ndims_ < 2) return 0; - const auto g0 = scales_.group_dims_[0]; - const auto g1 = scales_.group_dims_[1]; + if (scales_.has_default_groups()) return 0; + + const auto g0 = scales_.get_group(0); + const auto g1 = scales_.get_group(1); assert(utils::one_of(1, g0, g1)); UNUSED(g1); const int g_dim = g0 > 1 ? 0 : 1; - return ndims_ - scales_.ndims_ + g_dim; + // Note: hardcoded value so far. + // TODO: replace with some API when ndims can be different from 2. + return ndims_ - /* scales_.get_groups_ndims() = */ 2 + g_dim; } memory_storage_t &get_scales(const exec_ctx_t &ctx) const { @@ -77,11 +81,11 @@ class scales_query_t { int arg) : arg_(arg), ndims_(mdw.ndims()) { scales_ = attr->scales_.get(arg); - count_ = get_attr_oscales_count(scales_.mask_, mdw); + count_ = get_attr_oscales_count(scales_.get_mask(), mdw); } private: - runtime_scales_t scales_; + quant_entry_t scales_; dim_t count_ = 0; int arg_ = 0; int ndims_ = 0; diff --git a/src/gpu/intel/jit/conv/config.cpp b/src/gpu/intel/jit/conv/config.cpp index d326f6bd172..3256d8ba3b8 100644 --- a/src/gpu/intel/jit/conv/config.cpp +++ b/src/gpu/intel/jit/conv/config.cpp @@ -888,7 +888,9 @@ bool post_ops_ok(const conv_problem_t &prb, const hw_t &hw) { scales[i] = scale_args[i].second; if (!attr->scales_.has_default_values(scales)) return false; for (int arg : scales) { - int mask = attr->scales_.get(arg).mask_; + if (attr->scales_.get(arg).has_default_values()) continue; + + int mask = attr->scales_.get(arg).get_mask(); // XXX: per_oc for BWD_D is treated as per_ic assuming it's called from // deconvolution. if (arg == DNNL_ARG_WEIGHTS) { diff --git a/src/gpu/intel/jit/gemm/gen_gemm.hpp b/src/gpu/intel/jit/gemm/gen_gemm.hpp index 934c02cb235..b3526b7ba16 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm.hpp +++ b/src/gpu/intel/jit/gemm/gen_gemm.hpp @@ -299,11 +299,15 @@ struct gen_gemm_t : public gpu_gemm_t { auto &wei_scales = attr()->scales_.get(DNNL_ARG_WEIGHTS); auto &src_scales = attr()->scales_.get(DNNL_ARG_SRC); - if (quant_enabled_ && wei_scales.ndims_ > 1) wei_scales_2d_ = true; - if (quant_enabled_ && src_scales.ndims_ > 1) src_scales_2d_ = true; + if (quant_enabled_ && !wei_scales.has_default_groups()) + wei_scales_2d_ = true; + if (quant_enabled_ && !src_scales.has_default_groups()) + src_scales_2d_ = true; for (auto s : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - auto mask = attr()->scales_.get(s).mask_; + if (attr()->scales_.get(s).has_default_values()) continue; + + auto mask = attr()->scales_.get_mask(s); VDISPATCH_GEMM(utils::one_of(mask, 0, mask_scalar, mask_per_oc, mask_per_ic) || (s == DNNL_ARG_WEIGHTS && wei_scales_2d_ @@ -314,12 +318,11 @@ struct gen_gemm_t : public gpu_gemm_t { } if (wei_scales_2d_) { - auto scales_group_k - = wei_scales.ndims_ > 0 ? wei_scales.group_dims_[0] : 1; + auto scales_group_k = wei_scales.get_group(0); if (scales_group_k >= d->k()) { wei_scales_2d_ = false; } else { - wei_scales_type = wei_scales.data_type_; + wei_scales_type = wei_scales.get_data_type(); if (!wei_zp_2d) wei_q2d_group_k = scales_group_k; else { @@ -328,14 +331,13 @@ struct gen_gemm_t : public gpu_gemm_t { } } // Non-trivial N group unsupported. - VDISPATCH_GEMM(wei_scales.group_dims_[1] == 1, + VDISPATCH_GEMM(wei_scales.get_group(1) == 1, VERBOSE_UNSUPPORTED_SCALES_CFG); } if (src_scales_2d_) { - src_scales_type = src_scales.data_type_; - src_po_sc_ = src_scales.mask_ == 2; - auto scales_group_k - = src_scales.ndims_ > 0 ? src_scales.group_dims_[1] : 1; + src_scales_type = src_scales.get_data_type(); + src_po_sc_ = src_scales.get_mask() == 2; + auto scales_group_k = src_scales.get_group(1); if (scales_group_k >= d->k()) src_scales_2d_ = false; else { diff --git a/src/gpu/intel/jit/gemm/jit_gemm_pd.cpp b/src/gpu/intel/jit/gemm/jit_gemm_pd.cpp index 9b9b5a2bc6c..c212db2a237 100644 --- a/src/gpu/intel/jit/gemm/jit_gemm_pd.cpp +++ b/src/gpu/intel/jit/gemm/jit_gemm_pd.cpp @@ -92,56 +92,52 @@ status_t jit_gemm_pd_t::init_post_ops() { } if (!wei_scales->has_default_values()) { - const auto &mask = wei_scales->mask_; + const auto &mask = wei_scales->get_mask(); bool convert = (mask == 0 || math::is_pow2(mask)); - if (wei_scales->ndims_ > 1) - convert |= (wei_scales->group_dims_[0] >= d->k()); + if (!wei_scales->has_default_groups()) + convert |= (wei_scales->get_group(0) >= d->k()); if (convert) { - ok = ok && (mask == 0 || mask == (1 << (d->c_desc.ndims - 1))); - dim_t dims = {(mask > 0) ? d->m() : 1}; CHECK(memory_desc_init_by_tag(wei_scales_md, 1, &dims, - wei_scales->data_type_, format_tag::a)); + wei_scales->get_data_type(), format_tag::a)); - auto status = post_ops_.prepend_binary(binary_mul, &wei_scales_md); - if (status != status::success) return status; + CHECK(post_ops_.prepend_binary(binary_mul, &wei_scales_md)); binary_srcs_.insert(binary_srcs_.begin(), binary_src_t {binary_src_t::scales, DNNL_ARG_WEIGHTS}); } } if (!src_scales->has_default_values()) { - const auto &mask = src_scales->mask_; + const auto &mask = src_scales->get_mask(); bool convert = (mask == 0); - if (src_scales->ndims_ > 1) - convert |= (src_scales->group_dims_[1] >= d->k()); + if (!src_scales->has_default_groups()) + convert |= (src_scales->get_group(1) >= d->k()); if (convert) { if (mask == 0) { dim_t dims = 1; CHECK(memory_desc_init_by_tag(src_scales_md, 1, &dims, - src_scales->data_type_, format_tag::a)); + src_scales->get_data_type(), format_tag::a)); } else { dim_t dims[] = {d->n(), 1}; CHECK(memory_desc_init_by_tag(src_scales_md, 2, dims, - src_scales->data_type_, format_tag::ab)); + src_scales->get_data_type(), format_tag::ab)); } - auto status = post_ops_.prepend_binary(binary_mul, &src_scales_md); - if (status != status::success) return status; + CHECK(post_ops_.prepend_binary(binary_mul, &src_scales_md)); binary_srcs_.insert(binary_srcs_.begin(), binary_src_t {binary_src_t::scales, DNNL_ARG_SRC}); } } if (!c_scales->has_default_values()) { - ok = ok && (c_scales->mask_ == 0); + ok = ok && (c_scales->get_mask() == 0); + if (!ok) return status::unimplemented; dim_t dims = {1}; CHECK(memory_desc_init_by_tag( c_scales_md, 1, &dims, f32, format_tag::a)); - auto status = post_ops_.append_binary(binary_div, &c_scales_md); - if (status != status::success) return status; + CHECK(post_ops_.append_binary(binary_div, &c_scales_md)); binary_srcs_.push_back( binary_src_t {binary_src_t::scales, DNNL_ARG_DST}); diff --git a/src/gpu/intel/jit/ir/post_ops.cpp b/src/gpu/intel/jit/ir/post_ops.cpp index e80370192ae..a1f79e83741 100644 --- a/src/gpu/intel/jit/ir/post_ops.cpp +++ b/src/gpu/intel/jit/ir/post_ops.cpp @@ -48,30 +48,32 @@ post_op_context_t::post_op_context_t(const primitive_attr_t &attr, if (buf.is_empty()) continue; int key = kernel_info.key(scale_args[i].first) & ~DNNL_ARG_ATTR_SCALES; - int mask = attr.scales_.get(key).mask_; - view_t view; - switch (key) { - case DNNL_ARG_SRC: - ir_assert(mask == 0); - view = po_vm_.create_view(type_t::f32(), mask); - src_scales = add_input_tensor(view, buf); - src_scales_mask = mask; - break; - case DNNL_ARG_WEIGHTS: - // Convert o/i weights mask to src/dst. - // XXX: per_oc for BWD_D is treated as per_ic assuming it's - // called from deconvolution. - ir_assert(utils::one_of(mask, 0, 1, 3)); - view = po_vm_.create_view( - type_t::f32(), (mask) ? 1 << 1 : 0); - wei_scales = add_input_tensor(view, buf); - wei_scales_mask = mask; - break; - case DNNL_ARG_DST: // Invert dst scales right after load. - ir_assert(mask == 0); - view = po_vm_.create_view(type_t::f32(), mask); - dst_scales = add_input_tensor(view, buf); - break; + if (!attr.scales_.get(key).has_default_values()) { + int mask = attr.scales_.get_mask(key); + view_t view; + switch (key) { + case DNNL_ARG_SRC: + ir_assert(mask == 0); + view = po_vm_.create_view(type_t::f32(), mask); + src_scales = add_input_tensor(view, buf); + src_scales_mask = mask; + break; + case DNNL_ARG_WEIGHTS: + // Convert o/i weights mask to src/dst. + // XXX: per_oc for BWD_D is treated as per_ic assuming + // it's called from deconvolution. + ir_assert(utils::one_of(mask, 0, 1, 3)); + view = po_vm_.create_view( + type_t::f32(), (mask) ? 1 << 1 : 0); + wei_scales = add_input_tensor(view, buf); + wei_scales_mask = mask; + break; + case DNNL_ARG_DST: // Invert dst scales right after load. + ir_assert(mask == 0); + view = po_vm_.create_view(type_t::f32(), mask); + dst_scales = add_input_tensor(view, buf); + break; + } } } // Use virtual tensors for scalar scales airthmetic: diff --git a/src/gpu/intel/jit/ir/tensor_config.cpp b/src/gpu/intel/jit/ir/tensor_config.cpp index 20b8765df2b..14965e4dc6a 100644 --- a/src/gpu/intel/jit/ir/tensor_config.cpp +++ b/src/gpu/intel/jit/ir/tensor_config.cpp @@ -61,7 +61,7 @@ void init_extra_tensors(const zero_points_config_t &zp_cfg, int arg = scale_args[i].second; auto &s = attr.scales_.get(arg); if (s.has_default_values()) continue; - std::vector dims = {(s.mask_ == 0) ? 1 : oc}; + std::vector dims = {(s.get_mask() == 0) ? 1 : oc}; layout_t layout(type_t::f32(), 0, dims); int arg_key = DNNL_ARG_ATTR_SCALES | arg; tensor_cfg.add_tensor(scale_args[i].first, arg_key, /*is_input=*/true, diff --git a/src/gpu/intel/jit/reorder/gen_reorder.cpp b/src/gpu/intel/jit/reorder/gen_reorder.cpp index 5f048447146..c488f8155a4 100644 --- a/src/gpu/intel/jit/reorder/gen_reorder.cpp +++ b/src/gpu/intel/jit/reorder/gen_reorder.cpp @@ -58,8 +58,13 @@ status_t gen_reorder_t::pd_t::init(impl::engine_t *engine, return true; }; auto scales_ok = [&]() { - return (attr()->scales_.get(DNNL_ARG_SRC).mask_ == 0) - && (attr()->scales_.get(DNNL_ARG_DST).mask_ == 0); + const bool src_scale_ok + = attr()->scales_.get(DNNL_ARG_SRC).has_default_values() + || attr()->scales_.get_mask(DNNL_ARG_SRC) == 0; + const bool dst_scale_ok + = attr()->scales_.get(DNNL_ARG_DST).has_default_values() + || attr()->scales_.get_mask(DNNL_ARG_DST) == 0; + return src_scale_ok && dst_scale_ok; }; auto is_bf16_or_f32_or_f8 = [](data_type_t dt) { return utils::one_of(dt, bf16, f32, f8_e5m2, f8_e4m3); diff --git a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp index 519409a914d..0cf64a7498c 100644 --- a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp +++ b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp @@ -46,7 +46,9 @@ status_t gemm_with_post_ops_t::pd_t::init(impl::engine_t *engine) { const primitive_attr_t *attributes_with_po = attr(); for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { - const auto &mask = attr()->scales_.get(arg).mask_; + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); if (arg == DNNL_ARG_WEIGHTS && !wei_decomp) VDISPATCH_GEMM((mask == 0 || mask == (1 << (dst_md()->ndims - 1))), VERBOSE_UNSUPPORTED_SCALES_CFG); @@ -88,7 +90,7 @@ status_t gemm_with_post_ops_t::pd_t::init(impl::engine_t *engine) { // Setup empty attributes but keep zero points for gemm. primitive_attr_t attributes_without_po = *attr(); attributes_without_po.set_post_ops(post_ops_t()); - attributes_without_po.scales_ = arg_scales_t(); + attributes_without_po.scales_ = scales_t(); attributes_without_po.zero_points_ = zero_points_t(); int src_mask, wei_mask; auto zp = attributes_with_po->zero_points_; @@ -182,10 +184,10 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx( kernel_ctx.define_int("A_SCALES", with_src_scales); kernel_ctx.define_int("B_SCALES", with_wei_scales); kernel_ctx.define_int("C_SCALES", with_dst_scales); - def_data_type(kernel_ctx, attr_scales.get(DNNL_ARG_WEIGHTS).data_type_, + def_data_type(kernel_ctx, attr_scales.get_data_type(DNNL_ARG_WEIGHTS), "WEI_SCALES"); def_data_type( - kernel_ctx, attr_scales.get(DNNL_ARG_DST).data_type_, "DST_SCALES"); + kernel_ctx, attr_scales.get_data_type(DNNL_ARG_DST), "DST_SCALES"); int dst_zp_mask; attr()->zero_points_.get(DNNL_ARG_DST, &dst_zp_mask); kernel_ctx.define_int("DST_ZERO_POINT", @@ -254,7 +256,7 @@ status_t gemm_with_post_ops_t::execute(const gemm_exec_ctx_t &ctx) const { arg_list.set(idx++, GEMM_CTX_ARG_STORAGE(a_scales)); arg_list.set(idx++, GEMM_CTX_ARG_STORAGE(c_scales)); arg_list.set(idx++, - pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0 ? 1 : 0); + pd()->attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0 ? 1 : 0); arg_list.set(idx, GEMM_CTX_ARG_STORAGE(c_zero_point)); auto nd_range = pd()->dispatch_.nd_range(); exec_status = parallel_for(ctx, nd_range, post_process_kernel_, arg_list); diff --git a/src/gpu/intel/ocl/gemm/ref_gemm.hpp b/src/gpu/intel/ocl/gemm/ref_gemm.hpp index b3452437ea7..00727f4f2db 100644 --- a/src/gpu/intel/ocl/gemm/ref_gemm.hpp +++ b/src/gpu/intel/ocl/gemm/ref_gemm.hpp @@ -136,9 +136,16 @@ struct ref_gemm_t : public gpu_gemm_t { bool attr_oscale_ok() const { const auto &scales = attr()->scales_; - return scales.get(DNNL_ARG_SRC).mask_ == 0 - && scales.get(DNNL_ARG_WEIGHTS).mask_ == 0 - && scales.get(DNNL_ARG_DST).mask_ == 0; + const bool src_scale_ok + = scales.get(DNNL_ARG_SRC).has_default_values() + || scales.get_mask(DNNL_ARG_SRC) == 0; + const bool wei_scale_ok + = scales.get(DNNL_ARG_WEIGHTS).has_default_values() + || scales.get_mask(DNNL_ARG_WEIGHTS) == 0; + const bool dst_scale_ok + = scales.get(DNNL_ARG_DST).has_default_values() + || scales.get_mask(DNNL_ARG_DST) == 0; + return src_scale_ok && wei_scale_ok && dst_scale_ok; } bool attr_zp_ok() const { diff --git a/src/gpu/intel/ocl/gemm_inner_product.hpp b/src/gpu/intel/ocl/gemm_inner_product.hpp index 84e99657113..27c758db4d9 100644 --- a/src/gpu/intel/ocl/gemm_inner_product.hpp +++ b/src/gpu/intel/ocl/gemm_inner_product.hpp @@ -84,14 +84,19 @@ struct gemm_inner_product_fwd_t : public gpu_primitive_t { VDISPATCH_INNER_PRODUCT_SC( init_2d_desc(&c_md, dst_md()), "init_2d_desc()"); primitive_attr_t gemm_attr = *attr(); - auto wei_mask = gemm_attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; - if (wei_mask == 1) //transpose mask for gemm - VDISPATCH_INNER_PRODUCT_SC( - gemm_attr.scales_.set( - DNNL_ARG_WEIGHTS, 1 << (b_md.ndims - 1)), - VERBOSE_UNSUPPORTED_ATTR); - else if (wei_mask != 0) - return status::unimplemented; + if (!gemm_attr.scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { + auto wei_mask = gemm_attr.scales_.get_mask(DNNL_ARG_WEIGHTS); + if (wei_mask == 1) { + // Transpose the mask for gemm. + VDISPATCH_INNER_PRODUCT_SC( + gemm_attr.scales_.set( + DNNL_ARG_WEIGHTS, 1 << (b_md.ndims - 1)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + } else { + VDISPATCH_INNER_PRODUCT( + wei_mask == 0, VERBOSE_UNSUPPORTED_ATTR); + } + } VDISPATCH_INNER_PRODUCT_SC( create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md, weights_md(1), desc()->accum_data_type, &gemm_attr, diff --git a/src/gpu/intel/ocl/gemm_matmul.hpp b/src/gpu/intel/ocl/gemm_matmul.hpp index ab1d5d9c7a0..957e4fab330 100644 --- a/src/gpu/intel/ocl/gemm_matmul.hpp +++ b/src/gpu/intel/ocl/gemm_matmul.hpp @@ -70,17 +70,25 @@ struct gemm_matmul_t : public gpu_primitive_t { return status::success; }; - auto adjust_scales_mask = [&](arg_scales_t &scales, int arg, - int diff_dims) { - int mask = 0, nd = 0; - bool is_set = false; - data_type_t dt = dnnl_data_type_undef; - dims_t dims = {}; - CHECK(attr()->scales_.get(arg, &mask, &is_set, &nd, dims, &dt)); - mask = mask >> diff_dims; - if (is_set) { CHECK(scales.set(arg, mask, nd, dims, dt)); } - return status::success; - }; + // The function shrinks the mask for scales and updates it in + // `scales` object. + auto adjust_scales_mask + = [&](scales_t &scales, int arg, int diff_dims) { + if (attr()->scales_.get(arg).has_default_values()) + return status::success; + + int mask = attr()->scales_.get_mask(arg) >> diff_dims; + data_type_t dt = attr()->scales_.get_data_type(arg); + int nd = 0; + dims_t dims {}; + if (!attr()->scales_.get(arg).has_default_groups()) { + nd = 2; // Note: hardcoded so far. + dims[0] = attr()->scales_.get_group(arg, 0); + dims[1] = attr()->scales_.get_group(arg, 1); + } + CHECK(scales.set(arg, mask, dt, nd, dims)); + return status::success; + }; if (!attr()->zero_points_.has_default_values()) { CHECK(map_gemm_zp(DNNL_ARG_SRC, DNNL_ARG_B)); CHECK(map_gemm_zp( diff --git a/src/gpu/intel/ocl/gen9_binary.hpp b/src/gpu/intel/ocl/gen9_binary.hpp index 2ef8df1c52e..2b45b870a83 100644 --- a/src/gpu/intel/ocl/gen9_binary.hpp +++ b/src/gpu/intel/ocl/gen9_binary.hpp @@ -63,13 +63,17 @@ struct gen9_binary_t : public gpu_primitive_t { VDISPATCH_BINARY(!is_ternary_op(), VERBOSE_BAD_ALGORITHM); VDISPATCH_BINARY( IMPLICATION(!attr()->scales_.has_default_values(), - utils::one_of(dst_md()->data_type, s8, u8) - && utils::everyone_is( - attr()->scales_.get(DNNL_ARG_SRC_0) - .mask_, - attr()->scales_.get(DNNL_ARG_SRC_1) - .mask_, - 0)), + utils::one_of(dst_md()->data_type, s8, u8)), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_BINARY( + IMPLICATION(!attr()->scales_.get(DNNL_ARG_SRC_0) + .has_default_values(), + attr()->scales_.get_mask(DNNL_ARG_SRC_0) == 0), + VERBOSE_UNSUPPORTED_SCALES_CFG); + VDISPATCH_BINARY( + IMPLICATION(!attr()->scales_.get(DNNL_ARG_SRC_1) + .has_default_values(), + attr()->scales_.get_mask(DNNL_ARG_SRC_1) == 0), VERBOSE_UNSUPPORTED_SCALES_CFG); VDISPATCH_BINARY(attr()->has_default_values(attr_skip_mask), VERBOSE_UNSUPPORTED_ATTR); diff --git a/src/gpu/intel/ocl/generic_reorder.cpp b/src/gpu/intel/ocl/generic_reorder.cpp index 3094556cef4..b7cf72fc89c 100644 --- a/src/gpu/intel/ocl/generic_reorder.cpp +++ b/src/gpu/intel/ocl/generic_reorder.cpp @@ -791,8 +791,8 @@ status_t generic_reorder_t::pd_t::init_conf(impl::engine_t *engine) { memcpy(&new_a, src_md(), sizeof(new_a)); memcpy(&new_b, dst_md(), sizeof(new_b)); compress(new_a, new_b, src_mask, dst_mask); - if (src_mask) CHECK(attr_copy.scales_.set(DNNL_ARG_SRC, src_mask)); - if (dst_mask) CHECK(attr_copy.scales_.set(DNNL_ARG_DST, dst_mask)); + if (src_mask >= 0) { CHECK(attr_copy.scales_.set(DNNL_ARG_SRC, src_mask)); } + if (dst_mask >= 0) { CHECK(attr_copy.scales_.set(DNNL_ARG_DST, dst_mask)); } if (!is_generic_faster_than_ref(new_a, new_b)) return status::unimplemented; diff --git a/src/gpu/intel/ocl/micro_sdpa.cpp b/src/gpu/intel/ocl/micro_sdpa.cpp index d37bfd5cd98..0be452e9c1d 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cpp +++ b/src/gpu/intel/ocl/micro_sdpa.cpp @@ -157,10 +157,11 @@ sdpa_config_t *choose_config_xehpc(dim_t head_size, dim_t seq, bool thin_q) { /// | 3 (1100) | true | /// | 1 (1000) | true | /// | 8 (0001) | false | -bool with_quantize_common(const runtime_scales_t &scales) { - return !scales.has_default_values() - && (((scales.mask_ & 3) != 0 && (scales.mask_ & 12) == 0) - || scales.mask_ == 0); +bool with_quantize_common(const quant_entry_t &scale_entry) { + return !scale_entry.has_default_values() + && (((scale_entry.get_mask() & 3) != 0 + && (scale_entry.get_mask() & 12) == 0) + || scale_entry.get_mask() == 0); } /// Returns true if a common zero points value is used for each slice of the diff --git a/src/gpu/intel/ocl/micro_sdpa.hpp b/src/gpu/intel/ocl/micro_sdpa.hpp index a34c17cb5c4..7c20851a978 100644 --- a/src/gpu/intel/ocl/micro_sdpa.hpp +++ b/src/gpu/intel/ocl/micro_sdpa.hpp @@ -86,7 +86,7 @@ struct micro_sdpa_t : public gpu_primitive_t { "tensors", qry_md()->dims[1], key_md()->dims[1], val_md()->dims[1]); - int kq_scales_mask = desc()->kq_scales.mask_; + int kq_scales_mask = desc()->kq_scales.get_mask(); int kq_zp_mask = desc()->kq_zero_points.get(DNNL_ARG_WEIGHTS); if (!desc()->kq_scales.has_default_values() && !desc()->kq_zero_points.has_default_values()) @@ -113,7 +113,7 @@ struct micro_sdpa_t : public gpu_primitive_t { key_group_size()); } - int vs_scales_mask = desc()->vs_scales.mask_; + int vs_scales_mask = desc()->vs_scales.get_mask(); int vs_zp_mask = desc()->vs_zero_points.get(DNNL_ARG_WEIGHTS); if (!desc()->vs_scales.has_default_values() && !desc()->vs_zero_points.has_default_values()) diff --git a/src/gpu/intel/ocl/multi_po_reorder_binary.hpp b/src/gpu/intel/ocl/multi_po_reorder_binary.hpp index e10b1f4e961..808a298ca53 100644 --- a/src/gpu/intel/ocl/multi_po_reorder_binary.hpp +++ b/src/gpu/intel/ocl/multi_po_reorder_binary.hpp @@ -41,8 +41,8 @@ struct multi_po_reorder_binary : public gpu_primitive_t { DECLARE_COMMON_PD_T("multi_po_reorder_binary", multi_po_reorder_binary); status_t init(impl::engine_t *engine) { - if (attr()->scales_.get(DNNL_ARG_SRC_0).is_set_ - || attr()->scales_.get(DNNL_ARG_SRC_1).is_set_ + if (!attr()->scales_.get(DNNL_ARG_SRC_0).has_default_values() + || !attr()->scales_.get(DNNL_ARG_SRC_1).has_default_values() || attr()->post_ops_.len() >= 1) { VDISPATCH_BINARY(false, VERBOSE_UNSUPPORTED_ATTR); } diff --git a/src/gpu/intel/ocl/ref_matmul.cpp b/src/gpu/intel/ocl/ref_matmul.cpp index 8fa152073d5..5510dc25c6b 100644 --- a/src/gpu/intel/ocl/ref_matmul.cpp +++ b/src/gpu/intel/ocl/ref_matmul.cpp @@ -84,15 +84,13 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const dim_t K = a_d.dims()[last]; const auto &attr_scales = pd()->attr()->scales_; - const int wei_scale_mask = attr_scales.get(DNNL_ARG_WEIGHTS).mask_; + const int wei_scale_mask = attr_scales.get_mask(DNNL_ARG_WEIGHTS); const bool wei_scale_per_k = wei_scale_mask & pd()->wei_qmask_K(); - const auto wei_scale_group_ndim = attr_scales.get(DNNL_ARG_WEIGHTS).ndims_; - const auto wei_scale_group_k = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[0] + const auto wei_scale_group_k + = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_groups() + ? attr_scales.get_group(DNNL_ARG_WEIGHTS, 0) : (wei_scale_per_k ? 1 : K); - const auto wei_scale_group_n = wei_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_WEIGHTS).group_dims_[1] - : 1; + const auto wei_scale_group_n = attr_scales.get_group(DNNL_ARG_WEIGHTS, 1); const auto wei_scale_ngroups_k = K / wei_scale_group_k; // Identify wei_scales dimensions as user may not pass them. dims_t wei_scale_dims {}; @@ -120,11 +118,11 @@ status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { const dim_t wei_scale_stride_b1 = b_d.ndims() > 3 ? wei_scale_strides[b_d.ndims() - 4] : 0; - const int src_scale_mask = attr_scales.get(DNNL_ARG_SRC).mask_; + const int src_scale_mask = attr_scales.get_mask(DNNL_ARG_SRC); const bool src_scale_per_k = src_scale_mask & pd()->src_qmask_K(); - const auto src_scale_group_ndim = attr_scales.get(DNNL_ARG_SRC).ndims_; - const auto src_scale_group_k = src_scale_group_ndim > 0 - ? attr_scales.get(DNNL_ARG_SRC).group_dims_[1] + const auto src_scale_group_k + = !attr_scales.get(DNNL_ARG_SRC).has_default_groups() + ? attr_scales.get_group(DNNL_ARG_SRC, 1) : (src_scale_per_k ? 1 : K); const auto src_scale_ngroups_k = K / src_scale_group_k; // Identify src_scales dimensions as user may not pass them. diff --git a/src/gpu/intel/ocl/ref_matmul.hpp b/src/gpu/intel/ocl/ref_matmul.hpp index 428f0289c46..a6f76e3bbcd 100644 --- a/src/gpu/intel/ocl/ref_matmul.hpp +++ b/src/gpu/intel/ocl/ref_matmul.hpp @@ -230,19 +230,19 @@ struct ref_matmul_t : public gpu_primitive_t { def_data_type(kernel_ctx, pd()->bia_dt_, "BIA"); def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "ACC"); def_data_type(kernel_ctx, - pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).data_type_, + pd()->attr()->scales_.get_data_type(DNNL_ARG_WEIGHTS), "WEI_SCALES"); def_data_type(kernel_ctx, pd()->attr()->zero_points_.get_data_type(DNNL_ARG_WEIGHTS), "WEI_ZP"); def_data_type(kernel_ctx, - pd()->attr()->scales_.get(DNNL_ARG_SRC).data_type_, + pd()->attr()->scales_.get_data_type(DNNL_ARG_SRC), "SRC_SCALES"); def_data_type(kernel_ctx, pd()->attr()->zero_points_.get_data_type(DNNL_ARG_SRC), "SRC_ZP"); def_data_type(kernel_ctx, - pd()->attr()->scales_.get(DNNL_ARG_DST).data_type_, + pd()->attr()->scales_.get_data_type(DNNL_ARG_DST), "DST_SCALES"); kernels_.resize(2); CHECK(create_kernel(engine, &kernels_[0], "ref_matmul", kernel_ctx)); diff --git a/src/gpu/intel/primitive_conf.cpp b/src/gpu/intel/primitive_conf.cpp index 66a86a36617..7f9c15e2131 100644 --- a/src/gpu/intel/primitive_conf.cpp +++ b/src/gpu/intel/primitive_conf.cpp @@ -104,19 +104,20 @@ attr_info_t attr_info_t::create(const primitive_attr_t *attr) { const auto &src_scales = attr->scales_.get(DNNL_ARG_SRC); attr_info.with_src_scales = !src_scales.has_default_values(); attr_info.with_src0_scale = !src_scales.has_default_values(); - attr_info.src_scales_mask = src_scales.mask_; const auto &src1_scales = attr->scales_.get(DNNL_ARG_SRC_1); attr_info.with_src1_scale = !src1_scales.has_default_values(); - gpu_assert(src1_scales.mask_ == 0); + if (attr_info.with_src1_scale) { gpu_assert(src1_scales.get_mask() == 0); } const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS); attr_info.with_wei_scales = !wei_scales.has_default_values(); - attr_info.wei_scales_mask = wei_scales.mask_; + // TODO: remove the default `0` value. + attr_info.wei_scales_mask + = attr_info.with_wei_scales ? wei_scales.get_mask() : 0; const auto &dst_scales = attr->scales_.get(DNNL_ARG_DST); attr_info.with_dst_scales = !dst_scales.has_default_values(); - gpu_assert(dst_scales.mask_ == 0); + if (attr_info.with_dst_scales) { gpu_assert(dst_scales.get_mask() == 0); } // zero points const auto &zp = attr->zero_points_; @@ -823,7 +824,6 @@ status_t def_attr_info_impl(compute::kernel_ctx_t &kernel_ctx, kernel_ctx.define_int("WITH_SRC_SCALES", attr_info.with_src_scales); kernel_ctx.define_int("WITH_WEI_SCALES", attr_info.with_wei_scales); kernel_ctx.define_int("WITH_DST_SCALES", attr_info.with_dst_scales); - kernel_ctx.define_int("SRC_SCALES_MASK", attr_info.src_scales_mask); kernel_ctx.define_int("WEI_SCALES_MASK", attr_info.wei_scales_mask); kernel_ctx.define_int("WITH_SRC_ZPOINTS", attr_info.with_src_zpoints); diff --git a/src/gpu/intel/primitive_conf.hpp b/src/gpu/intel/primitive_conf.hpp index e3341fe55f8..596c8bc3e79 100644 --- a/src/gpu/intel/primitive_conf.hpp +++ b/src/gpu/intel/primitive_conf.hpp @@ -83,8 +83,7 @@ struct attr_info_t { bool with_src_scales; bool with_wei_scales; bool with_dst_scales; - bool src_scales_mask; - bool wei_scales_mask; + int wei_scales_mask; bool with_src_zpoints; bool with_wei_zpoints; diff --git a/src/gpu/nvidia/cudnn_binary.hpp b/src/gpu/nvidia/cudnn_binary.hpp index 4bba9c0ba6b..37964272dc9 100644 --- a/src/gpu/nvidia/cudnn_binary.hpp +++ b/src/gpu/nvidia/cudnn_binary.hpp @@ -66,12 +66,7 @@ struct cudnn_binary_t : public gpu::primitive_t { || has_zero_dims(dst_md()->dims, dst_md()->ndims); } - bool check_scales_mask() const { - for (const auto &s : attr()->scales_.scales_) { - if (s.second.mask_ != 0) return false; - } - return true; - } + bool check_scales_mask() const { return attr_scales_ok(); } bool check_no_blocking() const { // Blocking is not supported by cudnnOpTensor, return false if any diff --git a/src/gpu/nvidia/cudnn_convolution.hpp b/src/gpu/nvidia/cudnn_convolution.hpp index 329081c4633..db9a99dd142 100644 --- a/src/gpu/nvidia/cudnn_convolution.hpp +++ b/src/gpu/nvidia/cudnn_convolution.hpp @@ -155,15 +155,17 @@ struct cudnn_convolution_fwd_t : public gpu::primitive_t { && ndims() < 5; } - bool attr_scales_ok() const { - const auto &scales = attr()->scales_; - const auto &supported_args - = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; - if (!scales.has_default_values(supported_args)) return false; - // cuDNN does not support scaling per dimension. - for (auto arg : supported_args) - if (scales.get(arg).mask_ != 0) return false; - return true; + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { + bool ok = attr()->scales_.has_default_values(supported_args); + for (int arg : supported_args) { + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); + // cuDNN does not support scaling per dimension. + ok = ok && (mask == 0); + } + return ok; } }; diff --git a/src/gpu/nvidia/cudnn_inner_product.hpp b/src/gpu/nvidia/cudnn_inner_product.hpp index b4d5c7aa438..1507b841ee6 100644 --- a/src/gpu/nvidia/cudnn_inner_product.hpp +++ b/src/gpu/nvidia/cudnn_inner_product.hpp @@ -39,15 +39,17 @@ struct cudnn_inner_product_fwd_t : public gpu::primitive_t { struct pd_t : public inner_product_fwd_pd_t { using inner_product_fwd_pd_t::inner_product_fwd_pd_t; - bool attr_scales_ok() const { - const auto &scales = attr()->scales_; - const auto &supported_args - = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; - if (!scales.has_default_values(supported_args)) return false; - // cuDNN does not support scaling per dimension. - for (auto arg : supported_args) - if (scales.get(arg).mask_ != 0) return false; - return true; + bool attr_scales_ok(const std::vector &supported_args + = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) const { + bool ok = attr()->scales_.has_default_values(supported_args); + for (int arg : supported_args) { + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); + // cuDNN does not support scaling per dimension. + ok = ok && (mask == 0); + } + return ok; } std::shared_ptr inner_product_impl_; diff --git a/src/gpu/nvidia/cudnn_matmul.hpp b/src/gpu/nvidia/cudnn_matmul.hpp index f1674c860fe..f0853356a5b 100644 --- a/src/gpu/nvidia/cudnn_matmul.hpp +++ b/src/gpu/nvidia/cudnn_matmul.hpp @@ -101,10 +101,11 @@ struct cudnn_matmul_t : public gpu::primitive_t { if (!scales.has_default_values(supported_args)) return false; // cuDNN does not support scaling per dimension. for (auto arg : supported_args) { - auto &s = scales.get(arg); - if (scales.get(arg).mask_ != 0 - || !utils::one_of( - s.data_type_, s8, s32, f32, f16, bf16)) + if (scales.get(arg).has_default_values()) continue; + + if (scales.get_mask(arg) > 0) return false; + if (!utils::one_of( + scales.get_data_type(arg), s8, s32, f32, f16, bf16)) return false; } return true; diff --git a/src/gpu/nvidia/cudnn_matmul_lt.hpp b/src/gpu/nvidia/cudnn_matmul_lt.hpp index 0b6907d094c..2356398c7fe 100644 --- a/src/gpu/nvidia/cudnn_matmul_lt.hpp +++ b/src/gpu/nvidia/cudnn_matmul_lt.hpp @@ -215,10 +215,9 @@ struct cudnn_matmul_lt_t : public gpu::primitive_t { } status_t create_scale_binary_pd(impl::engine_t *engine, int ARG) { - auto scale_md = dnnl_memory_desc(); - scale_md.ndims = attr()->scales_.get(ARG).ndims_; - scale_md.data_type = attr()->scales_.get(ARG).data_type_; - scale_md.format_kind = dnnl_blocked; + memory_desc_t scale_md; + scale_md.data_type = attr()->scales_.get_data_type(ARG); + scale_md.format_kind = format_kind::blocked; auto format_desc = create_scaling_format_desc(ARG, scale_md); scale_md.format_desc = {format_desc}; @@ -246,7 +245,7 @@ struct cudnn_matmul_lt_t : public gpu::primitive_t { } blocking_desc_t create_scaling_format_desc( - int ARG, dnnl_memory_desc &scale_md) { + int ARG, memory_desc_t &scale_md) { blocking_desc_t format_desc; memory_desc_t md; if (ARG == DNNL_ARG_SRC) { @@ -255,11 +254,13 @@ struct cudnn_matmul_lt_t : public gpu::primitive_t { md = *weights_md(0); } else if (ARG == DNNL_ARG_DST) { md = *dst_md(); + } else { + assert(!"unexpected arg"); } scale_md.ndims = md.ndims; for (int i = 0; i < md.ndims; i++) { - if (attr()->scales_.get(1).mask_ & (1 << i)) { + if (attr()->scales_.get_mask(ARG) & (1 << i)) { scale_md.dims[i] = md.dims[i]; } else { scale_md.dims[i] = 1; @@ -303,20 +304,17 @@ struct cudnn_matmul_lt_t : public gpu::primitive_t { bool single_scale(int ARG) const { const auto &scales = attr()->scales_; - return scales.get(ARG).mask_ == 0; + return scales.get_mask(ARG) == 0; } - bool scales_ok() { - data_type_t src_scale_dt - = attr()->scales_.get(DNNL_ARG_SRC).data_type_; - data_type_t wei_scale_dt - = attr()->scales_.get(DNNL_ARG_WEIGHTS).data_type_; - bool src_scales_ok = default_scale(DNNL_ARG_SRC) - || utils::one_of( - src_scale_dt, data_type::s8, data_type::s32); - bool wei_scales_ok = default_scale(DNNL_ARG_WEIGHTS) - || utils::one_of( - wei_scale_dt, data_type::s8, data_type::s32); + bool scales_ok() const { + bool src_scales_ok = IMPLICATION(!default_scale(DNNL_ARG_SRC), + utils::one_of(attr()->scales_.get_data_type(DNNL_ARG_SRC), + data_type::s8, data_type::s32)); + bool wei_scales_ok = IMPLICATION(!default_scale(DNNL_ARG_WEIGHTS), + utils::one_of( + attr()->scales_.get_data_type(DNNL_ARG_WEIGHTS), + data_type::s8, data_type::s32)); return src_scales_ok && wei_scales_ok; } diff --git a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp index fda74f94e51..b6dfeea4346 100644 --- a/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp +++ b/src/gpu/nvidia/cudnn_matmul_lt_impl.hpp @@ -56,19 +56,22 @@ struct cublas_lt_params : cublas_base_params { || weights_d.has_runtime_dims_or_strides(); if (attr->scales_.get(DNNL_ARG_SRC).has_default_values()) { - auto src_scale = attr->scales_.get(DNNL_ARG_SRC); - if (src_scale.mask_ != 0) { multi_src_scale_ = true; } + if (attr->scales_.get_mask(DNNL_ARG_SRC) > 0) { + multi_src_scale_ = true; + } } if (attr->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { - auto wei_scale = attr->scales_.get(DNNL_ARG_WEIGHTS); - if (wei_scale.mask_ != 0) { multi_wei_scale_ = true; } + if (attr->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { + multi_wei_scale_ = true; + } } with_dst_scale_ = !attr->scales_.get(DNNL_ARG_DST).has_default_values(); if (with_dst_scale_) { - auto dst_scale = attr->scales_.get(DNNL_ARG_DST); - if (dst_scale.mask_ != 0) { multi_dst_scale_ = true; } + if (attr->scales_.get_mask(DNNL_ARG_DST) > 0) { + multi_dst_scale_ = true; + } } // Initialise flags and variables for the imma case (E.g. imma_case_ flag). diff --git a/src/gpu/nvidia/cudnn_reorder.hpp b/src/gpu/nvidia/cudnn_reorder.hpp index 30c7bc185dd..492afc8dff2 100644 --- a/src/gpu/nvidia/cudnn_reorder.hpp +++ b/src/gpu/nvidia/cudnn_reorder.hpp @@ -83,14 +83,17 @@ struct cudnn_reorder_t : public gpu::primitive_t { return ok; } - bool scales_ok() const { - const auto &scales = attr()->scales_; - const auto &supported_args = {DNNL_ARG_FROM, DNNL_ARG_TO}; - if (!scales.has_default_values(supported_args)) return false; - // cuDNN does not support scaling per dimension. - for (auto arg : supported_args) - if (scales.get(arg).mask_ != 0) return false; - return true; + bool scales_ok(const std::vector &supported_args + = {DNNL_ARG_FROM, DNNL_ARG_TO}) const { + bool ok = attr()->scales_.has_default_values(supported_args); + for (int arg : supported_args) { + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); + // cuDNN does not support scaling per dimension. + ok = ok && (mask == 0); + } + return ok; } bool post_ops_ok() const { diff --git a/src/gpu/nvidia/cudnn_reorder_lt.hpp b/src/gpu/nvidia/cudnn_reorder_lt.hpp index b8e1d2afc7f..b97c7d0de0d 100644 --- a/src/gpu/nvidia/cudnn_reorder_lt.hpp +++ b/src/gpu/nvidia/cudnn_reorder_lt.hpp @@ -117,14 +117,17 @@ struct cudnn_reorder_lt_t : public gpu::primitive_t { return ok; } - bool scales_ok() const { - const auto &scales = attr()->scales_; - const auto &supported_args = {DNNL_ARG_FROM, DNNL_ARG_TO}; - if (!scales.has_default_values(supported_args)) return false; - // cuDNN does not support scaling per dimension. - for (auto arg : supported_args) - if (scales.get(arg).mask_ != 0) return false; - return true; + bool scales_ok(const std::vector &supported_args + = {DNNL_ARG_FROM, DNNL_ARG_TO}) const { + bool ok = attr()->scales_.has_default_values(supported_args); + for (int arg : supported_args) { + if (attr()->scales_.get(arg).has_default_values()) continue; + + const auto &mask = attr()->scales_.get_mask(arg); + // cuDNN does not support scaling per dimension. + ok = ok && (mask == 0); + } + return ok; } bool post_ops_ok() const { @@ -144,11 +147,6 @@ struct cudnn_reorder_lt_t : public gpu::primitive_t { && post_ops_ok(); if (!ok) return status::unimplemented; - primitive_attr_t r_attr; - int mask = 0; - bool is_set = false; - auto src = DNNL_ARG_DST; - auto dst = DNNL_ARG_SRC; if (src_float_) { src_scratch_md_ = *src_md(); dst_scratch_md_ = create_temp_md(src_scratch_md_); @@ -157,21 +155,23 @@ struct cudnn_reorder_lt_t : public gpu::primitive_t { src_scratch_md_ = create_temp_md(dst_scratch_md_); dst_scratch_md_ = *dst_md(); } - attr()->scales_.get(src, &mask, &is_set); - if (is_set) { r_attr.scales_.set(src, mask); } - attr()->scales_.get(dst, &mask, &is_set); - if (is_set) { r_attr.scales_.set(dst, mask); } + primitive_attr_t r_attr; + if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values()) { + const auto mask = attr()->scales_.get_mask(DNNL_ARG_SRC); + r_attr.scales_.set(DNNL_ARG_SRC, mask); + } - status_t generic_ok = reorder_primitive_desc_create( - generic_reorder_desc_, engine, &src_scratch_md_, - &dst_scratch_md_, &r_attr); - ok = ok && (generic_ok == status::success); + if (!attr()->scales_.get(DNNL_ARG_DST).has_default_values()) { + const auto mask = attr()->scales_.get_mask(DNNL_ARG_DST); + r_attr.scales_.set(DNNL_ARG_DST, mask); + } - if (!ok) return status::unimplemented; + CHECK(reorder_primitive_desc_create(generic_reorder_desc_, engine, + &src_scratch_md_, &dst_scratch_md_, &r_attr)); init_scratchpad(); - return dnnl_success; + return status::success; } void init_scratchpad() { diff --git a/src/gpu/nvidia/cudnn_softmax.hpp b/src/gpu/nvidia/cudnn_softmax.hpp index 3ce67c4acaf..1268ee4b341 100644 --- a/src/gpu/nvidia/cudnn_softmax.hpp +++ b/src/gpu/nvidia/cudnn_softmax.hpp @@ -56,7 +56,7 @@ struct cudnn_softmax_fwd_t : public gpu::primitive_t { && set_default_formats() == status::success && src_d.is_plain() && dst_d.is_plain() && dst_d == src_d && IMPLICATION(!attr()->scales_.has_default_values(), - check_scales_mask() + attr_scales_ok() && dst_d.data_type() != data_type::s8); if (!ok) return status::unimplemented; @@ -64,12 +64,6 @@ struct cudnn_softmax_fwd_t : public gpu::primitive_t { return softmax_impl_->init(this); } - bool check_scales_mask() const { - for (const auto &s : attr()->scales_.scales_) { - if (s.second.mask_ != 0) return false; - } - return true; - } std::shared_ptr softmax_impl_; }; diff --git a/tests/gtests/test_iface_attr.cpp b/tests/gtests/test_iface_attr.cpp index 720aed6eb7f..d695913c147 100644 --- a/tests/gtests/test_iface_attr.cpp +++ b/tests/gtests/test_iface_attr.cpp @@ -281,15 +281,17 @@ HANDLE_EXCEPTIONS_FOR_TEST_F(attr_test_t, TestScalesWithGroups) { for (auto arg : supported_args) { // single non-default scales for supported arg attr.set_scales(arg, 0, {}); - // multiple scales with groups + // multiple scales with a single group dim attr.set_scales(arg, 1 << 0, {4}); + // multiple scales with multiple group dims + attr.set_scales(arg, 1 << 0, {4, 1}); // scales with groups and a data type - attr.set_scales(arg, 1 << 0, {4}, data_type::f32); + attr.set_scales(arg, 1 << 0, {4, 1}, data_type::f32); } for (auto arg : unsupported_args) { // multiple scales with groups for unsupported args - EXPECT_ANY_THROW(attr.set_scales(arg, 1 << 0, {4})); + EXPECT_ANY_THROW(attr.set_scales(arg, 1 << 0, {4, 1})); // multiple scales with non-default data type for unsupported args EXPECT_ANY_THROW(attr.set_scales(arg, 1 << 0, {}, data_type::bf16)); }