Skip to content

Commit

Permalink
cpu: softmax: enable relaxed accumulation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 13, 2024
1 parent 5570d23 commit e03d3a8
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 19 deletions.
27 changes: 21 additions & 6 deletions doc/primitives/softmax.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,27 @@ argument index as specified by the following table.
Attributes enable you to modify the behavior of the softmax primitive.
The following attributes are supported by the softmax primitive:

| Propagation | Type | Operation | Description | Restrictions |
|:------------|:----------|:-----------------------------------------------------|:--------------------------------------------------------------|:-----------------------------------------------------------------------|
| forward | attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the corresponding tensor by the given scale factor(s). | Supported only for int8 softmax and one scale per tensor is supported. |
| forward | post-op | [Binary](@ref dnnl::post_ops::append_binary) | Applies a @ref dnnl_api_binary operation to the result | General binary post-op restrictions |
| forward | Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result. | |

| Propagation | Type | Operation | Description | Restrictions |
|:------------|:----------|:----------------------------------------------------------------------|:--------------------------------------------------------------|:-----------------------------------------------------------------------|
| forward | attribute | [Scales](@ref dnnl::primitive_attr::set_scales_mask) | Scales the corresponding tensor by the given scale factor(s). | Supported only for int8 softmax and one scale per tensor is supported. |
| forward | post-op | [Binary](@ref dnnl::post_ops::append_binary) | Applies a @ref dnnl_api_binary operation to the result | General binary post-op restrictions |
| forward | Post-op | [Eltwise](@ref dnnl::post_ops::append_eltwise) | Applies an @ref dnnl_api_eltwise operation to the result. | |
| forward | attribute | [Accumulation mode](@ref dnnl::primitive_attr::set_accumulation_mode) | Defines the implementation's accumulation arithmetic. | Only the values `strict`, `relaxed`, and `any` are supported. |

#### Accumulation Mode

There's an optimization opportunity for scenarios when source and destination
floating-point data types of the operation are equal and different from `f32`.
For destination data type different from `f32` additional memory would be used
to accumulate data and store it in the destination memory buffer into a
requested data type. Using the additional memory can be opted-out with an
accumulation mode setting set to
[relaxed](@ref dnnl::accumulation_mode::relaxed) or
[any](@ref dnnl::accumulation_mode::any), which would use the precision of
destination data type to accumulate intermediate results directly to the
destination memory buffer. This performance optimization, however, comes with
some loss of accuracy. Depending on the actual data, the difference between
`strict` and `relaxed` accumulation can reach several ulps.

### Data Type Support

Expand Down
8 changes: 6 additions & 2 deletions src/cpu/ref_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ status_t ref_softmax_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
const memory_desc_wrapper src_d(pd()->src_md());
const memory_desc_wrapper dst_d(pd()->dst_md());

