Skip to content

Commit

Permalink
support different index dtype (#3140)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3140

X-link: facebookresearch/FBGEMM#233

before this diff we hack the indices to be int64
after this diff, SSD tbe support int32 indices

Reviewed By: q10

Differential Revision: D62761615

fbshipit-source-id: 5a08b022c0ceaa30bc0b7c2ec66e8b8444e15657
  • Loading branch information
Joe Wang authored and facebook-github-bot committed Sep 19, 2024
1 parent 46e309d commit 12710ef
Show file tree
Hide file tree
Showing 6 changed files with 446 additions and 335 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class CacheLibCache {
int64_t max_D_;
};

explicit CacheLibCache(const CacheConfig& cache_config);
explicit CacheLibCache(
const CacheConfig& cache_config,
int64_t unique_tbe_id);

std::unique_ptr<Cache> initializeCacheLib(const CacheConfig& config);

Expand All @@ -48,7 +50,7 @@ class CacheLibCache {

/// Find the stored embeddings from a given embedding indices, aka key
///
/// @param key embedding index to look up
/// @param key_tensor embedding index(tensor with only one element) to look up
///
/// @return an optional value, return none on cache misses, if cache hit
/// return a pointer to the cachelib underlying storage of associated
Expand All @@ -57,7 +59,7 @@ class CacheLibCache {
/// @note that this is not thread safe, caller needs to make sure the data is
/// fully processed before doing cache insertion, otherwise the returned space
/// might be overwritten if cache is full
std::optional<void*> get(int64_t key);
folly::Optional<void*> get(const at::Tensor& key_tensor);

/// Cachelib wrapper specific hash function
///
Expand All @@ -84,7 +86,8 @@ class CacheLibCache {

/// Add an embedding index and embeddings into cachelib
///
/// @param key embedding index to insert
/// @param key_tensor embedding index(tensor with only one element) to insert
/// @param data embedding weights to insert
///
/// @return true on success insertion, false on failure insertion, a failure
/// insertion could happen if the refcount of bottom K items in LRU queue
Expand All @@ -94,11 +97,12 @@ class CacheLibCache {
/// bulk read and bluk write sequentially
///
/// @note cache_->allocation will trigger eviction callback func
bool put(int64_t key, const at::Tensor& data);
bool put(const at::Tensor& key_tensor, const at::Tensor& data);

/// iterate through all items in L2 cache, fill them in indices and weights
/// respectively and return indices, weights and count
///
/// @return optional value, if cache is empty return none
/// @return indices The 1D embedding index tensor, should skip on negative
/// value
/// @return weights The 2D tensor that each row(embeddings) is paired up with
Expand All @@ -108,7 +112,8 @@ class CacheLibCache {
///
/// @note this isn't thread safe, caller needs to make sure put isn't called
/// while this is executed.
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_all_items();
folly::Optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
get_all_items();

/// instantiate eviction related indices and weights tensors(size of <count>)
/// for L2 eviction using the same dtype and device from <indices> and
Expand Down Expand Up @@ -141,12 +146,15 @@ class CacheLibCache {

private:
const CacheConfig cache_config_;
const int64_t unique_tbe_id_;
std::unique_ptr<Cache> cache_;
std::vector<facebook::cachelib::PoolId> pool_ids_;
std::unique_ptr<facebook::cachelib::CacheAdmin> admin_;

folly::Optional<at::Tensor> evicted_indices_opt_{folly::none};
folly::Optional<at::Tensor> evicted_weights_opt_{folly::none};
folly::Optional<at::ScalarType> index_dtype_{folly::none};
folly::Optional<at::ScalarType> weights_dtype_{folly::none};
std::atomic<int64_t> eviction_row_id{0};
};

Expand Down
201 changes: 115 additions & 86 deletions fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,11 @@ namespace l2_cache {

using Cache = facebook::cachelib::LruAllocator;

// this is a general predictor for weights data type, might not be general
// enough for all the cases
at::ScalarType bytes_to_dtype(int num_bytes) {
switch (num_bytes) {
case 1:
return at::kByte;
case 2:
return at::kHalf;
case 4:
return at::kFloat;
case 8:
return at::kDouble;
default:
throw std::runtime_error("Unsupported dtype");
}
}

CacheLibCache::CacheLibCache(const CacheConfig& cache_config)
CacheLibCache::CacheLibCache(
const CacheConfig& cache_config,
int64_t unique_tbe_id)
: cache_config_(cache_config),
unique_tbe_id_(unique_tbe_id),
cache_(initializeCacheLib(cache_config_)),
admin_(createCacheAdmin(*cache_)) {
for (size_t i = 0; i < cache_config_.num_shards; i++) {
Expand All @@ -50,30 +36,41 @@ CacheLibCache::CacheLibCache(const CacheConfig& cache_config)

std::unique_ptr<Cache> CacheLibCache::initializeCacheLib(
const CacheConfig& config) {
auto eviction_cb =
[this](const facebook::cachelib::LruAllocator::RemoveCbData& data) {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] {
if (data.context ==
facebook::cachelib::RemoveContext::kEviction) {
auto indices_data_ptr =
evicted_indices_opt_->data_ptr<int64_t>();
auto weights_data_ptr =
evicted_weights_opt_->data_ptr<scalar_t>();
auto row_id = eviction_row_id++;
auto weight_dim = evicted_weights_opt_->size(1);
const auto key_ptr =
reinterpret_cast<const int64_t*>(data.item.getKey().data());
indices_data_ptr[row_id] = *key_ptr;
auto eviction_cb = [this](
const facebook::cachelib::LruAllocator::RemoveCbData&
data) {
if (evicted_weights_opt_.has_value()) {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] {
using value_t = scalar_t;
FBGEMM_DISPATCH_INTEGRAL_TYPES(
evicted_indices_opt_->scalar_type(),
"l2_eviction_handling",
[&] {
using index_t = scalar_t;
if (data.context ==
facebook::cachelib::RemoveContext::kEviction) {
auto indices_data_ptr =
evicted_indices_opt_->data_ptr<index_t>();
auto weights_data_ptr =
evicted_weights_opt_->data_ptr<value_t>();
auto row_id = eviction_row_id++;
auto weight_dim = evicted_weights_opt_->size(1);
const auto key_ptr = reinterpret_cast<const index_t*>(
data.item.getKey().data());
indices_data_ptr[row_id] = *key_ptr;

std::copy(
reinterpret_cast<const scalar_t*>(data.item.getMemory()),
reinterpret_cast<const scalar_t*>(data.item.getMemory()) +
weight_dim,
&weights_data_ptr[row_id * weight_dim]); // dst_start
}
});
};
std::copy(
reinterpret_cast<const value_t*>(data.item.getMemory()),
reinterpret_cast<const value_t*>(
data.item.getMemory()) +
weight_dim,
&weights_data_ptr[row_id * weight_dim]); // dst_start
}
});
});
}
};
Cache::Config cacheLibConfig;
int64_t rough_num_items =
cache_config_.cache_size_bytes / cache_config_.item_size_bytes;
Expand All @@ -82,8 +79,9 @@ std::unique_ptr<Cache> CacheLibCache::initializeCacheLib(
unsigned int lock_power =
std::log(cache_config_.num_shards * 15) / std::log(2) + 1;
XLOG(INFO) << fmt::format(
"Setting up Cachelib for L2 cache, capacity: {}GB, "
"[TBE_ID{}] Setting up Cachelib for L2 cache, capacity: {}GB, "
"item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}",
unique_tbe_id_,
config.cache_size_bytes / 1024 / 1024 / 1024,
cache_config_.item_size_bytes,
rough_num_items,
Expand All @@ -106,14 +104,21 @@ std::unique_ptr<facebook::cachelib::CacheAdmin> CacheLibCache::createCacheAdmin(
cache, std::move(adminConfig));
}

std::optional<void*> CacheLibCache::get(int64_t key) {
auto key_str =
folly::StringPiece(reinterpret_cast<const char*>(&key), sizeof(int64_t));
auto item = cache_->find(key_str);
if (!item) {
return std::nullopt;
}
return const_cast<void*>(item->getMemory());
folly::Optional<void*> CacheLibCache::get(const at::Tensor& key_tensor) {
folly::Optional<void*> res;
FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "get", [&] {
using index_t = scalar_t;
auto key = *(key_tensor.data_ptr<index_t>());
auto key_str = folly::StringPiece(
reinterpret_cast<const char*>(&key), sizeof(index_t));
auto item = cache_->find(key_str);
if (!item) {
res = folly::none;
return;
}
res = const_cast<void*>(item->getMemory());
});
return res;
}

size_t CacheLibCache::get_shard_id(int64_t key) {
Expand All @@ -136,55 +141,79 @@ void CacheLibCache::batchMarkUseful(
}
}

bool CacheLibCache::put(int64_t key, const at::Tensor& data) {
auto key_str =
folly::StringPiece(reinterpret_cast<const char*>(&key), sizeof(int64_t));
auto item = cache_->findToWrite(key_str);
if (!item) {
auto alloc_item =
cache_->allocate(get_pool_id(key), key_str, data.nbytes());
if (!alloc_item) {
XLOG(ERR) << fmt::format(
"Failed to allocate item {} in cache, skip", key);
return false;
}
std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes());
cache_->insertOrReplace(std::move(alloc_item));
} else {
std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes());
bool CacheLibCache::put(const at::Tensor& key_tensor, const at::Tensor& data) {
if (!index_dtype_.has_value()) {
index_dtype_ = key_tensor.scalar_type();
}
if (!weights_dtype_.has_value()) {
weights_dtype_ = data.scalar_type();
}
return true;
bool res;
FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "put", [&] {
using index_t = scalar_t;
auto key = *(key_tensor.data_ptr<index_t>());
auto key_str = folly::StringPiece(
reinterpret_cast<const char*>(&key), sizeof(index_t));
auto item = cache_->findToWrite(key_str);
if (!item) {
auto alloc_item =
cache_->allocate(get_pool_id(key), key_str, data.nbytes());
if (!alloc_item) {
XLOG(ERR) << fmt::format(
"[TBE_ID{}]Failed to allocate item {} in cache, skip",
unique_tbe_id_,
key);
res = false;
return;
}
std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes());
cache_->insertOrReplace(std::move(alloc_item));
} else {
std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes());
}
res = true;
});
return res;
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> CacheLibCache::get_all_items() {
folly::Optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
CacheLibCache::get_all_items() {
if (!index_dtype_.has_value() || !weights_dtype_.has_value()) {
return folly::none;
}
int total_num_items = 0;
for (auto& pool_id : pool_ids_) {
total_num_items += cache_->getPoolStats(pool_id).numItems();
}
auto weight_dim = cache_config_.max_D_;
auto weights_dtype =
bytes_to_dtype(cache_config_.item_size_bytes / weight_dim);
auto indices = at::empty(
total_num_items, at::TensorOptions().dtype(at::kLong).device(at::kCPU));
total_num_items,
at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU));
auto weights = at::empty(
{total_num_items, weight_dim},
at::TensorOptions().dtype(weights_dtype).device(at::kCPU));
at::TensorOptions().dtype(weights_dtype_.value()).device(at::kCPU));
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
weights.scalar_type(), "get_all_items", [&] {
auto indices_data_ptr = indices.data_ptr<int64_t>();
auto weights_data_ptr = weights.data_ptr<scalar_t>();
int64_t item_idx = 0;
for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) {
const auto key_ptr =
reinterpret_cast<const int64_t*>(itr->getKey().data());
indices_data_ptr[item_idx] = *key_ptr;
std::copy(
reinterpret_cast<const scalar_t*>(itr->getMemory()),
reinterpret_cast<const scalar_t*>(itr->getMemory()) + weight_dim,
&weights_data_ptr[item_idx * weight_dim]); // dst_start
item_idx++;
}
CHECK_EQ(total_num_items, item_idx);
using value_t = scalar_t;
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(), "get_all_items", [&] {
using index_t = scalar_t;
auto indices_data_ptr = indices.data_ptr<index_t>();
auto weights_data_ptr = weights.data_ptr<value_t>();
int64_t item_idx = 0;
for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) {
const auto key_ptr =
reinterpret_cast<const index_t*>(itr->getKey().data());
indices_data_ptr[item_idx] = *key_ptr;
std::copy(
reinterpret_cast<const value_t*>(itr->getMemory()),
reinterpret_cast<const value_t*>(itr->getMemory()) +
weight_dim,
&weights_data_ptr[item_idx * weight_dim]); // dst_start
item_idx++;
}
CHECK_EQ(total_num_items, item_idx);
});
});
return std::make_tuple(
indices,
Expand Down
Loading

0 comments on commit 12710ef

Please sign in to comment.