Skip to content

Commit

Permalink
reserve a method during torchscript (#3152)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3152

X-link: facebookresearch/FBGEMM#247

This method is useful to help recompute buffers for torchscripted model in pyhon.

Reviewed By: seanx92

Differential Revision: D63000116

fbshipit-source-id: 43ad420b22ac2e06b59e2189dada1a0b30befcad
  • Loading branch information
842974287 authored and facebook-github-bot committed Sep 19, 2024
1 parent 0377308 commit ebbebd4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
24 changes: 14 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ def get(self, name: str) -> int:
return self.config[name]


def sparse_type_to_int(sparse_type: "SparseType") -> int:
return {
SparseType.FP32.value: 0,
SparseType.FP16.value: 1,
SparseType.INT8.value: 2,
SparseType.INT4.value: 3,
SparseType.INT2.value: 4,
SparseType.BF16.value: 5,
SparseType.FP8.value: 6,
SparseType.MX4.value: 7,
}[sparse_type.value]


@enum.unique
class SparseType(enum.Enum):
FP32 = "fp32"
Expand Down Expand Up @@ -104,16 +117,7 @@ def from_int(ty: int) -> "SparseType":
raise ValueError(f"Unsupported sparse type: {ty}")

def as_int(self) -> int:
return {
SparseType.FP32.value: 0,
SparseType.FP16.value: 1,
SparseType.INT8.value: 2,
SparseType.INT4.value: 3,
SparseType.INT2.value: 4,
SparseType.BF16.value: 5,
SparseType.FP8.value: 6,
SparseType.MX4.value: 7,
}[self.value]
return sparse_type_to_int(self)

@staticmethod
def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch # usort:skip
from torch import nn, Tensor # usort:skip

from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
Expand Down Expand Up @@ -943,6 +943,7 @@ def reset_embedding_spec_location(
for spec in self.embedding_specs
]

@torch.jit.export
def recompute_module_buffers(self) -> None:
"""
Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets().
Expand All @@ -955,7 +956,7 @@ def recompute_module_buffers(self) -> None:
):
return

weights_tys_int = [e[3].as_int() for e in self.embedding_specs]
weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs]
self.weights_tys = torch.tensor(
[weights_tys_int[t] for t in self.feature_table_map],
device=self.current_device,
Expand All @@ -968,8 +969,9 @@ def recompute_module_buffers(self) -> None:
dtype=torch.int64,
)
dims = [e[2] for e in self.embedding_specs]
D_offsets_list = [dims[t] for t in self.feature_table_map]
D_offsets_list = [0] + list(accumulate(D_offsets_list))
D_offsets_list = [0]
for t in self.feature_table_map:
D_offsets_list.append(dims[t] + D_offsets_list[-1])
self.D_offsets = torch.tensor(
D_offsets_list, device=self.current_device, dtype=torch.int32
)
Expand Down Expand Up @@ -999,6 +1001,9 @@ def recompute_module_buffers(self) -> None:
self.table_wise_cache_miss = torch.empty_like(
self.table_wise_cache_miss, device=self.current_device
)
self.weights_uvm = torch.empty_like(
self.weights_uvm, device=self.current_device
)

def _apply_split(
self,
Expand Down

0 comments on commit ebbebd4

Please sign in to comment.