-
Notifications
You must be signed in to change notification settings - Fork 479
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support torch class binding in fbgemm (#3151)
Summary: Pull Request resolved: #3151 X-link: facebookresearch/FBGEMM#243 * add torch class binding * add numeric tests usage see `KvTensorWrapperTest` Reviewed By: duduyi2013 Differential Revision: D62902959
- Loading branch information
1 parent
7628ba2
commit 9a46ccd
Showing
2 changed files
with
198 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import tempfile | ||
import unittest | ||
from unittest import TestCase | ||
|
||
import fbgemm_gpu | ||
import torch | ||
import torch.testing | ||
from fbgemm_gpu.split_embedding_configs import SparseType | ||
from fbgemm_gpu.utils.loader import load_torch_module | ||
|
||
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. | ||
open_source: bool = getattr(fbgemm_gpu, "open_source", False) | ||
|
||
if open_source: | ||
from test_utils import running_on_github # @manual # pyre-ignore[21] | ||
else: | ||
from fbgemm_gpu.test.test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils | ||
running_on_github, | ||
) | ||
|
||
load_torch_module( | ||
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings", | ||
) | ||
|
||
|
||
@unittest.skipIf(*running_on_github) | ||
class KvTensorWrapperTest(TestCase): | ||
def test_basic(self) -> None: | ||
E = int(1e4) | ||
D = 128 | ||
N = 1000 | ||
weights_precision = SparseType.FP32 | ||
weights_dtype = weights_precision.as_dtype() | ||
|
||
with tempfile.TemporaryDirectory() as ssd_directory: | ||
# pyre-fixme[16]: Module `classes` has no attribute `fbgemm`. | ||
ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( | ||
ssd_directory, | ||
8, # num_shards | ||
8, # num_threads | ||
0, # ssd_memtable_flush_period, | ||
0, # ssd_memtable_flush_offset, | ||
4, # ssd_l0_files_per_compact, | ||
D, # embedding_dim | ||
0, # ssd_rate_limit_mbps, | ||
1, # ssd_size_ratio, | ||
8, # ssd_compaction_trigger, | ||
536870912, # 512MB ssd_write_buffer_size, | ||
8, # ssd_max_write_buffer_num, | ||
-0.01, # ssd_uniform_init_lower | ||
0.01, # ssd_uniform_init_upper | ||
32, # row_storage_bitwidth | ||
10 * (2**20), # block cache size | ||
) | ||
|
||
# create random index tensor with size N | ||
indices = torch.randperm(N) | ||
# insert the weights with the corresponding indices into the table | ||
weights = torch.arange(N * D, dtype=weights_dtype).view(N, D) | ||
output_weights = torch.empty_like(weights) | ||
count = torch.tensor([N]) | ||
ssd_db.set(indices, weights, count) | ||
|
||
# force waiting for set to complete | ||
ssd_db.get(indices, output_weights, torch.tensor(indices.shape[0])) | ||
torch.testing.assert_close(weights, output_weights) | ||
|
||
# create a view tensor wrapper | ||
snapshot = ssd_db.create_snapshot() | ||
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( | ||
ssd_db, snapshot, [E, D], weights.dtype | ||
) | ||
self.assertEqual(tensor_wrapper.shape, [E, D]) | ||
|
||
# table has a total of E rows | ||
# load 1000 rows at a time | ||
step = 1000 | ||
for i in range(0, E, step): | ||
narrowed = tensor_wrapper.narrow(0, i, step) | ||
for weight_ind, v in enumerate(indices): | ||
j = v.item() | ||
if j < i or j >= i + step: | ||
continue | ||
self.assertTrue( | ||
torch.equal(narrowed[j % step], weights[weight_ind]), | ||
msg=( | ||
f"Tensor value mismatch at row {j}:\n" | ||
f"actual\n{narrowed[j % step]}\n\nexpected\n{weights[weight_ind]}" | ||
), | ||
) |