Skip to content

Commit

Permalink
Add __obj_flatten__ for AtomicCounter (pytorch#2697)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2081

Pull Request resolved: pytorch#2697

Fix the test failure that blocks D57991002

Reviewed By: angelayi

Differential Revision: D58154762

fbshipit-source-id: c2d5eb6b4840741f349241cea6b216b092a72352
  • Loading branch information
ydwu4 authored and facebook-github-bot committed Jun 7, 2024
1 parent c7720e8 commit 3889721
Showing 1 changed file with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ class AtomicCounter : public torch::jit::CustomClassHolder {
counter_ = val;
}

std::tuple<std::tuple<std::string, int64_t>> __obj_flatten__() {
return std::make_tuple(
std::make_tuple(std::string("counter_"), counter_.load()));
}

std::string serialize() const {
std::ostringstream oss;
oss << counter_;
Expand All @@ -465,6 +470,7 @@ static auto AtomicCounterRegistry =
.def("reset", &AtomicCounter::reset)
.def("get", &AtomicCounter::get)
.def("set", &AtomicCounter::set)
.def("__obj_flatten__", &AtomicCounter::__obj_flatten__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<AtomicCounter>& self) -> std::string {
Expand Down

0 comments on commit 3889721

Please sign in to comment.