Skip to content

Commit

Permalink
Use floatCloseAll for SparseAdagradTest/RowWiseSparseAdagradFusedTest (
Browse files Browse the repository at this point in the history
…pytorch#1872)

Summary:
Pull Request resolved: pytorch#1872

Replace `EXPECT_NEAR` with `floatCloseAll` to compare results with
both absolute and relative tolerances

Reviewed By: brad-mengchi, jianyuh, r-barnes, shintaro-iwasaki

Differential Revision: D47382447

fbshipit-source-id: bd15ac7e241bd5d88ef6da1dad2b4f4765d84c95
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 14, 2023
1 parent 9e24d2a commit 4a3931a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 125 deletions.
82 changes: 1 addition & 81 deletions test/QuantUtilsTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
#include <sstream>
#include <type_traits>

#include <gtest/gtest.h>

#include "TestUtils.h"
#include "fbgemm/QuantUtils.h"
#include "fbgemm/Types.h"
#include "fbgemm/Utils.h"
Expand Down Expand Up @@ -153,85 +152,6 @@ ::testing::AssertionResult isNear(
return ::testing::AssertionFailure() << " Quantized results do not match";
}

// atol: absolute tolerance. <=0 means do not consider atol.
// rtol: relative tolerance. <=0 means do not consider rtol.
::testing::AssertionResult floatCloseAll(
vector<float>& a,
vector<float>& b,
float atol = std::numeric_limits<float>::epsilon(),
float rtol = 0) {
std::stringstream ss;
bool match = true;
if (a.size() != b.size()) {
ss << " size mismatch ";
match = false;
}
if (match) {
for (size_t i = 0; i < a.size(); i++) {
const bool consider_absDiff = atol > 0;
const bool consider_relDiff = rtol > 0 &&
fabs(a[i]) > std::numeric_limits<float>::epsilon() &&
fabs(b[i]) > std::numeric_limits<float>::epsilon();

const float absDiff = fabs(a[i] - b[i]);
const float relDiff = absDiff / fabs(a[i]);

if (consider_absDiff && consider_relDiff) {
if (absDiff > atol && relDiff > rtol) {
ss << " mismatch at (" << i << ") " << endl;
ss << "\t ref: " << a[i] << " test: " << b[i] << endl;
ss << "\t absolute diff: " << absDiff << " > " << atol << endl;
ss << "\t relative diff: " << relDiff << " > " << rtol << endl;
match = false;
}
} else if (consider_absDiff) {
if (absDiff > atol) {
ss << " mismatch at (" << i << ") " << endl;
ss << "\t ref: " << a[i] << " test: " << b[i] << endl;
ss << "\t absolute diff: " << absDiff << " > " << atol << endl;
match = false;
}
} else if (consider_relDiff) {
if (relDiff > rtol) {
ss << " mismatch at (" << i << ") " << endl;
ss << "\t ref: " << a[i] << " test: " << b[i] << endl;
ss << "\t relative diff: " << relDiff << " > " << rtol << endl;
match = false;
}
}
}
}
if (match)
return ::testing::AssertionSuccess();
else
return ::testing::AssertionFailure()
<< " results do not match. " << ss.str();
}

::testing::AssertionResult floatCloseAll(
vector<float>& a,
vector<float16>& b,
float atol = std::numeric_limits<float>::epsilon(),
float rtol = 0) {
vector<float> b_float(b.size());
const auto transform = [](float16 input) { return cpu_half2float(input); };
std::transform(b.begin(), b.end(), b_float.begin(), transform);
return floatCloseAll(a, b_float, atol, rtol);
}

::testing::AssertionResult floatCloseAll(
vector<float16>& a,
vector<float16>& b,
float atol = std::numeric_limits<float>::epsilon(),
float rtol = 0) {
vector<float> a_float(a.size());
vector<float> b_float(b.size());
const auto transform = [](float16 input) { return cpu_half2float(input); };
std::transform(a.begin(), a.end(), a_float.begin(), transform);
std::transform(b.begin(), b.end(), b_float.begin(), transform);
return floatCloseAll(a_float, b_float, atol, rtol);
}

template <typename T>
::testing::AssertionResult isQEmbeddingClose(
const vector<uint8_t>& res_ref,
Expand Down
30 changes: 11 additions & 19 deletions test/RowWiseSparseAdagradFusedTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <gtest/gtest.h>

#include "./EmbeddingSpMDMTestUtils.h"
#include "TestUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/SimdUtils.h"
#include "fbgemm/Utils.h"
Expand Down Expand Up @@ -65,6 +66,8 @@ class RowWiseSparseAdagradFusedTest : public testing::TestWithParam<tuple<
bool>> {};
}; // namespace

constexpr float DEFAULT_TOL = 1.0e-6;

INSTANTIATE_TEST_CASE_P(
InstantiationName,
RowWiseSparseAdagradFusedTest,
Expand Down Expand Up @@ -280,29 +283,18 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) {
<< "return vals differ, reference is: " << success_ref
<< " ,fbgemm is: " << success;
if (success) {
for (size_t i = 0; i < h.size(); ++i) {
EXPECT_EQ(h[i], h_ref[i])
<< "results for h differ at (" << i << ") reference: " << h_ref[i]
<< ", FBGEMM: " << h[i] << " emb dim :" << embedding_dim;
}
EXPECT_TRUE(floatCloseAll(h, h_ref, DEFAULT_TOL, DEFAULT_TOL));

for (size_t i = 0; i < w.size(); ++i) {
float w_, w_ref_;
// for fp16 the ref impl already does the conversion
#if defined(__x86_64__) || defined(__i386__) || \
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
if (isWeightFp16) {
w_ = cpu_half2float(w_fp16[i]);
w_ref_ = cpu_half2float(w_fp16_ref[i]);
} else
if (isWeightFp16) {
// Set tol based on fbgemm_gpu tests
EXPECT_TRUE(floatCloseAll(w_fp16, w_fp16_ref, 1.0e-2, 1.0e-2));
} else
#endif
{
w_ = w[i];
w_ref_ = w_ref[i];
}
EXPECT_EQ(w_, w_ref_)
<< "results for w differ at (" << i << ") reference: " << w_ref_
<< ", FBGEMM: " << w_ << " emb dim :" << embedding_dim;
{
// Set tol based on fbgemm_gpu tests
EXPECT_TRUE(floatCloseAll(w, w_ref, 1.0e-4, 1.0e-4));
}
}
}
Expand Down
35 changes: 10 additions & 25 deletions test/SparseAdagradTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <gtest/gtest.h>

#include "TestUtils.h"
#include "fbgemm/Fbgemm.h"
#include "src/RefImplementations.h"

Expand Down Expand Up @@ -48,7 +49,7 @@ class SparseAdagradTest
: public testing::TestWithParam<tuple<bool, int, bool, bool, bool>> {};
}; // namespace

constexpr float ABS_TOL = 1e-6;
constexpr float DEFAULT_TOL = 1.0e-6;

// Test:
INSTANTIATE_TEST_CASE_P(
Expand Down Expand Up @@ -186,19 +187,11 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
counter_halflife);
}

