diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 365aebbfc..afb931bb3 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -68,6 +68,19 @@ def get(self, name: str) -> int: return self.config[name] +def sparse_type_to_int(sparse_type: "SparseType") -> int: + return { + SparseType.FP32.value: 0, + SparseType.FP16.value: 1, + SparseType.INT8.value: 2, + SparseType.INT4.value: 3, + SparseType.INT2.value: 4, + SparseType.BF16.value: 5, + SparseType.FP8.value: 6, + SparseType.MX4.value: 7, + }[sparse_type.value] + + @enum.unique class SparseType(enum.Enum): FP32 = "fp32" @@ -104,16 +117,7 @@ def from_int(ty: int) -> "SparseType": raise ValueError(f"Unsupported sparse type: {ty}") def as_int(self) -> int: - return { - SparseType.FP32.value: 0, - SparseType.FP16.value: 1, - SparseType.INT8.value: 2, - SparseType.INT4.value: 3, - SparseType.INT2.value: 4, - SparseType.BF16.value: 5, - SparseType.FP8.value: 6, - SparseType.MX4.value: 7, - }[self.value] + return sparse_type_to_int(self) @staticmethod def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType": diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 721e4f248..d988563ae 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -17,7 +17,7 @@ import torch # usort:skip from torch import nn, Tensor # usort:skip -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, @@ -943,6 +943,7 @@ def reset_embedding_spec_location( for spec in self.embedding_specs ] + @torch.jit.export def recompute_module_buffers(self) -> None: """ Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets(). @@ -955,7 +956,7 @@ def recompute_module_buffers(self) -> None: ): return - weights_tys_int = [e[3].as_int() for e in self.embedding_specs] + weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs] self.weights_tys = torch.tensor( [weights_tys_int[t] for t in self.feature_table_map], device=self.current_device, @@ -968,8 +969,9 @@ def recompute_module_buffers(self) -> None: dtype=torch.int64, ) dims = [e[2] for e in self.embedding_specs] - D_offsets_list = [dims[t] for t in self.feature_table_map] - D_offsets_list = [0] + list(accumulate(D_offsets_list)) + D_offsets_list = [0] + for t in self.feature_table_map: + D_offsets_list.append(dims[t] + D_offsets_list[-1]) self.D_offsets = torch.tensor( D_offsets_list, device=self.current_device, dtype=torch.int32 ) @@ -999,6 +1001,9 @@ def recompute_module_buffers(self) -> None: self.table_wise_cache_miss = torch.empty_like( self.table_wise_cache_miss, device=self.current_device ) + self.weights_uvm = torch.empty_like( + self.weights_uvm, device=self.current_device + ) def _apply_split( self,