const auto interim_dt = data_type::f32;
const auto interim_dt = pd()->need_intermediate_scratchpad()
? data_type::f32
: dst_d.data_type();
const auto is_inplace = (src == dst);
const auto has_padding = is_padding(dst_d);
const auto zero_padding = has_padding && !is_inplace;
Expand Down Expand Up @@ -210,7 +212,9 @@ status_t ref_softmax_fwd_t::execute_forward_generic(

void *interim_ptr
= pd()->need_intermediate_scratchpad() ? interim_scratchpad : dst;
const auto interim_dt = data_type::f32;
const auto interim_dt = pd()->need_intermediate_scratchpad()
? data_type::f32
: dst_d.data_type();
const auto is_inplace = (src == dst);
const auto has_padding = is_padding(dst_d);
if (has_padding && !is_inplace) {
Expand Down
16 changes: 13 additions & 3 deletions src/cpu/ref_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,19 @@ struct ref_softmax_fwd_t : public primitive_t {
int nthr_; // To not exceed the limit in execute used for set up.

bool need_intermediate_scratchpad() const {
return dst_md()->data_type
!= types::default_accum_data_type(
src_md()->data_type, dst_md()->data_type);
const auto src_dt = src_md()->data_type;
const auto dst_dt = dst_md()->data_type;
// Relaxed accumulation allows to downconvert intermediate results
// directly from xf16 or xf8 to dst avoiding scratchpad memory.
const bool relaxed_acc = src_dt == dst_dt
&& !types::is_integral_dt(dst_dt)
&& utils::one_of(attr()->acc_mode_,
accumulation_mode::relaxed, accumulation_mode::any);
const bool need_scratchpad = dst_md()->data_type
!= types::default_accum_data_type(
src_md()->data_type, dst_md()->data_type)
&& !relaxed_acc;
return need_scratchpad;
}

private:
Expand Down
16 changes: 14 additions & 2 deletions src/cpu/x64/jit_uni_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,13 @@ struct jit_softmax_dense_kernel_t : jit_softmax_kernel_base_t,
, is_avx2_ne_xf16_(mayiuse(avx2_vnni_2) && !mayiuse(avx512_core)
&& (is_bf16_ || is_f16_))
// Note: must be aligned with pd_t::init()->init_scratchpad();
, need_scratchpad_(pd_->is_fwd() && dst_d_.data_type() != f32)
, need_scratchpad_(pd_->is_fwd() && dst_d_.data_type() != f32
&& /* !relaxed_acc */ !(
src_d_.data_type() == dst_d_.data_type()
&& !types::is_integral_dt(dst_d_.data_type())
&& utils::one_of(pd_->attr()->acc_mode_,
accumulation_mode::relaxed,
accumulation_mode::any)))
, use_ext_aux_vmms_(!is_logsoftmax_ && n_vregs > 16)
, axis_simd_full_(pd_->axis_size() / simd_w_)
, axis_simd_tail_(pd_->axis_size() % simd_w_) {
Expand Down Expand Up @@ -1495,7 +1501,13 @@ struct jit_softmax_strided_kernel_t : jit_softmax_kernel_base_t,
, src_d_(pd_->invariant_src_md())
, dst_d_(pd_->dst_md())
// Note: must be aligned with pd_t::init()->init_scratchpad();
, need_scratchpad_(pd_->is_fwd() && dst_d_.data_type() != f32)
, need_scratchpad_(pd_->is_fwd() && dst_d_.data_type() != f32
&& /* !relaxed_acc */ !(
src_d_.data_type() == dst_d_.data_type()
&& !types::is_integral_dt(dst_d_.data_type())
&& utils::one_of(pd_->attr()->acc_mode_,
accumulation_mode::relaxed,
accumulation_mode::any)))
, axis_size_(pd_->axis_size())
// `axis_stride_`, `axis_simd_full_` and `axis_simd_tail_` are only
// different pieces from the dense version.
Expand Down
12 changes: 11 additions & 1 deletion src/cpu/x64/jit_uni_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,17 @@ struct jit_uni_softmax_fwd_t : public primitive_t {

private:
void init_scratchpad() {
if (dst_md()->data_type != data_type::f32) {
const auto src_dt = src_md()->data_type;
const auto dst_dt = dst_md()->data_type;
// Relaxed accumulation allows to downconvert intermediate results
// directly from xf16 or xf8 to dst avoiding scratchpad memory.
const bool relaxed_acc = src_dt == dst_dt
&& !types::is_integral_dt(dst_dt)
&& utils::one_of(attr()->acc_mode_,
accumulation_mode::relaxed, accumulation_mode::any);
const bool need_scratchpad
= dst_dt != data_type::f32 && !relaxed_acc;
if (need_scratchpad) {
auto scratchpad = scratchpad_registry().registrar();
// When stride != 1, then each thread operates over simd at a
// time, thus, increased scratchpad size.
Expand Down
3 changes: 3 additions & 0 deletions tests/benchdnn/inputs/softmax/test_softmax_ci
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@
--dir=FWD_D,BWD_D
--sdt=f32,f64,bf16,f16
--ddt=f32,f64,bf16,f16
--attr-acc-mode=strict,relaxed
--batch=shapes_ci

--dir=FWD_I
--sdt=f32,bf16,f16,s8,u8
--ddt=s8,u8
--attr-acc-mode=strict,relaxed
--attr-scales=src:common:64+dst:common:0.5
--attr-post-ops=,add:f32:per_oc,mul:f32:per_tensor,linear:0.5:2
--batch=shapes_ci

--sdt=s8,u8
--ddt=f32,bf16,f16
--attr-acc-mode=strict,relaxed
--attr-scales=src:common:64
--attr-post-ops=,add:f32:per_oc,mul:f32:per_tensor,linear:0.5:2
--batch=shapes_ci
11 changes: 6 additions & 5 deletions tests/benchdnn/softmax/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,12 @@ void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const float trh = trh_f32;
#else
const bool is_strict_acc
= prb->attr.acc_mode == dnnl_accumulation_mode_strict;
// Relaxed fp16 computation can get an ulp difference with f32 ref values.
const float trh = is_flt_or_dbl || (trh_dt == dnnl_f16 && !is_strict_acc)
? trh_f32
: 0.f;
= prb->attr.acc_mode == dnnl_accumulation_mode_strict
|| prb->attr.acc_mode == dnnl_accumulation_mode_f32;
const bool is_relaxed_xf16
= !is_strict_acc && (trh_dt == dnnl_f16 || trh_dt == dnnl_bf16);
// Relaxed xf16 computation can get an ulp difference with f32 ref values.
const float trh = is_flt_or_dbl || is_relaxed_xf16 ? trh_f32 : 0.f;
#endif
cmp.set_threshold(trh);

Expand Down

0 comments on commit e03d3a8

Please sign in to comment.