-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
x64: brdgmm conv: enable zps per group (closes MFDNN-12556) #2274
base: main
Are you sure you want to change the base?
Conversation
c215542
to
0741e51
Compare
src/common/primitive_attr_quant.hpp
Outdated
@@ -244,6 +244,7 @@ struct zero_points_t : public c_compatible { | |||
|
|||
// arg-specific checks | |||
bool common(int arg) const { return get_mask(arg) == 0; } | |||
bool per_oc(int arg) const { return get_mask(arg) == 2; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest not to introduce this as primitives may interpret mask == 2
not necessarily as per_oc
and rely on get_mask()
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzarukin get_mask()
is private, do you suggest to change it to public?
I could also change per_oc
to per_dim_1
alternatively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, wouldn't think it's in private
... Let's use per_dim_1
then, I'll replace it as a part of future refactor later. Thanks for checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/cpu/x64/brgemm/brgemm.cpp
Outdated
@@ -207,6 +207,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, | |||
brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations; | |||
brgemm_p.c_zp_values = post_ops_data.c_zp_values; | |||
brgemm_p.ptr_dst_scales = post_ops_data.dst_scales; | |||
brgemm_p.a_zp_values = post_ops_data.a_zp_values; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: put above c_zp_values
to group identical values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/cpu/x64/brgemm/brgemm.cpp
Outdated
zp_type = zero_points.has_default_values(mem_arg) || skip_zero_point | ||
? brgemm_broadcast_t::none | ||
: brgemm_broadcast_t::per_tensor; | ||
: is_per_oc_bcast ? brgemm_broadcast_t::per_n | ||
: brgemm_broadcast_t::per_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: for better readability and control of the value.
zp_type = brgemm_broadcast_t::none;
if (XXX) {
zp_type = brgemm_broadcast_t::per_tensor;
} else if (YYY) {
zp_type = brgemm_broadcast_t::per_n;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
0741e51
to
2bcc3f3
Compare
Adding src zero points per-group (mask=2) support to brdgmm conv.
Related jira: MFDNN-12556
CI
Nightly