From bbcac8bcb621580c8767819206e34bd826f7a53c Mon Sep 17 00:00:00 2001 From: Wei Su Date: Sun, 25 Jun 2023 23:43:17 -0700 Subject: [PATCH] Enable TBE CPU bf16 output support on Mac and Windows platforms (#1839) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1839 Remove the conditional macro `#if !defined(__APPLE__) && !defined(_WIN32)` to enable TBE CPU bf16 output support on MacOS and Windows Reviewed By: sryap Differential Revision: D46806887 fbshipit-source-id: 29ff3b075a5540f591b25715fc675bc68f47374d --- src/EmbeddingSpMDMNBit.cc | 16 -------- test/EmbeddingSpMDMNBitTest.cc | 73 ++++++++++++++++++++++++---------- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index a7c6fb0af..174f02aec 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -398,9 +398,7 @@ GenEmbeddingSpMDMNBitLookup< x86::Ymm mask_vreg; // mask for avx2 x86::Xmm mask2_vreg; x86::Xmm mask_fp16_vreg; -#if !defined(__APPLE__) && !defined(_WIN32) vec_reg_t ones_vreg; -#endif // We need 2 vec registers for 1. scale 2. bias --unroll_factor; @@ -408,7 +406,6 @@ GenEmbeddingSpMDMNBitLookup< --unroll_factor; bias_vreg = vec_reg_t(unroll_factor); -#if !defined(__APPLE__) && !defined(_WIN32) if (is_bf16_out) { --unroll_factor; ones_vreg = vec_reg_t(unroll_factor); @@ -416,7 +413,6 @@ GenEmbeddingSpMDMNBitLookup< a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0); a->vpbroadcastd(ones_vreg, ones_vreg.xmm()); } -#endif --unroll_factor; src_vreg = vec_reg_t(unroll_factor); @@ -883,19 +879,15 @@ GenEmbeddingSpMDMNBitLookup< } else { // 16-bit output if (instSet == inst_set_t::avx2) { -#if !defined(__APPLE__) && !defined(_WIN32) if (is_bf16_out) { a->vpaddd(out_vreg, out_vreg, ones_vreg); a->vpsrld(out_vreg, out_vreg, 16); a->vpackusdw(out_vreg, out_vreg, out_vreg); a->vpermq(out_vreg, out_vreg, 0xd8); } else { -#endif // round nearest with no exception a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8); -#if !defined(__APPLE__) && !defined(_WIN32) } -#endif if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (remainder > 1) { a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm()); @@ -918,31 +910,23 @@ GenEmbeddingSpMDMNBitLookup< } } else { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { -#if !defined(__APPLE__) && !defined(_WIN32) if (is_bf16_out) { // bf16 a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg); a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16); a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg); } else { -#endif a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); -#if !defined(__APPLE__) && !defined(_WIN32) } -#endif } else { -#if !defined(__APPLE__) && !defined(_WIN32) if (is_bf16_out) { // bf16 a->vpaddd(out_vreg, out_vreg, ones_vreg); a->vpsrld(out_vreg, out_vreg, 16); a->vpmovdw(dst_addr, out_vreg); } else { -#endif a->vcvtps2ph(dst_addr, out_vreg, 8); -#if !defined(__APPLE__) && !defined(_WIN32) } -#endif } } } diff --git a/test/EmbeddingSpMDMNBitTest.cc b/test/EmbeddingSpMDMNBitTest.cc index 08be4e722..1106289c1 100644 --- a/test/EmbeddingSpMDMNBitTest.cc +++ b/test/EmbeddingSpMDMNBitTest.cc @@ -57,7 +57,8 @@ class FusedNBitRowwiseEmbeddingLookupTest : public testing::TestWithParam> {}; + EmbeddingSpMDMCornerCase, + EmbeddingSpMDMDtypeChoice>> {}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -74,7 +75,8 @@ INSTANTIATE_TEST_CASE_P( NONE, EMPTY_INDICES, OUT_OF_BOUND_INDICES, - UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM))); + UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM), + ::testing::Values(FLOAT, FLOAT16, BFLOAT16))); TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { vector> inputs(GetInputs_()); @@ -86,19 +88,20 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { bool isOffset64b = bool_dist(generator); bool normalize_by_lengths = bool_dist(generator); bool use_offsets = bool_dist(generator); - bool is_output_float = bool_dist(generator); bool scale_bias_last = bool_dist(generator); bool test_thread_local = bool_dist(generator); int bit_rate, prefetch; EmbeddingSpMDMWeightChoice weight_choice; EmbeddingSpMDMCornerCase corner_case; - tie(bit_rate, prefetch, weight_choice, corner_case) = GetParam(); + EmbeddingSpMDMDtypeChoice out_type; + tie(bit_rate, prefetch, weight_choice, corner_case, out_type) = GetParam(); bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED; bool use_weight = weight_choice != UNWEIGHTED; + bool is_bf16_out = out_type == BFLOAT16; if (corner_case != NONE || weight_choice == POSITIONAL_WEIGHTED) { // Check corner case only for subset of tests. - if (normalize_by_lengths || !is_output_float || !scale_bias_last || + if (normalize_by_lengths || out_type != FLOAT || !scale_bias_last || test_thread_local) { return; } @@ -171,11 +174,14 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { vector output_ref(output_size_wo_sentries + num_sentries); vector output(output_ref.size()); vector output_ref_fp16(output.size()), output_fp16(output.size()); + vector output_ref_bf16(output.size()), output_bf16(output.size()); for (size_t i = output_size_wo_sentries; i < output.size(); ++i) { output_ref[i] = sentry_value; output[i] = sentry_value; output_ref_fp16[i] = cpu_float2half_rn(sentry_value); output_fp16[i] = cpu_float2half_rn(sentry_value); + FloatToBfloat16_ref(&sentry_value, &output_ref_bf16[i], 1); + FloatToBfloat16_ref(&sentry_value, &output_bf16[i], 1); } bool success, success_ref; @@ -205,7 +211,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { use_offsets, \ /*output_stride=*/-1, \ /*input_stride=*/-1, \ - scale_bias_last); \ + scale_bias_last, \ + is_bf16_out); \ \ auto kernel = GenerateEmbeddingSpMDMNBitWithStrides< \ IndexType, \ @@ -221,7 +228,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { use_offsets, \ /*output_stride=*/-1, \ /*input_stride=*/-1, \ - scale_bias_last); \ + scale_bias_last, \ + is_bf16_out); \ success = kernel( \ batch_size, \ lengths_sum, \ @@ -263,7 +271,7 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { } #define TEST_OUT_TYPE(indices, offsets_or_lengths, IndexType, OffsetType) \ - if (is_output_float) { \ + if (out_type == FLOAT) { \ TEST_THREAD_LOCAL( \ indices, \ offsets_or_lengths, \ @@ -272,6 +280,15 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { IndexType, \ OffsetType, \ float); \ + } else if (out_type == BFLOAT16) { \ + TEST_THREAD_LOCAL( \ + indices, \ + offsets_or_lengths, \ + output_ref_bf16, \ + output_bf16, \ + IndexType, \ + OffsetType, \ + bfloat16); \ } else { \ TEST_THREAD_LOCAL( \ indices, \ @@ -308,12 +325,31 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) { EXPECT_EQ(success, false); } + + auto get_actual = [&](int offset) { + if (out_type == FLOAT) { + return output[offset]; + } else if (out_type == BFLOAT16) { + return cpu_bf162float(output[offset]); + } else { + return cpu_half2float(output[offset]); + } + }; + + auto get_expected = [&](int offset) { + if (out_type == FLOAT) { + return output_ref[offset]; + } else if (out_type == BFLOAT16) { + return cpu_bf162float(output_ref[offset]); + } else { + return cpu_half2float(output_ref[offset]); + } + }; + if (success) { for (size_t i = 0; i < output.size(); ++i) { - float actual = - is_output_float ? output[i] : cpu_half2float(output_fp16[i]); - float expected = is_output_float ? output_ref[i] - : cpu_half2float(output_ref_fp16[i]); + float actual = get_actual(i); + float expected = get_expected(i); EXPECT_EQ(actual, expected) << "results differ at (" << i << ") reference: " << expected << ", FBGEMM: " << actual << " emb dim :" << embedding_dim; @@ -321,11 +357,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) { for (int offset = output_size_wo_sentries; offset < output_size_wo_sentries + num_sentries; ++offset) { - float actual = is_output_float ? output[offset] - : cpu_half2float(output_fp16[offset]); - float expected = is_output_float - ? output_ref[offset] - : cpu_half2float(output_ref_fp16[offset]); + float actual = get_actual(offset); + float expected = get_expected(offset); EXPECT_EQ(actual, expected) << "results differ at (" << offset << ") reference: " << expected << ", FBGEMM: " << actual << " emb dim :" << embedding_dim; @@ -344,17 +377,17 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { bool isOffset64b = bool_dist(generator); bool normalize_by_lengths = bool_dist(generator); bool use_offsets = bool_dist(generator); - bool is_output_float = bool_dist(generator); bool scale_bias_last = bool_dist(generator); int bit_rate, prefetch; EmbeddingSpMDMWeightChoice weight_choice; EmbeddingSpMDMCornerCase corner_case; - tie(bit_rate, prefetch, weight_choice, corner_case) = GetParam(); + EmbeddingSpMDMDtypeChoice out_type; + tie(bit_rate, prefetch, weight_choice, corner_case, out_type) = GetParam(); bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED; bool use_weight = weight_choice != UNWEIGHTED; - if (!is_output_float || !scale_bias_last) { + if (out_type != FLOAT || !scale_bias_last) { return; }