Skip to content

Commit

Permalink
Fill embedding tables with randomized scales and bias in split-TBE be…
Browse files Browse the repository at this point in the history
…nchmarks (pytorch#2031)

Summary:
Pull Request resolved: pytorch#2031

Same as title

Reviewed By: sryap

Differential Revision: D49433995

fbshipit-source-id: 8f7cc876a4284aabe36374d8e95ff2fa043e5ebe
  • Loading branch information
Wei Su authored and facebook-github-bot committed Sep 29, 2023
1 parent d1b8766 commit 49f5794
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
35 changes: 35 additions & 0 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple

import numpy as np

import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_embedding_utils import ( # noqa: F401
b_indices,
generate_requests, # noqa: F401
get_device, # noqa: F401
round_up, # noqa: F401
)
from torch import nn

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -455,3 +459,34 @@ def benchmark_vbe(
return VBEBenchmarkOutput(
avg, fwd, bwd, compressed_avg, compressed_fwd, reindex, compressed_bwd
)


def fill_random_scale_bias(
emb: nn.Module,
T: int,
weights_precision: SparseType,
) -> None:
for t in range(T):
(weights, scale_shift) = emb.split_embedding_weights()[t]
if scale_shift is not None:
(E, R) = scale_shift.shape
assert R == 4
scales = None
shifts = None
if weights_precision == SparseType.INT8:
scales = np.random.uniform(0.001, 0.01, size=(E,)).astype(np.float16)
shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
elif weights_precision == SparseType.INT4:
scales = np.random.uniform(0.01, 0.1, size=(E,)).astype(np.float16)
shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
elif weights_precision == SparseType.INT2:
scales = np.random.uniform(0.1, 1, size=(E,)).astype(np.float16)
shifts = np.random.normal(-2, 2, size=(E,)).astype(np.float16)
scale_shift.copy_(
torch.tensor(
np.stack([scales, shifts], axis=1)
.astype(np.float16)
.view(np.uint8),
device=scale_shift.device,
)
)
8 changes: 8 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
benchmark_requests_refer,
benchmark_torch_function,
benchmark_vbe,
fill_random_scale_bias,
)
else:
from fbgemm_gpu.bench.bench_utils import (
Expand All @@ -70,6 +71,7 @@
benchmark_requests_refer,
benchmark_torch_function,
benchmark_vbe,
fill_random_scale_bias,
)


Expand Down Expand Up @@ -815,6 +817,7 @@ def nbit_cpu( # noqa C901
fp8_exponent_bias=fp8_exponent_bias,
).cpu()
emb.fill_random_weights()
fill_random_scale_bias(emb, T, weights_precision)

nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
Expand Down Expand Up @@ -987,6 +990,7 @@ def nbit_device( # noqa C901
fp8_exponent_bias=fp8_exponent_bias,
).cuda()
emb.fill_random_weights()
fill_random_scale_bias(emb, T, weights_precision)

nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
Expand Down Expand Up @@ -1267,6 +1271,7 @@ def nbit_device_with_spec( # noqa C901
else:
emb = emb.cuda()
emb.fill_random_weights()
fill_random_scale_bias(emb, T, weights_precision)

nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
Expand Down Expand Up @@ -1843,6 +1848,7 @@ def bench_uvm_cls(
uvm_host_mapped=uvm_host_mapped,
).cuda()
emb.fill_random_weights()
fill_random_scale_bias(emb, T, weights_precision)

nvtx_range = (
f"UVM-RECORD-CACHE-{name.upper()}"
Expand Down Expand Up @@ -2015,6 +2021,7 @@ def nbit_cache( # noqa C901
cache_assoc=cache_assoc,
).cuda()
emb_nc.fill_random_weights()
fill_random_scale_bias(emb_nc, T, weights_precision)

emb = IntNBitTableBatchedEmbeddingBagsCodegen(
[
Expand All @@ -2040,6 +2047,7 @@ def nbit_cache( # noqa C901
cache_assoc=cache_assoc,
).cuda()
emb.fill_random_weights()
fill_random_scale_bias(emb, T, weights_precision)

nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
Expand Down

0 comments on commit 49f5794

Please sign in to comment.