Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: ocl: gemm: fix gemm_with_post_ops accumulator type #2289

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
#else
ACC_DATA_T acc = SRC_TO_ACC(src[data_idx]);
#endif
float accumulator = acc;
float accumulator = convert_float(acc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this issue but do we need to be concerned with potential f64 precision loss?

Copy link
Contributor Author

@rjoursler rjoursler Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 Yes. Here is a tracker: MFDNN-12893.

if ((d0 == D0_WO_PADDING && d1 == D1_WO_PADDING && d2 == D2_WO_PADDING
&& d3 == D3_WO_PADDING)
|| (d0 < D0_WO_PADDING && d1 < D1_WO_PADDING && d2 < D2_WO_PADDING
Expand All @@ -116,7 +116,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
const float b_scale
= B_SCALES ? WEI_SCALES_TO_REF(b_scales[scale_stride * d3]) : 1;
#endif
acc *= A_SCALE * b_scale;
accumulator *= A_SCALE * b_scale;
#endif

#if WITH_BIAS == 1
Expand All @@ -127,7 +127,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
#else
size_t bia_idx = BIA_OFF(d0, d1, 0, 0, 0);
#endif
acc += BIA_TO_ACC(bias[bia_idx]);
accumulator += BIA_TO_ACC(bias[bia_idx]);
#endif

// Apply postops
Expand All @@ -136,7 +136,6 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias,
sum_src = DST_TO_ACC(dst[data_idx]);
#endif

accumulator = acc;
#if NDIMS == 2
APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, d0, 1, d1, 1,
0, 1, 0, 1, 0, 1, 0, 1);
Expand Down
Loading