EXPECT_NEAR(ret_fbgemm, ret_ref, ABS_TOL)
EXPECT_EQ(ret_fbgemm, ret_ref)
<< "return vals differ, reference is: " << ret_ref
<< " ,fbgemm is: " << ret_fbgemm;
for (size_t i = 0; i < h.size(); ++i) {
EXPECT_NEAR(h[i], h_ref[i], ABS_TOL)
<< "results for h differ at (" << i << ") reference: " << h_ref[i]
<< ", FBGEMM: " << h[i] << " emb dim :" << block_size;
}
for (size_t i = 0; i < w.size(); ++i) {
EXPECT_NEAR(w[i], w_ref[i], ABS_TOL)
<< "results for h differ at (" << i << ") reference: " << w_ref[i]
<< ", FBGEMM: " << w[i] << " emb dim :" << block_size;
}
EXPECT_TRUE(floatCloseAll(h, h_ref, DEFAULT_TOL, DEFAULT_TOL));
EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL));
}
}

Expand Down Expand Up @@ -324,20 +317,12 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
counter_halflife);
}

EXPECT_NEAR(ret_fbgemm, ret_ref, ABS_TOL)
EXPECT_EQ(ret_fbgemm, ret_ref)
<< "return vals differ, reference is: " << ret_ref
<< " ,fbgemm is: " << ret_fbgemm;
for (size_t i = 0; i < h.size(); ++i) {
// Set the absolute tolerance of rowwise momentum to 1e-3 because it a
// product of square, add, div which the rounding error can be very high
EXPECT_NEAR(h[i], h_ref[i], 1e-3)
<< "results for h differ at (" << i << ") reference: " << h_ref[i]
<< ", FBGEMM: " << h[i] << " emb dim :" << block_size;
}
for (size_t i = 0; i < w.size(); ++i) {
EXPECT_NEAR(w[i], w_ref[i], ABS_TOL)
<< "results for w differ at (" << i << ") reference: " << w_ref[i]
<< ", FBGEMM: " << w[i] << " emb dim :" << block_size;
}
// Set the absolute tolerance of rowwise momentum to 1e-3 because it a
// product of square, add, div which the rounding error can be very high
EXPECT_TRUE(floatCloseAll(h, h_ref, 1.0e-3, 1.0e-3));
EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL));
}
}
75 changes: 75 additions & 0 deletions test/TestUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,79 @@ check_all_zero_entries<int32_t>(const int32_t* test, int m, int n);
template bool
check_all_zero_entries<uint8_t>(const uint8_t* test, int m, int n);

