Skip to content

Commit

Permalink
bug fix for CPU pruned weighted TBE
Browse files Browse the repository at this point in the history
Summary:
When we see a pruned row we also need to skip the corresponding weight.
D36461772 fixed EmbeddingSpMDMNBit.cc but didn't EmbeddingSpMDM.cc

Added unit tests for both 8bit and Nbit cases.
Fixed random number generations in unit tests that were generating deterministic random numbers.

Reviewed By: jianyuh

Differential Revision: D54163836

fbshipit-source-id: e23ff46e8079d6935a9f6133cf11127d1f521168
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Feb 25, 2024
1 parent dc20966 commit 2bba5cc
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 23 deletions.
3 changes: 3 additions & 0 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,9 @@ GenEmbeddingSpMDMLookup<
}
a->jne(ValidIndexLabel);
a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType)));
if (has_weight) {
a->add(weights, static_cast<asmjit::Imm>(sizeof(float)));
}
a->jmp(LoopDataIndexBegin);
a->bind(ValidIndexLabel);
}
Expand Down
18 changes: 12 additions & 6 deletions src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1287,8 +1287,13 @@ bool EmbeddingSpMDM_ref(
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
for (int i = 0; i < len; ++i, ++current) {
int64_t idx = indices[current];
if (!scale_bias_last && idx == -1) {
// When scale_bias_last == false, assume this is for table batched
// embedding (TBE) that can get -1 for pruned rows.
continue;
}
if (idx < 0 || idx >= data_size) {
return false;
}
Expand Down Expand Up @@ -1319,8 +1324,6 @@ bool EmbeddingSpMDM_ref(
(scale_bias_last ? 0 : 2 * sizeof(float16))],
buf[j] + bias);
}

