From 5921557966d90a8fb839ea666b892db4b32c98d0 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Mon, 29 Jan 2024 12:23:31 -0800 Subject: [PATCH] make fill_random_weights work for MTIA tbe module (#2286) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2286 att, just a couple places left over previously when adding MTIA EmbeddingLocation. Reviewed By: jspark1105 Differential Revision: D53062844 fbshipit-source-id: 051aafc0199cf7cfb99417f5ef39c36fd450fc6a --- .../split_table_batched_embeddings_ops_inference.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 4015c9353..4e86ebacd 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -101,7 +101,7 @@ def nbit_construct_split_state( offsets.append(host_size) host_size += state_size elif location == EmbeddingLocation.DEVICE or location == EmbeddingLocation.MTIA: - placements.append(EmbeddingLocation.DEVICE) + placements.append(location) offsets.append(dev_size) dev_size += state_size else: @@ -1176,7 +1176,10 @@ def split_embedding_weights_with_scale_bias( splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = [] for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs): placement = self.weights_physical_placements[t] - if placement == EmbeddingLocation.DEVICE.value: + if ( + placement == EmbeddingLocation.DEVICE.value + or placement == EmbeddingLocation.MTIA.value + ): weights = self.weights_dev elif placement == EmbeddingLocation.HOST.value: weights = self.weights_host