Skip to content

Commit

Permalink
add lock for l2 cache set/get (pytorch#3153)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#248

Pull Request resolved: pytorch#3153

In non pipelining mode, the sequence is
- get_cuda(): L2 read and insert L2 cache misses into queue for bg L2 write
- L1 cache eviction: insert into bg queue for L2 write
- ScratchPad update: insert into bg queue for L2 write

in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen after SP update

in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1 cache eviction pipeline case, SP bwd update could happen in parallel with L2 read
lock is used for l2 cache to do read / write exclusively

add unittest to capture L2 cache functionality and the cases discussed above

Reviewed By: q10

Differential Revision: D63010906

fbshipit-source-id: 3951ce138acb53da4f7aba01a03c46409a6fc630
  • Loading branch information
Joe Wang authored and facebook-github-bot committed Sep 19, 2024
1 parent 12710ef commit 904a1c6
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 14 deletions.
20 changes: 18 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,8 +1850,8 @@ def _report_l2_cache_perf_stats(self) -> None:
self.step, stats_reporter.report_interval # pyre-ignore
)

if len(l2_cache_perf_stats) != 13:
logging.error("l2 perf stats should have 13 elements")
if len(l2_cache_perf_stats) != 15:
logging.error("l2 perf stats should have 15 elements")
return

num_cache_misses = l2_cache_perf_stats[0]
Expand All @@ -1869,6 +1869,9 @@ def _report_l2_cache_perf_stats(self) -> None:
l2_cache_free_bytes = l2_cache_perf_stats[11]
l2_cache_capacity = l2_cache_perf_stats[12]

set_cache_lock_wait_duration = l2_cache_perf_stats[13]
get_cache_lock_wait_duration = l2_cache_perf_stats[14]

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.l2_num_cache_misses_stats_name,
Expand Down Expand Up @@ -1944,6 +1947,19 @@ def _report_l2_cache_perf_stats(self) -> None:
time_unit="us",
)

stats_reporter.report_duration(
iteration_step=self.step,
event_name="l2_cache.perf.get.cache_lock_wait_duration_us",
duration_ms=get_cache_lock_wait_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="l2_cache.perf.set.cache_lock_wait_duration_us",
duration_ms=set_cache_lock_wait_duration,
time_unit="us",
)

