-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Fp8 compute backports #2266
base: main
Are you sure you want to change the base?
Conversation
&& utils::one_of(mask_wei, 0, with_groups ? 3 : 1), | ||
VCHECK_CONV_UNIMPL(utils::one_of(mask_wei, 0, with_groups ? 3 : 1) | ||
&& utils::one_of(mask_dst, 0, 2) | ||
&& utils::one_of(mask_src, 0, 3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you, please, clarify why mask=3 for src but mask=2 for dst?
| smask_t::zero_points_runtime_data_type | ||
| smask_t::scales_runtime_groups | ||
| smask_t::scales_runtime_data_type; | ||
|
||
if (engine->kind() == engine_kind::gpu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here and below seems duplicated. Could you, please, consolidate it?
@@ -195,7 +207,13 @@ struct matmul_pd_t : public primitive_desc_t { | |||
sc.group_dims_[0] == 1 | |||
&& K() % sc.group_dims_[1] == 0); | |||
} else { | |||
ok = ok && (mask == 0); | |||
ok = ok |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here must be a check for fp8 versus classic quantization.
@@ -185,7 +185,7 @@ struct base_cfg_t { | |||
} | |||
const int64_t safe_digits = get_safe_digits(); | |||
const int64_t safe_n_acc = (1LL << safe_digits) / max_value; | |||
return safe_n_acc; | |||
return std::max((int64_t)1L, safe_n_acc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returning safe_n_acc = 0
is intentional here, it says that input values are not reasonable. When it happened it returned zero?
ded32f2
to
8aa5c64
Compare
Description
Backport of mixed fp8 support, additional scale support for compute primitivies.
Checklist
General
make test
andmake test_benchdnn_*
) pass locally for each commit?