++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
Expand Down Expand Up @@ -1450,8 +1453,13 @@ bool EmbeddingSpMDMNBit_ref(
if (current + len > index_size) {
return false;
}
for (int i = 0; i < len; ++i) {
for (int i = 0; i < len; ++i, ++current) {
int64_t idx = indices[current];
if (!scale_bias_last && idx == -1) {
// When scale_bias_last == false, assume this is for table batched
// embedding (TBE) that can get -1 for pruned rows.
continue;
}
if (idx < 0 || idx >= data_size) {
return false;
}
Expand All @@ -1478,8 +1486,6 @@ bool EmbeddingSpMDMNBit_ref(

buf[j] = std::fma(scale, quantized, buf[j] + bias);
}

++current;
}
if (normalize_by_lengths && len) {
float scale = 1.f / len;
Expand Down
19 changes: 17 additions & 2 deletions test/EmbeddingSpMDM8BitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ INSTANTIATE_TEST_CASE_P(
TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
vector<vector<int>> inputs(GetInputs_());

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<> bool_dist(0, 1);

bool isIndex64b = bool_dist(generator);
Expand Down Expand Up @@ -158,6 +159,19 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
const int32_t* offsets_or_lengths_32 =
(use_offsets ? offsets_32 : lengths_32).data();

if (!scale_bias_last && use_weight) {
// When scale_bias_last == false, assume this is for table batched
// embedding (TBE) that can get -1 for pruned rows.
uniform_int_distribution<int> pruned_indices_distribution(
0, indices.size() - 1);
constexpr float PRUNED_INDICES_PROPORTION = 0.1;
for (int i = 0; i < indices.size() * PRUNED_INDICES_PROPORTION; ++i) {
auto idx = pruned_indices_distribution(generator);
indices[idx] = -1;
indices_32[idx] = -1;
}
}

// Sentries at the end to make sure masking is done correctly not to write
// out of bounds.
constexpr int num_sentries = 10;
Expand Down Expand Up @@ -308,7 +322,8 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
vector<vector<int>> inputs(GetInputs_());

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<> bool_dist(0, 1);

bool isIndex64b = bool_dist(generator);
Expand Down
19 changes: 17 additions & 2 deletions test/EmbeddingSpMDMNBitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ INSTANTIATE_TEST_CASE_P(
TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
vector<vector<int>> inputs(GetInputs_());

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<> bool_dist(0, 1);

bool isIndex64b = bool_dist(generator);
Expand Down Expand Up @@ -166,6 +167,19 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
const int32_t* offsets_or_lengths_32 =
(use_offsets ? offsets_32 : lengths_32).data();

if (!scale_bias_last && use_weight) {
// When scale_bias_last == false, assume this is for table batched
// embedding (TBE) that can get -1 for pruned rows.
uniform_int_distribution<int> pruned_indices_distribution(
0, indices.size() - 1);
constexpr float PRUNED_INDICES_PROPORTION = 0.1;
for (int i = 0; i < indices.size() * PRUNED_INDICES_PROPORTION; ++i) {
auto idx = pruned_indices_distribution(generator);
indices[idx] = -1;
indices_32[idx] = -1;
}
}

// Sentries at the end to make sure masking is done correctly not to write
// out of bounds.
constexpr int num_sentries = 10;
Expand Down Expand Up @@ -370,7 +384,8 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, basicTest) {
TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
vector<vector<int>> inputs(GetInputs_());

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<> bool_dist(0, 1);

bool isIndex64b = bool_dist(generator);
Expand Down
6 changes: 4 additions & 2 deletions test/EmbeddingSpMDMTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ INSTANTIATE_TEST_CASE_P(
TEST_P(EmbeddingSpMDMTest, basicTest) {
vector<vector<int>> inputs(GetInputs_());

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<> bool_dist(0, 1);

bool isIndex64b = bool_dist(generator);
Expand Down Expand Up @@ -449,7 +450,8 @@ TEST_P(EmbeddingSpMDMTest, basicTest) {
TEST_P(rowwiseSparseEmbeddingSpMDMTest, rowwiseSparseTest) {
vector<vector<int>> inputs(GetInputs_());

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<> bool_dist(0, 1);

bool isFp16 = bool_dist(generator);
Expand Down
6 changes: 4 additions & 2 deletions test/EmbeddingSpMDMTestUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ int GenerateLengthsIndicesWeights(
int average_len,
EmbeddingSpMDMCornerCase corner_case) {
// Generate lengths
default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<int> length_distribution(
1, std::min(2 * average_len + 1, num_rows));
lengths.resize(batch_size);
Expand Down Expand Up @@ -89,7 +90,8 @@ int CreateMappingTableForRowWiseSparsity(
vector<int32_t>& mapping_table,
int num_rows,
float sparsity) {
default_random_engine generator;
random_device r;
default_random_engine generator(r());
mapping_table.resize(num_rows);
bernoulli_distribution row_prune_dist(sparsity);
int num_compressed_rows = 0;
Expand Down
3 changes: 2 additions & 1 deletion test/I64Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class Int64GemmTest : public testing::Test {
protected:
vector<array<int, 3>> GenParams() {
vector<array<int, 3>> shapes;
default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<int> dist_dim(1, 128);
for (int i = 0; i < 256; ++i) {
shapes.push_back(
Expand Down
3 changes: 2 additions & 1 deletion test/I8SpmdmTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ TEST_P(fbgemmSPMDMTest, TestsSpMDM) {
}

// deterministic random number
default_random_engine eng;
random_device r;
default_random_engine eng(r());
binomial_distribution<> per_col_nnz_dist(K_adjusted, density);
uniform_int_distribution<> value_dist(
numeric_limits<int8_t>::min() / 2, numeric_limits<int8_t>::max() / 2);
Expand Down
3 changes: 2 additions & 1 deletion test/Im2ColFusedRequantizeTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ void SConvTest() {

float density = 0.0001f;
CompressedSparseColumn B_csc(KDimPerGroup, conv_p.G * NDim);
default_random_engine eng;
random_device r;
default_random_engine eng(r());
binomial_distribution<> per_col_nnz_dist(KDimPerGroup, density);

// TODO: refactor CSC construction as a reusable function
Expand Down
3 changes: 2 additions & 1 deletion test/PackedRequantizeAcc16Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ TEST_P(fbgemmu8s8acc16WithQuantGranularityTest, SpMDMTest) {
// Make sure density is big enough. Otherwise, we're not really testing
// spmdm.
// deterministic random number
default_random_engine eng;
random_device r;
default_random_engine eng(r());
binomial_distribution<> per_col_nnz_dist(k_per_group, density);

vector<int> row_indices(k_per_group);
Expand Down
3 changes: 2 additions & 1 deletion test/RowWiseSparseAdagradFusedTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) {
h(num_rows), h_ref(num_rows),
g(batch_size * (use_grad_stride ? grad_stride : embedding_dim));
vector<float16> w_fp16(w.size()), w_fp16_ref(w.size());
default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_real_distribution<float> values_gen(0, 2);
for (size_t i = 0; i < w.size(); ++i) {
w_ref[i] = w[i] = values_gen(generator);
Expand Down
6 changes: 4 additions & 2 deletions test/SparseAdagradTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
vector<float> h_ref(param_size);
vector<float> w_ref(param_size);

default_random_engine generator;
random_device r;
default_random_engine generator(r());

normal_distribution<float> h_w_distribution;

Expand Down Expand Up @@ -217,7 +218,8 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
vector<float> h_ref(param_size);
vector<float> w_ref(param_size);

default_random_engine generator;
random_device r;
default_random_engine generator(r());
uniform_real_distribution<float> values_gen(0, 2);
for (int i = 0; i < param_size; i++) {
h_ref[i] = h[i] = values_gen(generator);
Expand Down
3 changes: 2 additions & 1 deletion test/SparseDenseMMFP32Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ using namespace fbgemm;

namespace {
uniform_int_distribution<int> dist_dim(1, 256);
default_random_engine generator;
random_device r;
default_random_engine generator(r());

class SparseDenseTest : public testing::Test {
protected:
Expand Down
3 changes: 2 additions & 1 deletion test/TransposeTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ TEST(TransposeTest, TransposeTest) {
// Generate shapes to test
vector<tuple<int, int, int, int>> shapes;
uniform_int_distribution<int> dist(0, 32);
default_random_engine generator;
random_device r;
default_random_engine generator(r());
for (int i = 0; i < 1024; ++i) {
int m = dist(generator);
int n = dist(generator);
Expand Down

0 comments on commit 2bba5cc

Please sign in to comment.