# pyre-ignore
def _recording_to_timer(
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void EmbeddingKVDB::set_cuda(
std::vector<double> EmbeddingKVDB::get_l2cache_perf(
const int64_t step,
const int64_t interval) {
std::vector<double> ret(13, 0); // num metrics
std::vector<double> ret(15, 0); // num metrics
if (step > 0 && step % interval == 0) {
int reset_val = 0;
auto num_cache_misses = num_cache_misses_.exchange(reset_val);
Expand All @@ -215,6 +215,12 @@ std::vector<double> EmbeddingKVDB::get_l2cache_perf(
get_tensor_copy_for_cache_update_.exchange(reset_val);
auto set_tensor_copy_for_cache_update_dur =
set_tensor_copy_for_cache_update_.exchange(reset_val);

auto set_cache_lock_wait_duration =
set_cache_lock_wait_duration_.exchange(reset_val);
auto get_cache_lock_wait_duration =
get_cache_lock_wait_duration_.exchange(reset_val);

ret[0] = (double(num_cache_misses) / interval);
ret[1] = (double(num_lookups) / interval);
ret[2] = (double(get_total_duration) / interval);
Expand All @@ -231,10 +237,16 @@ std::vector<double> EmbeddingKVDB::get_l2cache_perf(
ret[11] = (cache_mem_stats[0]); // free cache in bytes
ret[12] = (cache_mem_stats[1]); // total cache capacity in bytes
}
ret[13] = (double(set_cache_lock_wait_duration) / interval);
ret[14] = (double(get_cache_lock_wait_duration) / interval);
}
return ret;
}

void EmbeddingKVDB::reset_l2_cache() {
l2_cache_ = nullptr;
}

void EmbeddingKVDB::set(
const at::Tensor& indices,
const at::Tensor& weights,
Expand Down Expand Up @@ -264,7 +276,8 @@ void EmbeddingKVDB::set(
void EmbeddingKVDB::get(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) {
const at::Tensor& count,
int64_t sleep_ms) {
if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) {
XLOG_EVERY_MS(INFO, 60000)
<< "[TBE_ID" << unique_id_ << "]skip get_cuda since number lookups is "
Expand All @@ -274,6 +287,17 @@ void EmbeddingKVDB::get(
ASSERT_EQ(max_D_, weights.size(1));
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
wait_util_filling_work_done();

std::unique_lock<std::mutex> lock(l2_cache_mtx_);
get_cache_lock_wait_duration_ +=
facebook::WallClockUtil::NowInUsecFast() - start_ts;

// this is for unittest to repro synchronization situation deterministically
if (sleep_ms > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms));
XLOG(INFO) << "get sleep end";
}

auto cache_context = get_cache(indices, count);
if (cache_context != nullptr) {
if (cache_context->num_misses > 0) {
Expand Down Expand Up @@ -407,12 +431,14 @@ EmbeddingKVDB::set_cache(
if (l2_cache_ == nullptr) {
return folly::none;
}

// TODO: consider whether need to reconstruct indices/weights/count and free
// the original tensor since most of the tensor elem will be invalid,
// this will trade some perf for peak DRAM util saving

auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast();
std::unique_lock<std::mutex> lock(l2_cache_mtx_);
set_cache_lock_wait_duration_ +=
facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts;

l2_cache_->init_tensor_for_l2_eviction(indices, weights, count);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
/// relative element in <indices>
/// @param count A single element tensor that contains the number of indices
/// to be processed
/// @param sleep_ms this is used to specifically sleep in get function, this
/// is needed to reproduce synchronization situation deterministicly, in prod
/// case this will be 0 for sure
///
/// @return None
/// @note weights will be updated from either L2 cache or storage tier
void get(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count);
const at::Tensor& count,
int64_t sleep_ms = 0);

/// storage tier counterpart of function get()
virtual folly::coro::Task<void> get_kv_db_async(
Expand Down Expand Up @@ -227,6 +231,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const int64_t step,
const int64_t interval);

// reset L2 cache, this is used for unittesting to bypass l2 cache
void reset_l2_cache();

// block waiting for working items in queue to be finished, this is called by
// get_cache() as embedding read should wait until previous write to be
// finished, it could also be called in unitest to sync
void wait_util_filling_work_done();

private:
/// Find non-negative embedding indices in <indices> and shard them into
/// #cachelib_pools pieces to be lookedup in parallel
Expand Down Expand Up @@ -283,10 +295,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
const at::Tensor& indices,
const at::Tensor& weights);

// waiting for working item queue to be empty, this is called by get_cache()
// as embedding read should wait until previous write to be finished
void wait_util_filling_work_done();

std::unique_ptr<l2_cache::CacheLibCache> l2_cache_;
const int64_t unique_id_;
const int64_t num_shards_;
Expand All @@ -299,6 +307,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
// buffer queue that stores all the needed indices/weights/action_count to
// fill up cache
folly::USPSCQueue<QueueItem, true> weights_to_fill_queue_;
// In non pipelining mode, the sequence is
// - get_cuda(): L2 read and insert L2 cache misses into queue for
// bg L2 write
// - L1 cache eviction: insert into bg queue for L2 write
// - ScratchPad update: insert into bg queue for L2 write
// in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen
// after SP update
// in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1
// cache eviction pipeline case, SP bwd update could happen in parallel with
// L2 read mutex is used for l2 cache to do read / write exclusively
std::mutex l2_cache_mtx_;

// perf stats
// -- perf of get() function
Expand All @@ -313,9 +332,11 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
std::atomic<int64_t> get_weights_fillup_total_duration_{0};
std::atomic<int64_t> get_cache_memcpy_duration_{0};
std::atomic<int64_t> get_tensor_copy_for_cache_update_{0};
std::atomic<int64_t> get_cache_lock_wait_duration_{0};

// -- perf of set() function
std::atomic<int64_t> set_tensor_copy_for_cache_update_{0};
std::atomic<int64_t> set_cache_lock_wait_duration_{0};

// -- commone path
std::atomic<int64_t> total_cache_update_duration_{0};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->set(indices, weights, count);
}

void get(Tensor indices, Tensor weights, Tensor count) {
return impl_->get(indices, weights, count);
void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) {
return impl_->get(indices, weights, count, sleep_ms);
}

std::vector<int64_t> get_mem_usage() {
Expand All @@ -343,6 +343,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->flush();
}

void reset_l2_cache() {
return impl_->reset_l2_cache();
}

void wait_util_filling_work_done() {
return impl_->wait_util_filling_work_done();
}

private:
// shared pointer since we use shared_from_this() in callbacks.
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
Expand Down Expand Up @@ -413,7 +421,20 @@ static auto embedding_rocks_db_wrapper =
&EmbeddingRocksDBWrapper::get_rocksdb_io_duration)
.def("get_l2cache_perf", &EmbeddingRocksDBWrapper::get_l2cache_perf)
.def("set", &EmbeddingRocksDBWrapper::set)
.def("get", &EmbeddingRocksDBWrapper::get);
.def(
"get",
&EmbeddingRocksDBWrapper::get,
"",
{
torch::arg("indices"),
torch::arg("weights"),
torch::arg("count"),
torch::arg("sleep_ms") = 0,
})
.def("reset_l2_cache", &EmbeddingRocksDBWrapper::reset_l2_cache)
.def(
"wait_util_filling_work_done",
&EmbeddingRocksDBWrapper::wait_util_filling_work_done);

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
Expand Down
10 changes: 9 additions & 1 deletion fbgemm_gpu/test/tbe/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if not open_source:
if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, running_on_github
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
from fbgemm_gpu.test.test_utils import ( # noqa F401
gpu_unavailable,
running_on_github,
)


torch.ops.import_module("fbgemm_gpu.sparse_ops")
settings.register_profile("derandomize", derandomize=True)
Expand Down
Loading

0 comments on commit 904a1c6

Please sign in to comment.