diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index bb6edff00..80955120f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -39,7 +39,9 @@ class CacheLibCache { int64_t max_D_; }; - explicit CacheLibCache(const CacheConfig& cache_config); + explicit CacheLibCache( + const CacheConfig& cache_config, + int64_t unique_tbe_id); std::unique_ptr initializeCacheLib(const CacheConfig& config); @@ -48,7 +50,7 @@ class CacheLibCache { /// Find the stored embeddings from a given embedding indices, aka key /// - /// @param key embedding index to look up + /// @param key_tensor embedding index(tensor with only one element) to look up /// /// @return an optional value, return none on cache misses, if cache hit /// return a pointer to the cachelib underlying storage of associated @@ -57,7 +59,7 @@ class CacheLibCache { /// @note that this is not thread safe, caller needs to make sure the data is /// fully processed before doing cache insertion, otherwise the returned space /// might be overwritten if cache is full - std::optional get(int64_t key); + folly::Optional get(const at::Tensor& key_tensor); /// Cachelib wrapper specific hash function /// @@ -84,7 +86,8 @@ class CacheLibCache { /// Add an embedding index and embeddings into cachelib /// - /// @param key embedding index to insert + /// @param key_tensor embedding index(tensor with only one element) to insert + /// @param data embedding weights to insert /// /// @return true on success insertion, false on failure insertion, a failure /// insertion could happen if the refcount of bottom K items in LRU queue @@ -94,11 +97,12 @@ class CacheLibCache { /// bulk read and bluk write sequentially /// /// @note cache_->allocation will trigger eviction callback func - bool put(int64_t key, const at::Tensor& data); + bool put(const at::Tensor& key_tensor, const at::Tensor& data); /// iterate through all items in L2 cache, fill them in indices and weights /// respectively and return indices, weights and count /// + /// @return optional value, if cache is empty return none /// @return indices The 1D embedding index tensor, should skip on negative /// value /// @return weights The 2D tensor that each row(embeddings) is paired up with @@ -108,7 +112,8 @@ class CacheLibCache { /// /// @note this isn't thread safe, caller needs to make sure put isn't called /// while this is executed. - std::tuple get_all_items(); + folly::Optional> + get_all_items(); /// instantiate eviction related indices and weights tensors(size of ) /// for L2 eviction using the same dtype and device from and @@ -141,12 +146,15 @@ class CacheLibCache { private: const CacheConfig cache_config_; + const int64_t unique_tbe_id_; std::unique_ptr cache_; std::vector pool_ids_; std::unique_ptr admin_; folly::Optional evicted_indices_opt_{folly::none}; folly::Optional evicted_weights_opt_{folly::none}; + folly::Optional index_dtype_{folly::none}; + folly::Optional weights_dtype_{folly::none}; std::atomic eviction_row_id{0}; }; diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index d7eb220d8..f5002375c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -15,25 +15,11 @@ namespace l2_cache { using Cache = facebook::cachelib::LruAllocator; -// this is a general predictor for weights data type, might not be general -// enough for all the cases -at::ScalarType bytes_to_dtype(int num_bytes) { - switch (num_bytes) { - case 1: - return at::kByte; - case 2: - return at::kHalf; - case 4: - return at::kFloat; - case 8: - return at::kDouble; - default: - throw std::runtime_error("Unsupported dtype"); - } -} - -CacheLibCache::CacheLibCache(const CacheConfig& cache_config) +CacheLibCache::CacheLibCache( + const CacheConfig& cache_config, + int64_t unique_tbe_id) : cache_config_(cache_config), + unique_tbe_id_(unique_tbe_id), cache_(initializeCacheLib(cache_config_)), admin_(createCacheAdmin(*cache_)) { for (size_t i = 0; i < cache_config_.num_shards; i++) { @@ -50,30 +36,41 @@ CacheLibCache::CacheLibCache(const CacheConfig& cache_config) std::unique_ptr CacheLibCache::initializeCacheLib( const CacheConfig& config) { - auto eviction_cb = - [this](const facebook::cachelib::LruAllocator::RemoveCbData& data) { - FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { - if (data.context == - facebook::cachelib::RemoveContext::kEviction) { - auto indices_data_ptr = - evicted_indices_opt_->data_ptr(); - auto weights_data_ptr = - evicted_weights_opt_->data_ptr(); - auto row_id = eviction_row_id++; - auto weight_dim = evicted_weights_opt_->size(1); - const auto key_ptr = - reinterpret_cast(data.item.getKey().data()); - indices_data_ptr[row_id] = *key_ptr; + auto eviction_cb = [this]( + const facebook::cachelib::LruAllocator::RemoveCbData& + data) { + if (evicted_weights_opt_.has_value()) { + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + evicted_indices_opt_->scalar_type(), + "l2_eviction_handling", + [&] { + using index_t = scalar_t; + if (data.context == + facebook::cachelib::RemoveContext::kEviction) { + auto indices_data_ptr = + evicted_indices_opt_->data_ptr(); + auto weights_data_ptr = + evicted_weights_opt_->data_ptr(); + auto row_id = eviction_row_id++; + auto weight_dim = evicted_weights_opt_->size(1); + const auto key_ptr = reinterpret_cast( + data.item.getKey().data()); + indices_data_ptr[row_id] = *key_ptr; - std::copy( - reinterpret_cast(data.item.getMemory()), - reinterpret_cast(data.item.getMemory()) + - weight_dim, - &weights_data_ptr[row_id * weight_dim]); // dst_start - } - }); - }; + std::copy( + reinterpret_cast(data.item.getMemory()), + reinterpret_cast( + data.item.getMemory()) + + weight_dim, + &weights_data_ptr[row_id * weight_dim]); // dst_start + } + }); + }); + } + }; Cache::Config cacheLibConfig; int64_t rough_num_items = cache_config_.cache_size_bytes / cache_config_.item_size_bytes; @@ -82,8 +79,9 @@ std::unique_ptr CacheLibCache::initializeCacheLib( unsigned int lock_power = std::log(cache_config_.num_shards * 15) / std::log(2) + 1; XLOG(INFO) << fmt::format( - "Setting up Cachelib for L2 cache, capacity: {}GB, " + "[TBE_ID{}] Setting up Cachelib for L2 cache, capacity: {}GB, " "item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}", + unique_tbe_id_, config.cache_size_bytes / 1024 / 1024 / 1024, cache_config_.item_size_bytes, rough_num_items, @@ -106,14 +104,21 @@ std::unique_ptr CacheLibCache::createCacheAdmin( cache, std::move(adminConfig)); } -std::optional CacheLibCache::get(int64_t key) { - auto key_str = - folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->find(key_str); - if (!item) { - return std::nullopt; - } - return const_cast(item->getMemory()); +folly::Optional CacheLibCache::get(const at::Tensor& key_tensor) { + folly::Optional res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "get", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + auto item = cache_->find(key_str); + if (!item) { + res = folly::none; + return; + } + res = const_cast(item->getMemory()); + }); + return res; } size_t CacheLibCache::get_shard_id(int64_t key) { @@ -136,55 +141,79 @@ void CacheLibCache::batchMarkUseful( } } -bool CacheLibCache::put(int64_t key, const at::Tensor& data) { - auto key_str = - folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->findToWrite(key_str); - if (!item) { - auto alloc_item = - cache_->allocate(get_pool_id(key), key_str, data.nbytes()); - if (!alloc_item) { - XLOG(ERR) << fmt::format( - "Failed to allocate item {} in cache, skip", key); - return false; - } - std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes()); - cache_->insertOrReplace(std::move(alloc_item)); - } else { - std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); +bool CacheLibCache::put(const at::Tensor& key_tensor, const at::Tensor& data) { + if (!index_dtype_.has_value()) { + index_dtype_ = key_tensor.scalar_type(); + } + if (!weights_dtype_.has_value()) { + weights_dtype_ = data.scalar_type(); } - return true; + bool res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "put", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + auto item = cache_->findToWrite(key_str); + if (!item) { + auto alloc_item = + cache_->allocate(get_pool_id(key), key_str, data.nbytes()); + if (!alloc_item) { + XLOG(ERR) << fmt::format( + "[TBE_ID{}]Failed to allocate item {} in cache, skip", + unique_tbe_id_, + key); + res = false; + return; + } + std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes()); + cache_->insertOrReplace(std::move(alloc_item)); + } else { + std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); + } + res = true; + }); + return res; } -std::tuple CacheLibCache::get_all_items() { +folly::Optional> +CacheLibCache::get_all_items() { + if (!index_dtype_.has_value() || !weights_dtype_.has_value()) { + return folly::none; + } int total_num_items = 0; for (auto& pool_id : pool_ids_) { total_num_items += cache_->getPoolStats(pool_id).numItems(); } auto weight_dim = cache_config_.max_D_; - auto weights_dtype = - bytes_to_dtype(cache_config_.item_size_bytes / weight_dim); auto indices = at::empty( - total_num_items, at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + total_num_items, + at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU)); auto weights = at::empty( {total_num_items, weight_dim}, - at::TensorOptions().dtype(weights_dtype).device(at::kCPU)); + at::TensorOptions().dtype(weights_dtype_.value()).device(at::kCPU)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "get_all_items", [&] { - auto indices_data_ptr = indices.data_ptr(); - auto weights_data_ptr = weights.data_ptr(); - int64_t item_idx = 0; - for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) { - const auto key_ptr = - reinterpret_cast(itr->getKey().data()); - indices_data_ptr[item_idx] = *key_ptr; - std::copy( - reinterpret_cast(itr->getMemory()), - reinterpret_cast(itr->getMemory()) + weight_dim, - &weights_data_ptr[item_idx * weight_dim]); // dst_start - item_idx++; - } - CHECK_EQ(total_num_items, item_idx); + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "get_all_items", [&] { + using index_t = scalar_t; + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = weights.data_ptr(); + int64_t item_idx = 0; + for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) { + const auto key_ptr = + reinterpret_cast(itr->getKey().data()); + indices_data_ptr[item_idx] = *key_ptr; + std::copy( + reinterpret_cast(itr->getMemory()), + reinterpret_cast(itr->getMemory()) + + weight_dim, + &weights_data_ptr[item_idx * weight_dim]); // dst_start + item_idx++; + } + CHECK_EQ(total_num_items, item_idx); + }); }); return std::make_tuple( indices, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index fcf04fff0..21aa00290 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -44,19 +44,24 @@ QueueItem tensor_copy( at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "tensor_copy", [&] { - auto indices_addr = indices.data_ptr(); - auto new_indices_addr = new_indices.data_ptr(); - std::copy( - indices_addr, - indices_addr + num_sets, - new_indices_addr); // dst_start - - auto weights_addr = weights.data_ptr(); - auto new_weightss_addr = new_weights.data_ptr(); - std::copy( - weights_addr, - weights_addr + num_sets * weights.size(1), - new_weightss_addr); // dst_start + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "tensor_copy", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start + + auto weights_addr = weights.data_ptr(); + auto new_weightss_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weightss_addr); // dst_start + }); }); *new_count.data_ptr() = num_sets; return QueueItem{new_indices, new_weights, new_count, mode}; @@ -79,7 +84,8 @@ EmbeddingKVDB::EmbeddingKVDB( cache_config.num_shards = num_shards_; cache_config.item_size_bytes = max_D_ * ele_size_bytes; cache_config.max_D_ = max_D_; - l2_cache_ = std::make_unique(cache_config); + l2_cache_ = + std::make_unique(cache_config, unique_id); } else { l2_cache_ = nullptr; } @@ -128,10 +134,15 @@ EmbeddingKVDB::~EmbeddingKVDB() { void EmbeddingKVDB::flush() { wait_util_filling_work_done(); if (l2_cache_) { - auto tensor_tuple = l2_cache_->get_all_items(); - auto& indices = std::get<0>(tensor_tuple); - auto& weights = std::get<1>(tensor_tuple); - auto& count = std::get<2>(tensor_tuple); + auto tensor_tuple_opt = l2_cache_->get_all_items(); + if (!tensor_tuple_opt.has_value()) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "]no items exist in L2 cache, flush nothing"; + return; + } + auto& indices = std::get<0>(tensor_tuple_opt.value()); + auto& weights = std::get<1>(tensor_tuple_opt.value()); + auto& count = std::get<2>(tensor_tuple_opt.value()); folly::coro::blockingWait(set_kv_db_async( indices, weights, count, kv_db::RocksdbWriteMode::FLUSH)); } @@ -143,6 +154,7 @@ void EmbeddingKVDB::get_cuda( const at::Tensor& count) { auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::get_cuda ##"); + check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = @@ -163,6 +175,7 @@ void EmbeddingKVDB::set_cuda( const bool is_bwd) { auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::set_cuda ##"); + check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = new std::function([=]() { @@ -229,8 +242,8 @@ void EmbeddingKVDB::set( const bool is_bwd) { if (auto num_evictions = get_maybe_uvm_scalar(count); num_evictions <= 0) { XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ << "]skip set_cuda since number evictions is " - << num_evictions; + << "[TBE_ID" << unique_id_ + << "]skip set_cuda since number evictions is " << num_evictions; return; } auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -254,7 +267,7 @@ void EmbeddingKVDB::get( const at::Tensor& count) { if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) { XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ << "]skip get_cuda since number lookups is " + << "[TBE_ID" << unique_id_ << "]skip get_cuda since number lookups is " << num_lookups; return; } @@ -303,63 +316,67 @@ std::shared_ptr EmbeddingKVDB::get_cache( return nullptr; } auto start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto indices_addr = indices.data_ptr(); + auto num_lookups = get_maybe_uvm_scalar(count); auto cache_context = std::make_shared(num_lookups); - - auto num_shards = executor_tp_->numThreads(); - - std::vector> tasks; - std::vector> row_ids_per_shard(num_shards); - for (int i = 0; i < num_shards; i++) { - row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); - } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); - } - for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, - &indices_addr, - cache_context, - shard_id, - &row_ids_per_shard]() mutable -> folly::coro::Task { - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - auto cached_addr_opt = l2_cache_->get(emb_idx); - if (cached_addr_opt.has_value()) { // cache hit - cache_context->cached_addr_list[row_id] = - cached_addr_opt.value(); - indices_addr[row_id] = -1; // mark to sentinel value - } else { // cache miss - cache_context->num_misses += 1; + FBGEMM_DISPATCH_INTEGRAL_TYPES(indices.scalar_type(), "get_cache", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto num_shards = executor_tp_->numThreads(); + + std::vector> tasks; + std::vector> row_ids_per_shard(num_shards); + for (int i = 0; i < num_shards; i++) { + row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); + } + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &indices_addr, + &indices, + cache_context, + shard_id, + &row_ids_per_shard]() mutable -> folly::coro::Task { + for (const auto& row_id : row_ids_per_shard[shard_id]) { + auto emb_idx = indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + auto cached_addr_opt = l2_cache_->get(indices[row_id]); + if (cached_addr_opt.has_value()) { // cache hit + cache_context->cached_addr_list[row_id] = + cached_addr_opt.value(); + indices_addr[row_id] = -1; // mark to sentinel value + } else { // cache miss + cache_context->num_misses += 1; + } } - } - co_return; - }) - .scheduleOn(executor_tp_.get())); - } - folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - - // the following metrics added here as the current assumption is - // get_cache will only be called in get_cuda path, if assumption no longer - // true, we should wrap this up on the caller side - auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; - get_cache_lookup_total_duration_ += dur; - auto cache_misses = cache_context->num_misses.load(); - if (num_lookups > 0) { - num_cache_misses_ += cache_misses; - num_lookups_ += num_lookups; - } else { - XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ - << "]num_lookups is 0, skip collecting the L2 cache miss stats"; - } + co_return; + }) + .scheduleOn(executor_tp_.get())); + } + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + + // the following metrics added here as the current assumption is + // get_cache will only be called in get_cuda path, if assumption no longer + // true, we should wrap this up on the caller side + auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; + get_cache_lookup_total_duration_ += dur; + auto cache_misses = cache_context->num_misses.load(); + if (num_lookups > 0) { + num_cache_misses_ += cache_misses; + num_lookups_ += num_lookups; + } else { + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "]num_lookups is 0, skip collecting the L2 cache miss stats"; + } + }); return cache_context; } @@ -373,7 +390,8 @@ void EmbeddingKVDB::wait_util_filling_work_done() { total_wait_time_ms += 1; if (total_wait_time_ms > 100) { XLOG_EVERY_MS(ERR, 1000) - << "get_cache: waiting for L2 caching filling embeddings for " + << "[TBE_ID" << unique_id_ + << "]get_cache: waiting for L2 caching filling embeddings for " << total_wait_time_ms << " ms, somethings is likely wrong"; } } @@ -397,45 +415,50 @@ EmbeddingKVDB::set_cache( auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast(); l2_cache_->init_tensor_for_l2_eviction(indices, weights, count); - auto indices_addr = indices.data_ptr(); - const int64_t num_lookups = get_maybe_uvm_scalar(count); - auto num_shards = executor_tp_->numThreads(); - std::vector> tasks; - std::vector> row_ids_per_shard(num_shards); + FBGEMM_DISPATCH_INTEGRAL_TYPES(indices.scalar_type(), "set_cache", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + const int64_t num_lookups = get_maybe_uvm_scalar(count); + auto num_shards = executor_tp_->numThreads(); - for (int i = 0; i < num_shards; i++) { - row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); - } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); - } + std::vector> tasks; + std::vector> row_ids_per_shard(num_shards); - for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, - &indices_addr, - &weights, - shard_id, - &row_ids_per_shard]() mutable -> folly::coro::Task { - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - if (!l2_cache_->put(emb_idx, weights[row_id])) { - XLOG_EVERY_MS(ERR, 1000) - << "[" << unique_id_ - << "]Failed to insert into cache, this shouldn't happen"; + for (int i = 0; i < num_shards; i++) { + row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); + } + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } + + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &indices_addr, + &indices, + &weights, + shard_id, + &row_ids_per_shard]() mutable -> folly::coro::Task { + for (const auto& row_id : row_ids_per_shard[shard_id]) { + auto emb_idx = indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + if (!l2_cache_->put(indices[row_id], weights[row_id])) { + XLOG_EVERY_MS(ERR, 1000) + << "[TBE_ID" << unique_id_ + << "]Failed to insert into cache, this shouldn't happen"; + } } - } - co_return; - }) - .scheduleOn(executor_tp_.get())); - } - folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + co_return; + }) + .scheduleOn(executor_tp_.get())); + } + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + }); total_cache_update_duration_ += facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; auto tensor_tuple_opt = l2_cache_->get_tensors_and_reset(); @@ -492,4 +515,24 @@ folly::coro::Task EmbeddingKVDB::cache_memcpy( co_return; } +void EmbeddingKVDB::check_tensor_type_consistency( + const at::Tensor& indices, + const at::Tensor& weights) { + if (index_dtype_.has_value()) { + assert(index_dtype_.value() == indices.scalar_type()); + } else { + index_dtype_ = indices.scalar_type(); + XLOG(INFO) << "[TBE_ID" << unique_id_ << "]L2 cache index dtype is " + << index_dtype_.value(); + } + + if (weights_dtype_.has_value()) { + assert(weights_dtype_.value() == weights.scalar_type()); + } else { + weights_dtype_ = weights.scalar_type(); + XLOG(INFO) << "[TBE_ID" << unique_id_ << "]L2 cache weights dtype is " + << weights_dtype_.value(); + } +} + } // namespace kv_db diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index d63ff5418..da98ba9c3 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -279,6 +279,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this { virtual void flush_or_compact(const int64_t timestep) = 0; + void check_tensor_type_consistency( + const at::Tensor& indices, + const at::Tensor& weights); + // waiting for working item queue to be empty, this is called by get_cache() // as embedding read should wait until previous write to be finished void wait_util_filling_work_done(); @@ -287,6 +291,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t unique_id_; const int64_t num_shards_; const int64_t max_D_; + folly::Optional index_dtype_{folly::none}; + folly::Optional weights_dtype_{folly::none}; std::unique_ptr executor_tp_; std::unique_ptr cache_filling_thread_; std::atomic stop_{false}; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 3d1f2c977..e7d32682e 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -432,36 +432,43 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_set", [&] { - CHECK(indices.is_contiguous()); - CHECK(weights.is_contiguous()); - auto indices_acc = indices.accessor(); - auto D = weights.size(1); - CHECK_EQ(indices.size(0), weights.size(0)); - { - rocksdb::WriteBatch batch( - (2 * (count_ + dbs_.size() - 1) / dbs_.size()) * - (sizeof(int64_t) + sizeof(scalar_t) * D)); - for (auto i = 0; i < count_; ++i) { - if (indices_acc[i] < 0) { - continue; - } - if (kv_db_utils::hash_shard( - indices_acc[i], dbs_.size()) != shard) { - continue; - } - batch.Put( - rocksdb::Slice( - reinterpret_cast( - &(indices.data_ptr()[i])), - sizeof(int64_t)), - rocksdb::Slice( - reinterpret_cast( - &(weights.data_ptr()[i * D])), - D * sizeof(scalar_t))); - } - auto s = dbs_[shard]->Write(wo_, &batch); - CHECK(s.ok()); - } + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "ssd_set", [&] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights.is_contiguous()); + auto indices_acc = indices.accessor(); + auto D = weights.size(1); + CHECK_EQ(indices.size(0), weights.size(0)); + { + rocksdb::WriteBatch batch( + (2 * (count_ + dbs_.size() - 1) / + dbs_.size()) * + (sizeof(index_t) + sizeof(value_t) * D)); + for (auto i = 0; i < count_; ++i) { + if (indices_acc[i] < 0) { + continue; + } + if (kv_db_utils::hash_shard( + indices_acc[i], dbs_.size()) != shard) { + continue; + } + batch.Put( + rocksdb::Slice( + reinterpret_cast( + &(indices.data_ptr()[i])), + sizeof(index_t)), + rocksdb::Slice( + reinterpret_cast( + &(weights + .data_ptr()[i * D])), + D * sizeof(value_t))); + } + auto s = dbs_[shard]->Write(wo_, &batch); + CHECK(s.ok()); + } + }); }); co_return; }) @@ -661,107 +668,117 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_get", [&] { - CHECK(indices.is_contiguous()); - CHECK(weights.is_contiguous()); - auto indices_data_ptr = indices.data_ptr(); - auto D = weights.size(1); - CHECK_EQ(indices.size(0), weights.size(0)); - auto weights_data_ptr = weights.data_ptr(); - FOLLY_DECLARE_REUSED(keys, std::vector); - FOLLY_DECLARE_REUSED(shard_ids, std::vector); - FOLLY_DECLARE_REUSED( - cfs, std::vector); - FOLLY_DECLARE_REUSED( - values, std::vector); - FOLLY_DECLARE_REUSED( - statuses, std::vector); - auto* dcf = dbs_[shard]->DefaultColumnFamily(); - for (auto i = 0; i < count_; ++i) { - // "no-op"/empty evicted tensor - if (indices_data_ptr[i] == -1) { - continue; - } - if (kv_db_utils::hash_shard( - indices_data_ptr[i], dbs_.size()) != shard) { - continue; - } - shard_ids.push_back(i); - } - std::sort( - shard_ids.begin(), - shard_ids.end(), - [&](int32_t lhs, int32_t rhs) { - const auto lhs_key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[lhs])), - sizeof(int64_t)); - const auto rhs_key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[rhs])), - sizeof(int64_t)); - return lhs_key.compare(rhs_key) < 0; + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "ssd_get", [&] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights.is_contiguous()); + auto indices_data_ptr = indices.data_ptr(); + auto D = weights.size(1); + CHECK_EQ(indices.size(0), weights.size(0)); + auto weights_data_ptr = weights.data_ptr(); + FOLLY_DECLARE_REUSED( + keys, std::vector); + FOLLY_DECLARE_REUSED( + shard_ids, std::vector); + FOLLY_DECLARE_REUSED( + cfs, std::vector); + FOLLY_DECLARE_REUSED( + values, std::vector); + FOLLY_DECLARE_REUSED( + statuses, std::vector); + auto* dcf = dbs_[shard]->DefaultColumnFamily(); + for (auto i = 0; i < count_; ++i) { + // "no-op"/empty evicted tensor + if (indices_data_ptr[i] == -1) { + continue; + } + if (kv_db_utils::hash_shard( + indices_data_ptr[i], dbs_.size()) != + shard) { + continue; + } + shard_ids.push_back(i); + } + std::sort( + shard_ids.begin(), + shard_ids.end(), + [&](int32_t lhs, int32_t rhs) { + const auto lhs_key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[lhs])), + sizeof(index_t)); + const auto rhs_key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[rhs])), + sizeof(index_t)); + return lhs_key.compare(rhs_key) < 0; + }); + for (const auto& i : shard_ids) { + const auto key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[i])), + sizeof(index_t)); + keys.push_back(key); + cfs.push_back(dcf); + } + CHECK_EQ(shard_ids.size(), keys.size()); + CHECK_EQ(shard_ids.size(), cfs.size()); + + values.resize(keys.size()); + statuses.resize(keys.size()); + // Set a snapshot if it is available + ro_.snapshot = snapshot; + dbs_[shard]->MultiGet( + ro_, + keys.size(), + cfs.data(), + keys.data(), + values.data(), + statuses.data(), + /*sorted_input=*/true); + const auto& init_storage = + initializers_[shard]->row_storage_; + // Sanity check + TORCH_CHECK( + init_storage.scalar_type() == + weights.scalar_type(), + "init_storage (", + toString(init_storage.scalar_type()), + ") and weights scalar (", + toString(weights.scalar_type()), + ") types mismatch"); + auto row_storage_data_ptr = + init_storage.data_ptr(); + for (auto j = 0; j < keys.size(); ++j) { + const auto& s = statuses[j]; + int64_t i = shard_ids[j]; + const auto& value = values[j]; + if (s.ok()) { + if (!std::is_same::value) { + CHECK_EQ(value.size(), D * sizeof(value_t)); + } + std::copy( + reinterpret_cast( + value.data()), + reinterpret_cast( + value.data() + value.size()), + &(weights_data_ptr[i * D])); + } else { + CHECK(s.IsNotFound()); + int64_t row_index; + initializers_[shard]->producer_queue_.dequeue( + row_index); + std::copy( + &(row_storage_data_ptr[row_index * D]), + &(row_storage_data_ptr[row_index * D + D]), + &(weights_data_ptr[i * D])); + initializers_[shard]->consumer_queue_.enqueue( + row_index); + } + } }); - for (const auto& i : shard_ids) { - const auto key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[i])), - sizeof(int64_t)); - keys.push_back(key); - cfs.push_back(dcf); - } - CHECK_EQ(shard_ids.size(), keys.size()); - CHECK_EQ(shard_ids.size(), cfs.size()); - - values.resize(keys.size()); - statuses.resize(keys.size()); - // Set a snapshot if it is available - ro_.snapshot = snapshot; - dbs_[shard]->MultiGet( - ro_, - keys.size(), - cfs.data(), - keys.data(), - values.data(), - statuses.data(), - /*sorted_input=*/true); - const auto& init_storage = - initializers_[shard]->row_storage_; - // Sanity check - TORCH_CHECK( - init_storage.scalar_type() == weights.scalar_type(), - "init_storage (", - toString(init_storage.scalar_type()), - ") and weights scalar (", - toString(weights.scalar_type()), - ") types mismatch"); - auto row_storage_data_ptr = - init_storage.data_ptr(); - for (auto j = 0; j < keys.size(); ++j) { - const auto& s = statuses[j]; - int64_t i = shard_ids[j]; - const auto& value = values[j]; - if (s.ok()) { - if (!std::is_same::value) { - CHECK_EQ(value.size(), D * sizeof(scalar_t)); - } - std::copy( - reinterpret_cast(value.data()), - reinterpret_cast( - value.data() + value.size()), - &(weights_data_ptr[i * D])); - } else { - CHECK(s.IsNotFound()); - int64_t row_index; - initializers_[shard]->producer_queue_.dequeue( - row_index); - std::copy( - &(row_storage_data_ptr[row_index * D]), - &(row_storage_data_ptr[row_index * D + D]), - &(weights_data_ptr[i * D])); - initializers_[shard]->consumer_queue_.enqueue( - row_index); - } - } }); co_return; }) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 7cc75d9e9..61ac98d6e 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -94,15 +94,23 @@ def get_physical_table_arg_indices_(self, feature_table_map: List[int]): @given( weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + indice_int64_t=st.sampled_from([True, False]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd(self, weights_precision: SparseType) -> None: + def test_ssd(self, indice_int64_t: bool, weights_precision: SparseType) -> None: import tempfile E = int(1e4) D = 128 N = 100 - indices = torch.as_tensor(np.random.choice(E, replace=False, size=(N,))) + if indice_int64_t: + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + ) + else: + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int32 + ) weights = torch.randn(N, D, dtype=weights_precision.as_dtype()) output_weights = torch.empty_like(weights) count = torch.tensor([N])