From 2da0adf1b9ead6c1d5fe1fb3756761a16c72fe92 Mon Sep 17 00:00:00 2001 From: Roy Oursler Date: Wed, 18 Dec 2024 14:26:14 -0800 Subject: [PATCH] xe: ocl: gemm: fix gemm_with_post_ops accumulator type --- src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl index f8181ea8990..90b066c9f08 100644 --- a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl +++ b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl @@ -98,7 +98,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, #else ACC_DATA_T acc = SRC_TO_ACC(src[data_idx]); #endif - float accumulator = acc; + float accumulator = convert_float(acc); if ((d0 == D0_WO_PADDING && d1 == D1_WO_PADDING && d2 == D2_WO_PADDING && d3 == D3_WO_PADDING) || (d0 < D0_WO_PADDING && d1 < D1_WO_PADDING && d2 < D2_WO_PADDING @@ -116,7 +116,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, const float b_scale = B_SCALES ? WEI_SCALES_TO_REF(b_scales[scale_stride * d3]) : 1; #endif - acc *= A_SCALE * b_scale; + accumulator *= A_SCALE * b_scale; #endif #if WITH_BIAS == 1 @@ -127,7 +127,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, #else size_t bia_idx = BIA_OFF(d0, d1, 0, 0, 0); #endif - acc += BIA_TO_ACC(bias[bia_idx]); + accumulator += convert_float(bias[bia_idx]); #endif // Apply postops @@ -136,7 +136,6 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, sum_src = DST_TO_ACC(dst[data_idx]); #endif - accumulator = acc; #if NDIMS == 2 APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, d0, 1, d1, 1, 0, 1, 0, 1, 0, 1, 0, 1);