Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reserve a method during torchscript #3152

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ def get(self, name: str) -> int:
return self.config[name]


def sparse_type_to_int(sparse_type: "SparseType") -> int:
return {
SparseType.FP32.value: 0,
SparseType.FP16.value: 1,
SparseType.INT8.value: 2,
SparseType.INT4.value: 3,
SparseType.INT2.value: 4,
SparseType.BF16.value: 5,
SparseType.FP8.value: 6,
SparseType.MX4.value: 7,
}[sparse_type.value]


@enum.unique
class SparseType(enum.Enum):
FP32 = "fp32"
Expand Down Expand Up @@ -104,16 +117,7 @@ def from_int(ty: int) -> "SparseType":
raise ValueError(f"Unsupported sparse type: {ty}")

def as_int(self) -> int:
return {
SparseType.FP32.value: 0,
SparseType.FP16.value: 1,
SparseType.INT8.value: 2,
SparseType.INT4.value: 3,
SparseType.INT2.value: 4,
SparseType.BF16.value: 5,
SparseType.FP8.value: 6,
SparseType.MX4.value: 7,
}[self.value]
return sparse_type_to_int(self)

@staticmethod
def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch # usort:skip
from torch import nn, Tensor # usort:skip

from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
Expand Down Expand Up @@ -943,6 +943,7 @@ def reset_embedding_spec_location(
for spec in self.embedding_specs
]

@torch.jit.export
def recompute_module_buffers(self) -> None:
"""
Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets().
Expand All @@ -955,7 +956,7 @@ def recompute_module_buffers(self) -> None:
):
return

weights_tys_int = [e[3].as_int() for e in self.embedding_specs]
weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs]
self.weights_tys = torch.tensor(
[weights_tys_int[t] for t in self.feature_table_map],
device=self.current_device,
Expand All @@ -968,8 +969,9 @@ def recompute_module_buffers(self) -> None:
dtype=torch.int64,
)
dims = [e[2] for e in self.embedding_specs]
D_offsets_list = [dims[t] for t in self.feature_table_map]
D_offsets_list = [0] + list(accumulate(D_offsets_list))
D_offsets_list = [0]
for t in self.feature_table_map:
D_offsets_list.append(dims[t] + D_offsets_list[-1])
self.D_offsets = torch.tensor(
D_offsets_list, device=self.current_device, dtype=torch.int32
)
Expand Down Expand Up @@ -999,6 +1001,9 @@ def recompute_module_buffers(self) -> None:
self.table_wise_cache_miss = torch.empty_like(
self.table_wise_cache_miss, device=self.current_device
)
self.weights_uvm = torch.empty_like(
self.weights_uvm, device=self.current_device
)

def _apply_split(
self,
Expand Down
Loading