diff --git a/src/gpu/intel/jit/ir/epilogue.cpp b/src/gpu/intel/jit/ir/epilogue.cpp index 727af86caec..d0f910e586b 100644 --- a/src/gpu/intel/jit/ir/epilogue.cpp +++ b/src/gpu/intel/jit/ir/epilogue.cpp @@ -299,6 +299,7 @@ class post_op_tensor_t { // Assign new f32 layout and buffer. reg_layout_ = std::move(f32_layout); reg_buf_ = std::move(f32_buf); + info_.retype(type_t::f32()); return ret; } diff --git a/src/gpu/intel/jit/ir/post_ops.cpp b/src/gpu/intel/jit/ir/post_ops.cpp index 49c87987c10..e10294ef5d1 100644 --- a/src/gpu/intel/jit/ir/post_ops.cpp +++ b/src/gpu/intel/jit/ir/post_ops.cpp @@ -43,18 +43,25 @@ post_op_context_t::post_op_context_t(const primitive_attr_t &attr, int src_scales_mask = 0; int wei_scales_mask = 0; int dst_scales_mask = 0; + type_t src_scales_type, wei_scales_type, dst_scales_type; for (int i = 0; i < (int)scale_args.size(); i++) { auto buf = kernel_info.find_arg( scale_args[i].first, /*allow_empty=*/true); if (buf.is_empty()) continue; int key = kernel_info.key(scale_args[i].first) & ~DNNL_ARG_ATTR_SCALES; - int mask = attr.scales_.get(key).mask_; + auto scales = attr.scales_.get(key); + if (scales.has_default_values()) continue; + int mask = scales.mask_; + auto sc_type = scales.data_type_ == data_type::undef + ? type_t::f32() + : scales.data_type_; view_t view; switch (key) { case DNNL_ARG_SRC: ir_assert(mask == 0); - view = po_vm_.create_view(type_t::f32(), mask); + src_scales_type = sc_type; + view = po_vm_.create_view(sc_type, mask); src_scales = add_input_tensor(view, buf); src_scales_mask = mask; break; @@ -63,14 +70,15 @@ post_op_context_t::post_op_context_t(const primitive_attr_t &attr, // XXX: per_oc for BWD_D is treated as per_ic assuming it's // called from deconvolution. ir_assert(utils::one_of(mask, 0, 1, 3)); - view = po_vm_.create_view( - type_t::f32(), (mask) ? 1 << 1 : 0); + wei_scales_type = sc_type; + view = po_vm_.create_view(sc_type, (mask) ? 1 << 1 : 0); wei_scales = add_input_tensor(view, buf); wei_scales_mask = mask; break; case DNNL_ARG_DST: // Invert dst scales right after load. ir_assert(utils::one_of(mask, 0, 2)); - view = po_vm_.create_view(type_t::f32(), mask); + dst_scales_type = sc_type; + view = po_vm_.create_view(sc_type, mask); dst_scales = add_input_tensor(view, buf); dst_scales_mask = mask; break; @@ -273,6 +281,8 @@ bool post_op_context_t::init_need_to_restore_zero_padding( if (zp_cfg.do_dst_compensation && zp_cfg.is_common_dst_zero_point && out_md.dims[1] != out_md.padded_dims[1]) return true; + auto dst_scales = attr.scales_.get(DNNL_ARG_DST); + if (!dst_scales.has_default_values() && dst_scales.mask_ != 0) return true; return false; } diff --git a/src/gpu/intel/jit/ir/post_ops.hpp b/src/gpu/intel/jit/ir/post_ops.hpp index beec0b42bc3..84de8f0939d 100644 --- a/src/gpu/intel/jit/ir/post_ops.hpp +++ b/src/gpu/intel/jit/ir/post_ops.hpp @@ -160,6 +160,8 @@ class post_op_tensor_info_t { return ret; } + void retype(const type_t &new_type) { view_ = view_.retype(new_type); } + void require_masked_update() { needs_masked_update_ = true; } private: