Skip to content

Commit

Permalink
gtests: update attr_quant with relevant arguments and conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 16, 2024
1 parent 08a5101 commit 0b8567e
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions tests/gtests/test_iface_attr_quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class attr_quantization_test_t : public ::testing::Test {
engine eng = get_test_engine();
void SetUp() override {}

static primitive_attr gen_attr_with_scales() {
static primitive_attr gen_attr_with_scales(bool with_wei = true) {
primitive_attr attr;
attr.set_scales_mask(DNNL_ARG_SRC, 0);
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
if (with_wei) attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
attr.set_scales_mask(DNNL_ARG_DST, 0);
return attr;
}
Expand Down Expand Up @@ -84,7 +84,7 @@ TEST_F(attr_quantization_test_t, TestBNorm) {
eng, prop_kind::forward_inference, md, md, 0.1f, flags));
CHECK_UNIMPL(batch_normalization_forward::primitive_desc(eng,
prop_kind::forward_inference, md, md, 0.1f, flags,
gen_attr_with_scales()));
gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS,
DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, DNNL_ARG_DST}) {
Expand Down Expand Up @@ -117,8 +117,7 @@ TEST_F(attr_quantization_test_t, TestConcat) {
memory::desc md {{1, 16, 3, 3}, data_type::s8, tag::abcd};
CHECK_OK(concat::primitive_desc(eng, 1, {md, md}));

for (auto arg :
{DNNL_ARG_MULTIPLE_SRC, DNNL_ARG_MULTIPLE_SRC + 1, DNNL_ARG_DST}) {
for (auto arg : {DNNL_ARG_MULTIPLE_SRC, DNNL_ARG_MULTIPLE_SRC + 1}) {
CHECK_OK(concat::primitive_desc(
eng, 1, {md, md}, gen_attr_with_scales(arg)));
CHECK_UNIMPL(concat::primitive_desc(
Expand Down Expand Up @@ -359,7 +358,8 @@ TEST_F(attr_quantization_test_t, TestEltwise) {
eng, prop_kind::forward, algorithm::eltwise_relu, md, md, 0.f));

CHECK_UNIMPL(eltwise_forward::primitive_desc(eng, prop_kind::forward,
algorithm::eltwise_relu, md, md, 0.f, gen_attr_with_scales()));
algorithm::eltwise_relu, md, md, 0.f,
gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_UNIMPL(eltwise_forward::primitive_desc(eng,
Expand Down Expand Up @@ -403,7 +403,7 @@ TEST_F(attr_quantization_test_t, TestLNorm) {
eng, prop_kind::forward_inference, md, md, stat_md, 0.1f, flags));
CHECK_OK(layer_normalization_forward::primitive_desc(eng,
prop_kind::forward_inference, md, md, stat_md, 0.1f, flags,
gen_attr_with_scales()));
gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE,
DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS, DNNL_ARG_DST}) {
Expand All @@ -418,9 +418,10 @@ TEST_F(attr_quantization_test_t, TestLRN) {
memory::desc md {{1, 16, 3, 3}, dt, tag::abcd};
CHECK_OK(lrn_forward::primitive_desc(eng, prop_kind::forward_inference,
algorithm::lrn_across_channels, md, md, 5, 1.f, 0.75f, 1.0f));
CHECK_UNIMPL(lrn_forward::primitive_desc(eng,
prop_kind::forward_inference, algorithm::lrn_across_channels,
md, md, 5, 1.f, 0.75f, 1.0f, gen_attr_with_scales()));
CHECK_UNIMPL(
lrn_forward::primitive_desc(eng, prop_kind::forward_inference,
algorithm::lrn_across_channels, md, md, 5, 1.f, 0.75f,
1.0f, gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_UNIMPL(lrn_forward::primitive_desc(eng,
Expand Down Expand Up @@ -483,9 +484,9 @@ TEST_F(attr_quantization_test_t, TestMatmul) {
if (arg == DNNL_ARG_WEIGHTS) {
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
gen_attr_with_scales(arg, 1 << 1)));
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
gen_attr_with_scales(arg, (1 << 1) + (1 << 0))));
if (b_dt == data_type::s8) {
CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md,
gen_attr_with_scales(arg, (1 << 1) + (1 << 0))));
// Groups non divisible by 32 are not supported.
CHECK_UNIMPL(matmul::primitive_desc(eng, a_md, b_md, c_md,
gen_attr_with_scales(arg, (1 << 1) + (1 << 0),
Expand Down Expand Up @@ -597,10 +598,10 @@ TEST_F(attr_quantization_test_t, TestPool) {
CHECK_OK(pooling_forward::primitive_desc(eng, prop_kind::forward_inference,
algorithm::pooling_max, src_md, dst_md, {2, 2}, {2, 2}, {0, 0},
{0, 0}, {0, 0}));
CHECK_UNIMPL(
pooling_forward::primitive_desc(eng, prop_kind::forward_inference,
algorithm::pooling_max, src_md, dst_md, {2, 2}, {2, 2},
{0, 0}, {0, 0}, {0, 0}, gen_attr_with_scales()));
CHECK_UNIMPL(pooling_forward::primitive_desc(eng,
prop_kind::forward_inference, algorithm::pooling_max, src_md,
dst_md, {2, 2}, {2, 2}, {0, 0}, {0, 0}, {0, 0},
gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_UNIMPL(pooling_forward::primitive_desc(eng,
Expand All @@ -620,7 +621,7 @@ TEST_F(attr_quantization_test_t, TestPReLU) {
eng, prop_kind::forward, data_md, weights_md, data_md));

CHECK_UNIMPL(prelu_forward::primitive_desc(eng, prop_kind::forward, data_md,
weights_md, data_md, gen_attr_with_scales()));
weights_md, data_md, gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_UNIMPL(prelu_forward::primitive_desc(eng, prop_kind::forward,
Expand All @@ -634,8 +635,8 @@ CPU_TEST_F(attr_quantization_test_t, TestReorder) {
CHECK_OK(reorder::primitive_desc(eng, src_md, eng, dst_md));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_OK(reorder::primitive_desc(
eng, src_md, eng, dst_md, gen_attr_with_scales()));
CHECK_OK(reorder::primitive_desc(eng, src_md, eng, dst_md,
gen_attr_with_scales(/* with_wei = */ false)));
CHECK_OK(reorder::primitive_desc(
eng, src_md, eng, dst_md, gen_attr_with_zp(arg)));
}
Expand Down Expand Up @@ -704,8 +705,8 @@ TEST_F(attr_quantization_test_t, TestShuffle) {

CHECK_OK(shuffle_forward::primitive_desc pd(
eng, prop_kind::forward, md, md, 1, 4));
CHECK_UNIMPL(shuffle_forward::primitive_desc pd(
eng, prop_kind::forward, md, md, 1, 4, gen_attr_with_scales()));
CHECK_UNIMPL(shuffle_forward::primitive_desc pd(eng, prop_kind::forward, md,
md, 1, 4, gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_UNIMPL(shuffle_forward::primitive_desc pd(
Expand Down Expand Up @@ -735,8 +736,8 @@ TEST_F(attr_quantization_test_t, TestSum) {
SKIP_IF_HIP(true, "Unsupported operator for HIP");
memory::desc md {{1, 16, 3, 3}, data_type::s8, tag::abcd};
CHECK_OK(sum::primitive_desc(eng, {1.f, 1.f}, {md, md}));
CHECK_UNIMPL(sum::primitive_desc(
eng, {1.f, 1.f}, {md, md}, gen_attr_with_scales()));
CHECK_UNIMPL(sum::primitive_desc(eng, {1.f, 1.f}, {md, md},
gen_attr_with_scales(/* with_wei = */ false)));

for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) {
CHECK_UNIMPL(sum::primitive_desc(
Expand Down

0 comments on commit 0b8567e

Please sign in to comment.