Skip to content

Commit

Permalink
Enable TBE CPU bf16 output support on Mac and Windows platforms (pyto…
Browse files Browse the repository at this point in the history
…rch#1839)

Summary:
Pull Request resolved: pytorch#1839

Remove the conditional macro `#if !defined(__APPLE__) && !defined(_WIN32)` to enable TBE CPU bf16 output support on MacOS and Windows

Reviewed By: sryap

Differential Revision: D46806887

fbshipit-source-id: 29ff3b075a5540f591b25715fc675bc68f47374d
  • Loading branch information
Wei Su authored and facebook-github-bot committed Jun 26, 2023
1 parent 8979de3 commit bbcac8b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 36 deletions.
16 changes: 0 additions & 16 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,25 +398,21 @@ GenEmbeddingSpMDMNBitLookup<
x86::Ymm mask_vreg; // mask for avx2
x86::Xmm mask2_vreg;
x86::Xmm mask_fp16_vreg;
#if !defined(__APPLE__) && !defined(_WIN32)
vec_reg_t ones_vreg;
#endif

// We need 2 vec registers for 1. scale 2. bias
--unroll_factor;
scale_vreg = vec_reg_t(unroll_factor);
--unroll_factor;
bias_vreg = vec_reg_t(unroll_factor);

#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
--unroll_factor;
ones_vreg = vec_reg_t(unroll_factor);
a->mov(scratchReg2_, 1 << 15);
a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0);
a->vpbroadcastd(ones_vreg, ones_vreg.xmm());
}
#endif

--unroll_factor;
src_vreg = vec_reg_t(unroll_factor);
Expand Down Expand Up @@ -883,19 +879,15 @@ GenEmbeddingSpMDMNBitLookup<
} else {
// 16-bit output
if (instSet == inst_set_t::avx2) {
#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
a->vpaddd(out_vreg, out_vreg, ones_vreg);
a->vpsrld(out_vreg, out_vreg, 16);
a->vpackusdw(out_vreg, out_vreg, out_vreg);
a->vpermq(out_vreg, out_vreg, 0xd8);
} else {
#endif
// round nearest with no exception
a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8);
#if !defined(__APPLE__) && !defined(_WIN32)
}
#endif
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
if (remainder > 1) {
a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm());
Expand All @@ -918,31 +910,23 @@ GenEmbeddingSpMDMNBitLookup<
}
} else {
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
// bf16
a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg);
a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16);
a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg);
} else {
#endif
a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8);
#if !defined(__APPLE__) && !defined(_WIN32)
}
#endif
} else {
#if !defined(__APPLE__) && !defined(_WIN32)
if (is_bf16_out) {
// bf16
a->vpaddd(out_vreg, out_vreg, ones_vreg);
a->vpsrld(out_vreg, out_vreg, 16);
a->vpmovdw(dst_addr, out_vreg);
} else {
#endif
a->vcvtps2ph(dst_addr, out_vreg, 8);
#if !defined(__APPLE__) && !defined(_WIN32)
}
#endif
}
}
}
Expand Down
73 changes: 53 additions & 20 deletions test/EmbeddingSpMDMNBitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class FusedNBitRowwiseEmbeddingLookupTest : public testing::TestWithParam<tuple<
int,
int,
EmbeddingSpMDMWeightChoice,
EmbeddingSpMDMCornerCase>> {};
EmbeddingSpMDMCornerCase,
EmbeddingSpMDMDtypeChoice>> {};
}; // namespace

INSTANTIATE_TEST_CASE_P(
Expand All @@ -74,7 +75,8 @@ INSTANTIATE_TEST_CASE_P(
NONE,
EMPTY_INDICES,
OUT_OF_BOUND_INDICES,
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM)));
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM),
::testing::Values(FLOAT, FLOAT16, BFLOAT16)));

TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
vector<vector<int>> inputs(GetInputs_());
Expand All @@ -86,19 +88,20 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
bool isOffset64b = bool_dist(generator);
bool normalize_by_lengths = bool_dist(generator);
bool use_offsets = bool_dist(generator);
bool is_output_float = bool_dist(generator);
bool scale_bias_last = bool_dist(generator);
bool test_thread_local = bool_dist(generator);
int bit_rate, prefetch;
EmbeddingSpMDMWeightChoice weight_choice;
EmbeddingSpMDMCornerCase corner_case;
tie(bit_rate, prefetch, weight_choice, corner_case) = GetParam();
EmbeddingSpMDMDtypeChoice out_type;
tie(bit_rate, prefetch, weight_choice, corner_case, out_type) = GetParam();
bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED;
bool use_weight = weight_choice != UNWEIGHTED;
bool is_bf16_out = out_type == BFLOAT16;

