Skip to content

Commit

Permalink
pass in kernel tbe id into rocksdb wrapper (pytorch#2930)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#32

Pull Request resolved: pytorch#2930

the reason we need this is we constantly see the port conflict error in rocksdb initialization. Before this diff we call getFreePort to ge an available port. For each ssd tbe we will create 32 rocksdb shards, so in total there are 256 ports needed per host.
This works fine with 4 hosts until we are running 16 hosts training job as we need make sure all 16 hosts don't get into the corner cases where multiple db shard get assigned the same free port.

Reviewed By: sryap

Differential Revision: D60635718

fbshipit-source-id: 606216a4a2d5a43f82f7bd681477537413bd372a
  • Loading branch information
Joe Wang authored and facebook-github-bot committed Aug 6, 2024
1 parent a28e2e0 commit 6607072
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
30 changes: 18 additions & 12 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,25 @@ def __init__(
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
)
# logging.info("DEBUG: weights_precision {}".format(weights_precision))

# create tbe unique id using rank index | local tbe idx
if tbe_unique_id == -1:
SSDTableBatchedEmbeddingBags._local_instance_index += 1
if dist.is_initialized():
assert (
SSDTableBatchedEmbeddingBags._local_instance_index < 1024
), f"{SSDTableBatchedEmbeddingBags._local_instance_index}, more than 1024 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
tbe_unique_id = (
dist.get_rank() << 10
| SSDTableBatchedEmbeddingBags._local_instance_index
)
else:
logging.warning("dist is not initialized, treating as single gpu cases")
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
logging.info(f"tbe_unique_id: {tbe_unique_id}")
if not ps_hosts:
logging.info(
f"Logging SSD offloading setup "
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id},"
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps},"
Expand Down Expand Up @@ -289,19 +305,9 @@ def __init__(
weights_precision.bit_rate(), # row_storage_bitwidth
ssd_block_cache_size_per_tbe,
use_passed_in_path,
tbe_unique_id,
)
else:
# create tbe unique id using rank index | local tbe idx
if tbe_unique_id == -1:
SSDTableBatchedEmbeddingBags._local_instance_index += 1
assert (
SSDTableBatchedEmbeddingBags._local_instance_index < 8
), f"{SSDTableBatchedEmbeddingBags._local_instance_index}, more than 8 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
tbe_unique_id = (
dist.get_rank() << 3
| SSDTableBatchedEmbeddingBags._local_instance_index
)
logging.info(f"tbe_unique_id: {tbe_unique_id}")
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
double uniform_init_upper,
int64_t row_storage_bitwidth = 32,
int64_t cache_size = 0,
bool use_passed_in_path = false)
bool use_passed_in_path = false,
int64_t tbe_unique_id = 0)
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
path,
num_shards,
Expand All @@ -186,7 +187,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
uniform_init_upper,
row_storage_bitwidth,
cache_size,
use_passed_in_path)) {}
use_passed_in_path,
tbe_unique_id)) {}

void
set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) {
Expand Down Expand Up @@ -238,7 +240,8 @@ static auto embedding_rocks_db_wrapper =
double,
int64_t,
int64_t,
bool>(),
bool,
int64_t>(),
"",
{
torch::arg("path"),
Expand All @@ -258,6 +261,7 @@ static auto embedding_rocks_db_wrapper =
torch::arg("row_storage_bitwidth"),
torch::arg("cache_size"),
torch::arg("use_passed_in_path") = true,
torch::arg("tbe_unique_id") = 0,
})
.def("set_cuda", &EmbeddingRocksDBWrapper::set_cuda)
.def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <torch/nn/init.h>
#include <iostream>
#ifdef FBGEMM_FBCODE
#include "common/network/PortUtil.h"
#include "common/strings/UUID.h"
#include "fb_rocksdb/DBMonitor/DBMonitor.h"
#include "fb_rocksdb/FbRocksDb.h"
Expand Down Expand Up @@ -40,6 +39,7 @@ constexpr size_t kRowInitBufferSize = 32 * 1024;
#ifdef FBGEMM_FBCODE
constexpr size_t num_ssd_drives = 8;
const std::string ssd_mount_point = "/data00_nvidia";
const size_t base_port = 136000;
#endif

class Initializer {
Expand Down Expand Up @@ -132,7 +132,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
float uniform_init_upper,
int64_t row_storage_bitwidth = 32,
int64_t cache_size = 0,
bool use_passed_in_path = false) {
bool use_passed_in_path = false,
int64_t tbe_unqiue_id = 0) {
// TODO: lots of tunables. NNI or something for this?
rocksdb::Options options;
options.create_if_missing = true;
Expand Down Expand Up @@ -256,7 +257,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
rocksdb::DB* db;

#ifdef FBGEMM_FBCODE
db_monitor_options.port = facebook::network::getFreePort();
db_monitor_options.port = base_port + tbe_unqiue_id;
auto s = facebook::fb_rocksdb::openRocksDB(
options,
shard_path,
Expand Down

0 comments on commit 6607072

Please sign in to comment.