// atol: absolute tolerance. <=0 means do not consider atol.
// rtol: relative tolerance. <=0 means do not consider rtol.
template <>
::testing::AssertionResult floatCloseAll<float, float>(
const std::vector<float>& a,
const std::vector<float>& b,
const float atol,
const float rtol) {
std::stringstream ss;
bool match = true;
if (a.size() != b.size()) {
ss << " size mismatch ";
match = false;
}
if (!match) {
return ::testing::AssertionFailure()
<< " results do not match. " << ss.str();
}
for (size_t i = 0; i < a.size(); i++) {
const bool consider_absDiff = atol > 0;
const bool consider_relDiff = rtol > 0 &&
std::fabs(a[i]) > std::numeric_limits<float>::epsilon() &&
std::fabs(b[i]) > std::numeric_limits<float>::epsilon();

const float absDiff = std::fabs(a[i] - b[i]);
const float relDiff = absDiff / std::fabs(a[i]);

if (consider_absDiff && consider_relDiff) {
match = absDiff <= atol || relDiff <= rtol;
} else if (consider_absDiff) {
match = absDiff <= atol;
} else if (consider_relDiff) {
match = relDiff <= rtol;
}
if (!match) {
ss << " mismatch at (" << i << ") " << std::endl;
ss << "\t ref: " << a[i] << " test: " << b[i] << std::endl;
if (consider_absDiff) {
ss << "\t absolute diff: " << absDiff << " > " << atol << std::endl;
}
if (consider_relDiff) {
ss << "\t relative diff: " << relDiff << " > " << rtol << std::endl;
}
return ::testing::AssertionFailure()
<< " results do not match. " << ss.str();
}
}
return ::testing::AssertionSuccess();
}

template <>
::testing::AssertionResult floatCloseAll<float, float16>(
const std::vector<float>& a,
const std::vector<float16>& b,
const float atol,
const float rtol) {
std::vector<float> b_float(b.size());
const auto transform = [](float16 input) { return cpu_half2float(input); };
std::transform(b.begin(), b.end(), b_float.begin(), transform);
return floatCloseAll(a, b_float, atol, rtol);
}

template <>
::testing::AssertionResult floatCloseAll<float16, float16>(
const std::vector<float16>& a,
const std::vector<float16>& b,
const float atol,
const float rtol) {
std::vector<float> a_float(a.size());
std::vector<float> b_float(b.size());
const auto transform = [](float16 input) { return cpu_half2float(input); };
std::transform(a.begin(), a.end(), a_float.begin(), transform);
std::transform(b.begin(), b.end(), b_float.begin(), transform);
return floatCloseAll(a_float, b_float, atol, rtol);
}
} // namespace fbgemm
10 changes: 10 additions & 0 deletions test/TestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#pragma once
#include <gtest/gtest.h>
#include <cmath>
#include <vector>

Expand All @@ -32,4 +33,13 @@ int compare_validate_buffers(
template <typename T>
bool check_all_zero_entries(const T* test, int m, int n);

// atol: absolute tolerance. <=0 means do not consider atol.
// rtol: relative tolerance. <=0 means do not consider rtol.
template <typename a_T, typename b_T>
::testing::AssertionResult floatCloseAll(
const std::vector<a_T>& a,
const std::vector<b_T>& b,
const float atol = std::numeric_limits<float>::epsilon(),
const float rtol = 0);

} // namespace fbgemm

0 comments on commit 4a3931a

Please sign in to comment.