if (corner_case != NONE || weight_choice == POSITIONAL_WEIGHTED) {
// Check corner case only for subset of tests.
if (normalize_by_lengths || !is_output_float || !scale_bias_last ||
if (normalize_by_lengths || out_type != FLOAT || !scale_bias_last ||
test_thread_local) {
return;
}
Expand Down Expand Up @@ -171,11 +174,14 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
vector<float> output_ref(output_size_wo_sentries + num_sentries);
vector<float> output(output_ref.size());
vector<float16> output_ref_fp16(output.size()), output_fp16(output.size());
vector<bfloat16> output_ref_bf16(output.size()), output_bf16(output.size());
for (size_t i = output_size_wo_sentries; i < output.size(); ++i) {
output_ref[i] = sentry_value;
output[i] = sentry_value;
output_ref_fp16[i] = cpu_float2half_rn(sentry_value);
output_fp16[i] = cpu_float2half_rn(sentry_value);
FloatToBfloat16_ref(&sentry_value, &output_ref_bf16[i], 1);
FloatToBfloat16_ref(&sentry_value, &output_bf16[i], 1);
}

bool success, success_ref;
Expand Down Expand Up @@ -205,7 +211,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
use_offsets, \
/*output_stride=*/-1, \
/*input_stride=*/-1, \
scale_bias_last); \
scale_bias_last, \
is_bf16_out); \
\
auto kernel = GenerateEmbeddingSpMDMNBitWithStrides< \
IndexType, \
Expand All @@ -221,7 +228,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
use_offsets, \
/*output_stride=*/-1, \
/*input_stride=*/-1, \
scale_bias_last); \
scale_bias_last, \
is_bf16_out); \
success = kernel( \
batch_size, \
lengths_sum, \
Expand Down Expand Up @@ -263,7 +271,7 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
}

#define TEST_OUT_TYPE(indices, offsets_or_lengths, IndexType, OffsetType) \
if (is_output_float) { \
if (out_type == FLOAT) { \
TEST_THREAD_LOCAL( \
indices, \
offsets_or_lengths, \
Expand All @@ -272,6 +280,15 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
IndexType, \
OffsetType, \
float); \
} else if (out_type == BFLOAT16) { \
TEST_THREAD_LOCAL( \
indices, \
offsets_or_lengths, \
output_ref_bf16, \
output_bf16, \
IndexType, \
OffsetType, \
bfloat16); \
} else { \
TEST_THREAD_LOCAL( \
indices, \
Expand Down Expand Up @@ -308,24 +325,40 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) {
EXPECT_EQ(success, false);
}

auto get_actual = [&](int offset) {
if (out_type == FLOAT) {
return output[offset];
} else if (out_type == BFLOAT16) {
return cpu_bf162float(output[offset]);
} else {
return cpu_half2float(output[offset]);
}
};

auto get_expected = [&](int offset) {
if (out_type == FLOAT) {
return output_ref[offset];
} else if (out_type == BFLOAT16) {
return cpu_bf162float(output_ref[offset]);
} else {
return cpu_half2float(output_ref[offset]);
}
};

if (success) {
for (size_t i = 0; i < output.size(); ++i) {
float actual =
is_output_float ? output[i] : cpu_half2float(output_fp16[i]);
float expected = is_output_float ? output_ref[i]
: cpu_half2float(output_ref_fp16[i]);
float actual = get_actual(i);
float expected = get_expected(i);
EXPECT_EQ(actual, expected)
<< "results differ at (" << i << ") reference: " << expected
<< ", FBGEMM: " << actual << " emb dim :" << embedding_dim;
}
for (int offset = output_size_wo_sentries;
offset < output_size_wo_sentries + num_sentries;
++offset) {
float actual = is_output_float ? output[offset]
: cpu_half2float(output_fp16[offset]);
float expected = is_output_float
? output_ref[offset]
: cpu_half2float(output_ref_fp16[offset]);
float actual = get_actual(offset);
float expected = get_expected(offset);
EXPECT_EQ(actual, expected)
<< "results differ at (" << offset << ") reference: " << expected
<< ", FBGEMM: " << actual << " emb dim :" << embedding_dim;
Expand All @@ -344,17 +377,17 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
bool isOffset64b = bool_dist(generator);
bool normalize_by_lengths = bool_dist(generator);
bool use_offsets = bool_dist(generator);
bool is_output_float = bool_dist(generator);
bool scale_bias_last = bool_dist(generator);

int bit_rate, prefetch;
EmbeddingSpMDMWeightChoice weight_choice;
EmbeddingSpMDMCornerCase corner_case;
tie(bit_rate, prefetch, weight_choice, corner_case) = GetParam();
EmbeddingSpMDMDtypeChoice out_type;
tie(bit_rate, prefetch, weight_choice, corner_case, out_type) = GetParam();
bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED;
bool use_weight = weight_choice != UNWEIGHTED;

if (!is_output_float || !scale_bias_last) {
if (out_type != FLOAT || !scale_bias_last) {
return;
}

Expand Down

0 comments on commit bbcac8b

Please sign in to comment.