Skip to content

Commit

Permalink
add configs for 1) client thread num 2) max key per request 3) max lo…
Browse files Browse the repository at this point in the history
…cal index length 4) ps hosts & ports (pytorch#2727)

Summary:
X-link: pytorch/torchrec#2118

Pull Request resolved: pytorch#2727

In this diff, we added the APi to enable 1) client thread num and 2) max key per request 3) max local index length 4) ps hosts & ports in Parameter Server to be configurable in model config

Reviewed By: emlin

Differential Revision: D58372476

fbshipit-source-id: 0b9fb24e3574c5cc9b4181f68d98a0f0b7fbd9ec
  • Loading branch information
Franco Mo authored and facebook-github-bot committed Jun 16, 2024
1 parent 8a938d6 commit e5d0c94
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
32 changes: 28 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,12 @@ def __init__(
CowClipDefinition
] = None, # used by Rowwise Adagrad
pooling_mode: PoolingMode = PoolingMode.SUM,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
tbe_unique_id: int = -1,
ps_max_key_per_request: Optional[int] = None,
ps_client_thread_num: Optional[int] = None,
ps_max_local_index_length: Optional[int] = None,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -285,14 +289,22 @@ def __init__(
| SSDTableBatchedEmbeddingBags._local_instance_index
)
logging.info(f"tbe_unique_id: {tbe_unique_id}")
logging.info(f"ps_max_local_index_length: {ps_max_local_index_length}")
logging.info(f"ps_client_thread_num: {ps_client_thread_num}")
logging.info(f"ps_max_key_per_request: {ps_max_key_per_request}")
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
(
ps_max_local_index_length
if ps_max_local_index_length is not None
else 54
),
ps_client_thread_num if ps_client_thread_num is not None else 32,
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
)
# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
Expand Down Expand Up @@ -790,7 +802,11 @@ def __init__(
ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
ssd_uniform_init_lower: float = -0.01,
ssd_uniform_init_upper: float = 0.01,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
ps_max_key_per_request: Optional[int] = None,
ps_client_thread_num: Optional[int] = None,
ps_max_local_index_length: Optional[int] = None,
tbe_unique_id: int = -1, # unique id for this embedding, if not set, will derive based on current rank and tbe index id
) -> None: # noqa C901 # tuple of (rows, dims,)
super(SSDIntNBitTableBatchedEmbeddingBags, self).__init__()
Expand Down Expand Up @@ -1002,14 +1018,22 @@ def max_ty_D(ty: SparseType) -> int:
| SSDIntNBitTableBatchedEmbeddingBags._local_instance_index
)
logging.info(f"tbe_unique_id: {tbe_unique_id}")
logging.info(f"ps_max_local_index_length: {ps_max_local_index_length}")
logging.info(f"ps_client_thread_num: {ps_client_thread_num}")
logging.info(f"ps_max_key_per_request: {ps_max_key_per_request}")
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
(
ps_max_local_index_length
if ps_max_local_index_length is not None
else 54
),
ps_client_thread_num if ps_client_thread_num is not None else 32,
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
)

# pyre-fixme[20]: Argument `self` expected.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
const std::vector<int64_t>& tps_ports,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32) {
int64_t num_threads = 32,
int64_t maxKeysPerRequest = 500) {
TORCH_CHECK(
tps_ips.size() == tps_ports.size(),
"tps_ips and tps_ports must have the same size");
Expand All @@ -32,7 +33,11 @@ class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
}

impl_ = std::make_shared<ps::EmbeddingParameterServer>(
std::move(tpsHosts), tbe_id, maxLocalIndexLength, num_threads);
std::move(tpsHosts),
tbe_id,
maxLocalIndexLength,
num_threads,
maxKeysPerRequest);
}

void
Expand Down Expand Up @@ -78,6 +83,7 @@ static auto embedding_parameter_server_wrapper =
const std::vector<int64_t>,
int64_t,
int64_t,
int64_t,
int64_t>())
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ class EmbeddingParameterServer : public kv_db::EmbeddingKVDB {
std::vector<std::pair<std::string, int>>&& tps_hosts,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32)
int64_t num_threads = 32,
int64_t maxKeysPerRequest = 500)
: tps_client_(
std::make_shared<mvai_infra::experimental::ps_training::tps_client::
TrainingParameterServiceClient>(
std::move(tps_hosts),
tbe_id,
maxLocalIndexLength,
num_threads)) {}
num_threads,
maxKeysPerRequest)) {}

void set(
const at::Tensor& indices,
Expand Down

0 comments on commit e5d0c94

Please sign in to comment.