From 4a3931ace03b7a6e98c3c9f267fc48a09fb2997d Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Fri, 14 Jul 2023 10:13:49 -0700 Subject: [PATCH] Use floatCloseAll for SparseAdagradTest/RowWiseSparseAdagradFusedTest (#1872) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1872 Replace `EXPECT_NEAR` with `floatCloseAll` to compare results with both absolute and relative tolerances Reviewed By: brad-mengchi, jianyuh, r-barnes, shintaro-iwasaki Differential Revision: D47382447 fbshipit-source-id: bd15ac7e241bd5d88ef6da1dad2b4f4765d84c95 --- test/QuantUtilsTest.cc | 82 +-------------------------- test/RowWiseSparseAdagradFusedTest.cc | 30 ++++------ test/SparseAdagradTest.cc | 35 ++++-------- test/TestUtils.cc | 75 ++++++++++++++++++++++++ test/TestUtils.h | 10 ++++ 5 files changed, 107 insertions(+), 125 deletions(-) diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index 196fdc959..4ec5c0dc0 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -13,8 +13,7 @@ #include #include -#include - +#include "TestUtils.h" #include "fbgemm/QuantUtils.h" #include "fbgemm/Types.h" #include "fbgemm/Utils.h" @@ -153,85 +152,6 @@ ::testing::AssertionResult isNear( return ::testing::AssertionFailure() << " Quantized results do not match"; } -// atol: absolute tolerance. <=0 means do not consider atol. -// rtol: relative tolerance. <=0 means do not consider rtol. -::testing::AssertionResult floatCloseAll( - vector& a, - vector& b, - float atol = std::numeric_limits::epsilon(), - float rtol = 0) { - std::stringstream ss; - bool match = true; - if (a.size() != b.size()) { - ss << " size mismatch "; - match = false; - } - if (match) { - for (size_t i = 0; i < a.size(); i++) { - const bool consider_absDiff = atol > 0; - const bool consider_relDiff = rtol > 0 && - fabs(a[i]) > std::numeric_limits::epsilon() && - fabs(b[i]) > std::numeric_limits::epsilon(); - - const float absDiff = fabs(a[i] - b[i]); - const float relDiff = absDiff / fabs(a[i]); - - if (consider_absDiff && consider_relDiff) { - if (absDiff > atol && relDiff > rtol) { - ss << " mismatch at (" << i << ") " << endl; - ss << "\t ref: " << a[i] << " test: " << b[i] << endl; - ss << "\t absolute diff: " << absDiff << " > " << atol << endl; - ss << "\t relative diff: " << relDiff << " > " << rtol << endl; - match = false; - } - } else if (consider_absDiff) { - if (absDiff > atol) { - ss << " mismatch at (" << i << ") " << endl; - ss << "\t ref: " << a[i] << " test: " << b[i] << endl; - ss << "\t absolute diff: " << absDiff << " > " << atol << endl; - match = false; - } - } else if (consider_relDiff) { - if (relDiff > rtol) { - ss << " mismatch at (" << i << ") " << endl; - ss << "\t ref: " << a[i] << " test: " << b[i] << endl; - ss << "\t relative diff: " << relDiff << " > " << rtol << endl; - match = false; - } - } - } - } - if (match) - return ::testing::AssertionSuccess(); - else - return ::testing::AssertionFailure() - << " results do not match. " << ss.str(); -} - -::testing::AssertionResult floatCloseAll( - vector& a, - vector& b, - float atol = std::numeric_limits::epsilon(), - float rtol = 0) { - vector b_float(b.size()); - const auto transform = [](float16 input) { return cpu_half2float(input); }; - std::transform(b.begin(), b.end(), b_float.begin(), transform); - return floatCloseAll(a, b_float, atol, rtol); -} - -::testing::AssertionResult floatCloseAll( - vector& a, - vector& b, - float atol = std::numeric_limits::epsilon(), - float rtol = 0) { - vector a_float(a.size()); - vector b_float(b.size()); - const auto transform = [](float16 input) { return cpu_half2float(input); }; - std::transform(a.begin(), a.end(), a_float.begin(), transform); - std::transform(b.begin(), b.end(), b_float.begin(), transform); - return floatCloseAll(a_float, b_float, atol, rtol); -} - template ::testing::AssertionResult isQEmbeddingClose( const vector& res_ref, diff --git a/test/RowWiseSparseAdagradFusedTest.cc b/test/RowWiseSparseAdagradFusedTest.cc index 89ae626df..0e295462a 100644 --- a/test/RowWiseSparseAdagradFusedTest.cc +++ b/test/RowWiseSparseAdagradFusedTest.cc @@ -16,6 +16,7 @@ #include #include "./EmbeddingSpMDMTestUtils.h" +#include "TestUtils.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/SimdUtils.h" #include "fbgemm/Utils.h" @@ -65,6 +66,8 @@ class RowWiseSparseAdagradFusedTest : public testing::TestWithParam> {}; }; // namespace +constexpr float DEFAULT_TOL = 1.0e-6; + INSTANTIATE_TEST_CASE_P( InstantiationName, RowWiseSparseAdagradFusedTest, @@ -280,29 +283,18 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) { << "return vals differ, reference is: " << success_ref << " ,fbgemm is: " << success; if (success) { - for (size_t i = 0; i < h.size(); ++i) { - EXPECT_EQ(h[i], h_ref[i]) - << "results for h differ at (" << i << ") reference: " << h_ref[i] - << ", FBGEMM: " << h[i] << " emb dim :" << embedding_dim; - } + EXPECT_TRUE(floatCloseAll(h, h_ref, DEFAULT_TOL, DEFAULT_TOL)); - for (size_t i = 0; i < w.size(); ++i) { - float w_, w_ref_; -// for fp16 the ref impl already does the conversion #if defined(__x86_64__) || defined(__i386__) || \ (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) - if (isWeightFp16) { - w_ = cpu_half2float(w_fp16[i]); - w_ref_ = cpu_half2float(w_fp16_ref[i]); - } else + if (isWeightFp16) { + // Set tol based on fbgemm_gpu tests + EXPECT_TRUE(floatCloseAll(w_fp16, w_fp16_ref, 1.0e-2, 1.0e-2)); + } else #endif - { - w_ = w[i]; - w_ref_ = w_ref[i]; - } - EXPECT_EQ(w_, w_ref_) - << "results for w differ at (" << i << ") reference: " << w_ref_ - << ", FBGEMM: " << w_ << " emb dim :" << embedding_dim; + { + // Set tol based on fbgemm_gpu tests + EXPECT_TRUE(floatCloseAll(w, w_ref, 1.0e-4, 1.0e-4)); } } } diff --git a/test/SparseAdagradTest.cc b/test/SparseAdagradTest.cc index a818a4f7e..3f4b193b3 100644 --- a/test/SparseAdagradTest.cc +++ b/test/SparseAdagradTest.cc @@ -13,6 +13,7 @@ #include +#include "TestUtils.h" #include "fbgemm/Fbgemm.h" #include "src/RefImplementations.h" @@ -48,7 +49,7 @@ class SparseAdagradTest : public testing::TestWithParam> {}; }; // namespace -constexpr float ABS_TOL = 1e-6; +constexpr float DEFAULT_TOL = 1.0e-6; // Test: INSTANTIATE_TEST_CASE_P( @@ -186,19 +187,11 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) { counter_halflife); } - EXPECT_NEAR(ret_fbgemm, ret_ref, ABS_TOL) + EXPECT_EQ(ret_fbgemm, ret_ref) << "return vals differ, reference is: " << ret_ref << " ,fbgemm is: " << ret_fbgemm; - for (size_t i = 0; i < h.size(); ++i) { - EXPECT_NEAR(h[i], h_ref[i], ABS_TOL) - << "results for h differ at (" << i << ") reference: " << h_ref[i] - << ", FBGEMM: " << h[i] << " emb dim :" << block_size; - } - for (size_t i = 0; i < w.size(); ++i) { - EXPECT_NEAR(w[i], w_ref[i], ABS_TOL) - << "results for h differ at (" << i << ") reference: " << w_ref[i] - << ", FBGEMM: " << w[i] << " emb dim :" << block_size; - } + EXPECT_TRUE(floatCloseAll(h, h_ref, DEFAULT_TOL, DEFAULT_TOL)); + EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL)); } } @@ -324,20 +317,12 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) { counter_halflife); } - EXPECT_NEAR(ret_fbgemm, ret_ref, ABS_TOL) + EXPECT_EQ(ret_fbgemm, ret_ref) << "return vals differ, reference is: " << ret_ref << " ,fbgemm is: " << ret_fbgemm; - for (size_t i = 0; i < h.size(); ++i) { - // Set the absolute tolerance of rowwise momentum to 1e-3 because it a - // product of square, add, div which the rounding error can be very high - EXPECT_NEAR(h[i], h_ref[i], 1e-3) - << "results for h differ at (" << i << ") reference: " << h_ref[i] - << ", FBGEMM: " << h[i] << " emb dim :" << block_size; - } - for (size_t i = 0; i < w.size(); ++i) { - EXPECT_NEAR(w[i], w_ref[i], ABS_TOL) - << "results for w differ at (" << i << ") reference: " << w_ref[i] - << ", FBGEMM: " << w[i] << " emb dim :" << block_size; - } + // Set the absolute tolerance of rowwise momentum to 1e-3 because it a + // product of square, add, div which the rounding error can be very high + EXPECT_TRUE(floatCloseAll(h, h_ref, 1.0e-3, 1.0e-3)); + EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL)); } } diff --git a/test/TestUtils.cc b/test/TestUtils.cc index 46a570cec..c945df572 100644 --- a/test/TestUtils.cc +++ b/test/TestUtils.cc @@ -87,4 +87,79 @@ check_all_zero_entries(const int32_t* test, int m, int n); template bool check_all_zero_entries(const uint8_t* test, int m, int n); +// atol: absolute tolerance. <=0 means do not consider atol. +// rtol: relative tolerance. <=0 means do not consider rtol. +template <> +::testing::AssertionResult floatCloseAll( + const std::vector& a, + const std::vector& b, + const float atol, + const float rtol) { + std::stringstream ss; + bool match = true; + if (a.size() != b.size()) { + ss << " size mismatch "; + match = false; + } + if (!match) { + return ::testing::AssertionFailure() + << " results do not match. " << ss.str(); + } + for (size_t i = 0; i < a.size(); i++) { + const bool consider_absDiff = atol > 0; + const bool consider_relDiff = rtol > 0 && + std::fabs(a[i]) > std::numeric_limits::epsilon() && + std::fabs(b[i]) > std::numeric_limits::epsilon(); + + const float absDiff = std::fabs(a[i] - b[i]); + const float relDiff = absDiff / std::fabs(a[i]); + + if (consider_absDiff && consider_relDiff) { + match = absDiff <= atol || relDiff <= rtol; + } else if (consider_absDiff) { + match = absDiff <= atol; + } else if (consider_relDiff) { + match = relDiff <= rtol; + } + if (!match) { + ss << " mismatch at (" << i << ") " << std::endl; + ss << "\t ref: " << a[i] << " test: " << b[i] << std::endl; + if (consider_absDiff) { + ss << "\t absolute diff: " << absDiff << " > " << atol << std::endl; + } + if (consider_relDiff) { + ss << "\t relative diff: " << relDiff << " > " << rtol << std::endl; + } + return ::testing::AssertionFailure() + << " results do not match. " << ss.str(); + } + } + return ::testing::AssertionSuccess(); +} + +template <> +::testing::AssertionResult floatCloseAll( + const std::vector& a, + const std::vector& b, + const float atol, + const float rtol) { + std::vector b_float(b.size()); + const auto transform = [](float16 input) { return cpu_half2float(input); }; + std::transform(b.begin(), b.end(), b_float.begin(), transform); + return floatCloseAll(a, b_float, atol, rtol); +} + +template <> +::testing::AssertionResult floatCloseAll( + const std::vector& a, + const std::vector& b, + const float atol, + const float rtol) { + std::vector a_float(a.size()); + std::vector b_float(b.size()); + const auto transform = [](float16 input) { return cpu_half2float(input); }; + std::transform(a.begin(), a.end(), a_float.begin(), transform); + std::transform(b.begin(), b.end(), b_float.begin(), transform); + return floatCloseAll(a_float, b_float, atol, rtol); +} } // namespace fbgemm diff --git a/test/TestUtils.h b/test/TestUtils.h index 43667899e..039940457 100644 --- a/test/TestUtils.h +++ b/test/TestUtils.h @@ -7,6 +7,7 @@ */ #pragma once +#include #include #include @@ -32,4 +33,13 @@ int compare_validate_buffers( template bool check_all_zero_entries(const T* test, int m, int n); +// atol: absolute tolerance. <=0 means do not consider atol. +// rtol: relative tolerance. <=0 means do not consider rtol. +template +::testing::AssertionResult floatCloseAll( + const std::vector& a, + const std::vector& b, + const float atol = std::numeric_limits::epsilon(), + const float rtol = 0); + } // namespace fbgemm