diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index f0e16b840..e2d35072f 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -1850,8 +1850,8 @@ def _report_l2_cache_perf_stats(self) -> None: self.step, stats_reporter.report_interval # pyre-ignore ) - if len(l2_cache_perf_stats) != 13: - logging.error("l2 perf stats should have 13 elements") + if len(l2_cache_perf_stats) != 15: + logging.error("l2 perf stats should have 15 elements") return num_cache_misses = l2_cache_perf_stats[0] @@ -1869,6 +1869,9 @@ def _report_l2_cache_perf_stats(self) -> None: l2_cache_free_bytes = l2_cache_perf_stats[11] l2_cache_capacity = l2_cache_perf_stats[12] + set_cache_lock_wait_duration = l2_cache_perf_stats[13] + get_cache_lock_wait_duration = l2_cache_perf_stats[14] + stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.l2_num_cache_misses_stats_name, @@ -1944,6 +1947,19 @@ def _report_l2_cache_perf_stats(self) -> None: time_unit="us", ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_lock_wait_duration_us", + duration_ms=get_cache_lock_wait_duration, + time_unit="us", + ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.set.cache_lock_wait_duration_us", + duration_ms=set_cache_lock_wait_duration, + time_unit="us", + ) + # pyre-ignore def _recording_to_timer( self, timer: Optional[AsyncSeriesTimer], **kwargs: Any diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 21aa00290..fdc91b0e9 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -193,7 +193,7 @@ void EmbeddingKVDB::set_cuda( std::vector EmbeddingKVDB::get_l2cache_perf( const int64_t step, const int64_t interval) { - std::vector ret(13, 0); // num metrics + std::vector ret(15, 0); // num metrics if (step > 0 && step % interval == 0) { int reset_val = 0; auto num_cache_misses = num_cache_misses_.exchange(reset_val); @@ -215,6 +215,12 @@ std::vector EmbeddingKVDB::get_l2cache_perf( get_tensor_copy_for_cache_update_.exchange(reset_val); auto set_tensor_copy_for_cache_update_dur = set_tensor_copy_for_cache_update_.exchange(reset_val); + + auto set_cache_lock_wait_duration = + set_cache_lock_wait_duration_.exchange(reset_val); + auto get_cache_lock_wait_duration = + get_cache_lock_wait_duration_.exchange(reset_val); + ret[0] = (double(num_cache_misses) / interval); ret[1] = (double(num_lookups) / interval); ret[2] = (double(get_total_duration) / interval); @@ -231,10 +237,16 @@ std::vector EmbeddingKVDB::get_l2cache_perf( ret[11] = (cache_mem_stats[0]); // free cache in bytes ret[12] = (cache_mem_stats[1]); // total cache capacity in bytes } + ret[13] = (double(set_cache_lock_wait_duration) / interval); + ret[14] = (double(get_cache_lock_wait_duration) / interval); } return ret; } +void EmbeddingKVDB::reset_l2_cache() { + l2_cache_ = nullptr; +} + void EmbeddingKVDB::set( const at::Tensor& indices, const at::Tensor& weights, @@ -264,7 +276,8 @@ void EmbeddingKVDB::set( void EmbeddingKVDB::get( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count, + int64_t sleep_ms) { if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) { XLOG_EVERY_MS(INFO, 60000) << "[TBE_ID" << unique_id_ << "]skip get_cuda since number lookups is " @@ -274,6 +287,17 @@ void EmbeddingKVDB::get( ASSERT_EQ(max_D_, weights.size(1)); auto start_ts = facebook::WallClockUtil::NowInUsecFast(); wait_util_filling_work_done(); + + std::unique_lock lock(l2_cache_mtx_); + get_cache_lock_wait_duration_ += + facebook::WallClockUtil::NowInUsecFast() - start_ts; + + // this is for unittest to repro synchronization situation deterministically + if (sleep_ms > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + XLOG(INFO) << "get sleep end"; + } + auto cache_context = get_cache(indices, count); if (cache_context != nullptr) { if (cache_context->num_misses > 0) { @@ -407,12 +431,14 @@ EmbeddingKVDB::set_cache( if (l2_cache_ == nullptr) { return folly::none; } - // TODO: consider whether need to reconstruct indices/weights/count and free // the original tensor since most of the tensor elem will be invalid, // this will trade some perf for peak DRAM util saving auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast(); + std::unique_lock lock(l2_cache_mtx_); + set_cache_lock_wait_duration_ += + facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; l2_cache_->init_tensor_for_l2_eviction(indices, weights, count); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index da98ba9c3..93495f2da 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -171,13 +171,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this { /// relative element in /// @param count A single element tensor that contains the number of indices /// to be processed + /// @param sleep_ms this is used to specifically sleep in get function, this + /// is needed to reproduce synchronization situation deterministicly, in prod + /// case this will be 0 for sure /// /// @return None /// @note weights will be updated from either L2 cache or storage tier void get( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count); + const at::Tensor& count, + int64_t sleep_ms = 0); /// storage tier counterpart of function get() virtual folly::coro::Task get_kv_db_async( @@ -227,6 +231,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t step, const int64_t interval); + // reset L2 cache, this is used for unittesting to bypass l2 cache + void reset_l2_cache(); + + // block waiting for working items in queue to be finished, this is called by + // get_cache() as embedding read should wait until previous write to be + // finished, it could also be called in unitest to sync + void wait_util_filling_work_done(); + private: /// Find non-negative embedding indices in and shard them into /// #cachelib_pools pieces to be lookedup in parallel @@ -283,10 +295,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const at::Tensor& indices, const at::Tensor& weights); - // waiting for working item queue to be empty, this is called by get_cache() - // as embedding read should wait until previous write to be finished - void wait_util_filling_work_done(); - std::unique_ptr l2_cache_; const int64_t unique_id_; const int64_t num_shards_; @@ -299,6 +307,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this { // buffer queue that stores all the needed indices/weights/action_count to // fill up cache folly::USPSCQueue weights_to_fill_queue_; + // In non pipelining mode, the sequence is + // - get_cuda(): L2 read and insert L2 cache misses into queue for + // bg L2 write + // - L1 cache eviction: insert into bg queue for L2 write + // - ScratchPad update: insert into bg queue for L2 write + // in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen + // after SP update + // in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1 + // cache eviction pipeline case, SP bwd update could happen in parallel with + // L2 read mutex is used for l2 cache to do read / write exclusively + std::mutex l2_cache_mtx_; // perf stats // -- perf of get() function @@ -313,9 +332,11 @@ class EmbeddingKVDB : public std::enable_shared_from_this { std::atomic get_weights_fillup_total_duration_{0}; std::atomic get_cache_memcpy_duration_{0}; std::atomic get_tensor_copy_for_cache_update_{0}; + std::atomic get_cache_lock_wait_duration_{0}; // -- perf of set() function std::atomic set_tensor_copy_for_cache_update_{0}; + std::atomic set_cache_lock_wait_duration_{0}; // -- commone path std::atomic total_cache_update_duration_{0}; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 1e91cefb5..6238fe1be 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -315,8 +315,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->set(indices, weights, count); } - void get(Tensor indices, Tensor weights, Tensor count) { - return impl_->get(indices, weights, count); + void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) { + return impl_->get(indices, weights, count, sleep_ms); } std::vector get_mem_usage() { @@ -343,6 +343,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->flush(); } + void reset_l2_cache() { + return impl_->reset_l2_cache(); + } + + void wait_util_filling_work_done() { + return impl_->wait_util_filling_work_done(); + } + private: // shared pointer since we use shared_from_this() in callbacks. std::shared_ptr impl_; @@ -413,7 +421,20 @@ static auto embedding_rocks_db_wrapper = &EmbeddingRocksDBWrapper::get_rocksdb_io_duration) .def("get_l2cache_perf", &EmbeddingRocksDBWrapper::get_l2cache_perf) .def("set", &EmbeddingRocksDBWrapper::set) - .def("get", &EmbeddingRocksDBWrapper::get); + .def( + "get", + &EmbeddingRocksDBWrapper::get, + "", + { + torch::arg("indices"), + torch::arg("weights"), + torch::arg("count"), + torch::arg("sleep_ms") = 0, + }) + .def("reset_l2_cache", &EmbeddingRocksDBWrapper::reset_l2_cache) + .def( + "wait_util_filling_work_done", + &EmbeddingRocksDBWrapper::wait_util_filling_work_done); TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( diff --git a/fbgemm_gpu/test/tbe/common.py b/fbgemm_gpu/test/tbe/common.py index df38e15de..40f1b49e7 100644 --- a/fbgemm_gpu/test/tbe/common.py +++ b/fbgemm_gpu/test/tbe/common.py @@ -17,8 +17,16 @@ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) -if not open_source: +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable, running_on_github +else: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") + from fbgemm_gpu.test.test_utils import ( # noqa F401 + gpu_unavailable, + running_on_github, + ) + torch.ops.import_module("fbgemm_gpu.sparse_ops") settings.register_profile("derandomize", derandomize=True) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py new file mode 100644 index 000000000..ef5a9f0a2 --- /dev/null +++ b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[3,6,56] + +import tempfile + +import threading +import time +import unittest + +from typing import Any, Dict, List, Tuple + +import hypothesis.strategies as st +import numpy as np +import torch +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.utils import round_up +from hypothesis import given, settings, Verbosity + +from .. import common # noqa E402 +from ..common import gpu_unavailable, running_on_github + +MAX_EXAMPLES = 20 +default_st: Dict[str, Any] = { + "T": st.integers(min_value=1, max_value=10), + "D": st.integers(min_value=2, max_value=128), + "log_E": st.integers(min_value=2, max_value=3), + "mixed": st.booleans(), + "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), +} + +default_settings: Dict[str, Any] = { + "verbosity": Verbosity.verbose, + "max_examples": MAX_EXAMPLES, + "deadline": None, +} + + +@unittest.skipIf(*running_on_github) +@unittest.skipIf(*gpu_unavailable) +class SSDCheckpointTest(unittest.TestCase): + def generate_fbgemm_ssd_tbe( + self, + T: int, + D: int, + log_E: int, + weights_precision: SparseType, + mixed: bool, + enable_l2: bool = True, + ) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int], int]: + E = int(10**log_E) + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) + ] + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) + ] + + feature_table_map = list(range(T)) + emb = SSDTableBatchedEmbeddingBags( + embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], + feature_table_map=feature_table_map, + ssd_storage_directory=tempfile.mkdtemp(), + cache_sets=1, + ssd_uniform_init_lower=-0.1, + ssd_uniform_init_upper=0.1, + weights_precision=weights_precision, + l2_cache_size=1 if enable_l2 else 0, + ) + return emb, Es, Ds, max(Ds) + + # @given(**default_st, do_flush=st.sampled_from([True, False])) + # @settings(**default_settings) + # def test_l2_flush( + # self, + # T: int, + # D: int, + # log_E: int, + # mixed: bool, + # weights_precision: SparseType, + # do_flush: bool, + # ) -> None: + # emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + # T, D, log_E, weights_precision, mixed + # ) + # indices = torch.arange(start=0, end=sum(Es)) + # weights = torch.randn( + # indices.numel(), max_D, dtype=weights_precision.as_dtype() + # ) + # weights_from_l2 = torch.empty_like(weights) + # count = torch.as_tensor([indices.numel()]) + # emb.ssd_db.set_cuda(indices, weights, count, 1) + # emb.ssd_db.get_cuda(indices.clone(), weights_from_l2, count) + + # torch.cuda.synchronize() + # assert torch.equal(weights, weights_from_l2) + # import logging + + # logging.info(f"wgqtest {do_flush=}") + # weights_from_ssd = torch.empty_like(weights) + # if do_flush: + # emb.ssd_db.flush() + # emb.ssd_db.reset_l2_cache() + # emb.ssd_db.get_cuda(indices, weights_from_ssd, count) + # torch.cuda.synchronize() + # if do_flush: + # assert torch.equal(weights, weights_from_ssd) + # else: + # assert not torch.equal(weights, weights_from_ssd) + + # @given(**default_st, enable_l2=st.sampled_from([True, False])) + # @settings(**default_settings) + # def test_l2_io( + # self, + # T: int, + # D: int, + # log_E: int, + # mixed: bool, + # weights_precision: SparseType, + # enable_l2: bool, + # ) -> None: + # emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + # T, D, log_E, weights_precision, mixed, enable_l2 + # ) + # E = int(10**log_E) + # num_rounds = 10 + # N = E + # total_indices = torch.tensor([]) + + # indices = torch.as_tensor( + # np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + # ) + # weights = torch.randn( + # indices.numel(), max_D, dtype=weights_precision.as_dtype() + # ) + # sub_N = N // num_rounds + + # for _ in range(num_rounds): + # sub_indices = torch.as_tensor( + # np.random.choice(E, replace=False, size=(sub_N,)), dtype=torch.int64 + # ) + # sub_weights = weights[sub_indices, :] + # sub_weights_out = torch.empty_like(sub_weights) + # count = torch.as_tensor([sub_indices.numel()]) + # emb.ssd_db.set_cuda(sub_indices, sub_weights, count, 1) + # emb.ssd_db.get_cuda(sub_indices.clone(), sub_weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(sub_weights, sub_weights_out) + # total_indices = torch.cat((total_indices, sub_indices)) + # # dedup + # used_unique_indices = torch.tensor( + # list(set(total_indices.tolist())), dtype=torch.int64 + # ) + # stored_weights = weights[used_unique_indices, :] + # weights_out = torch.empty_like(stored_weights) + # count = torch.as_tensor([used_unique_indices.numel()]) + # emb.ssd_db.get_cuda(used_unique_indices.clone(), weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(stored_weights, weights_out) + + # emb.ssd_db.flush() + # emb.ssd_db.reset_l2_cache() + # weights_out = torch.empty_like(stored_weights) + # count = torch.as_tensor([used_unique_indices.numel()]) + # emb.ssd_db.get_cuda(used_unique_indices.clone(), weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(stored_weights, weights_out) + + @given(**default_st) + @settings(**default_settings) + def test_l2_prefetch_compatibility( + self, + T: int, + D: int, + log_E: int, + mixed: bool, + weights_precision: SparseType, + ) -> None: + weights_precision: SparseType = SparseType.FP32 + emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + T, D, log_E, weights_precision, mixed + ) + E = int(10**log_E) + N = E + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + ) + weights = torch.randn(N, max_D, dtype=weights_precision.as_dtype()) + new_weights = weights + 1 + weights_out = torch.empty_like(weights) + count = torch.as_tensor([E]) + emb.ssd_db.set(indices, weights, count) + emb.ssd_db.wait_util_filling_work_done() + + event = threading.Event() + get_sleep_ms = 50 + + # pyre-ignore + def trigger_get() -> None: + event.set() + emb.ssd_db.get(indices.clone(), weights_out, count, get_sleep_ms) + + # pyre-ignore + def trigger_set() -> None: + event.wait() + time.sleep( + get_sleep_ms / 1000.0 / 2 + ) # sleep half of the sleep time in get, making sure set is trigger after get but before get is done + emb.ssd_db.set(indices, new_weights, count) + + thread1 = threading.Thread(target=trigger_get) + thread2 = threading.Thread(target=trigger_set) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + assert torch.equal(weights, weights_out) + emb.ssd_db.get(indices.clone(), weights_out, count) + assert torch.equal(new_weights, weights_out)