From 9a46ccdb3e62f57a27aa3268578608ef8871610c Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Thu, 19 Sep 2024 22:42:00 -0700 Subject: [PATCH] support torch class binding in fbgemm (#3151) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3151 X-link: https://github.com/facebookresearch/FBGEMM/pull/243 * add torch class binding * add numeric tests usage see `KvTensorWrapperTest` Reviewed By: duduyi2013 Differential Revision: D62902959 --- .../ssd_split_table_batched_embeddings.cpp | 101 +++++++++++++++++- .../test/tbe/ssd/kv_tensor_wrapper_test.py | 98 +++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 6238fe1be4..42532294e8 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -16,6 +16,7 @@ #include "fbgemm_gpu/utils/ops_utils.h" using namespace at; +using namespace ssd; std::tuple ssd_cache_populate_actions_cuda( @@ -255,6 +256,16 @@ void compact_indices_cuda( Tensor count); namespace { +class KVTensorWrapper; + +struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder { + explicit EmbeddingSnapshotHandleWrapper( + const EmbeddingRocksDB::SnapshotHandle* handle) + : handle(handle) {} + + const EmbeddingRocksDB::SnapshotHandle* handle; +}; + class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { public: EmbeddingRocksDBWrapper( @@ -351,11 +362,77 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->wait_util_filling_work_done(); } + c10::intrusive_ptr create_snapshot() { + auto handle = impl_->create_snapshot(); + return c10::make_intrusive(handle); + } + + void release_snapshot( + c10::intrusive_ptr snapshot_handle) { + auto handle = snapshot_handle->handle; + CHECK_NE(handle, nullptr); + CHECK(impl_->is_valid_snapshot(handle)); + impl_->release_snapshot(handle); + } + + int64_t get_max_D() { + return impl_->get_max_D(); + } + private: + friend class KVTensorWrapper; + // shared pointer since we use shared_from_this() in callbacks. std::shared_ptr impl_; }; +class KVTensorWrapper : public torch::jit::CustomClassHolder { + public: + explicit KVTensorWrapper( + c10::intrusive_ptr db, + const c10::intrusive_ptr& snapshot_handle, + std::vector shape, + int64_t dtype) + : db_(std::move(db->impl_)), + snapshot_handle_(snapshot_handle->handle), + shape_(std::move(shape)) { + CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported"; + options_ = at::TensorOptions() + .dtype(static_cast(dtype)) + .device(at::kCPU) + .layout(at::kStrided); + } + + at::Tensor narrow(int64_t dim, int64_t start, int64_t length) { + CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported"; + CHECK_EQ(db_->get_max_D(), shape_[1]); + auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_); + db_->get_range_from_snapshot(t, start, length, snapshot_handle_); + // TBE may have multiple embeddings in one table padded to max D + // narrow to the actual shape here before returning + return t.narrow(1, 0, shape_[1]); + } + + c10::IntArrayRef size() { + return shape_; + } + + c10::ScalarType dtype() { + return options_.dtype().toScalarType(); + } + + private: + std::shared_ptr db_; + const EmbeddingRocksDB::SnapshotHandle* snapshot_handle_; + at::TensorOptions options_; + std::vector shape_; +}; + +static auto embedding_snapshot_handle_wrapper = + torch::class_( + "fbgemm", + "EmbeddingSnapshotHandleWrapper"); + static auto embedding_rocks_db_wrapper = torch::class_("fbgemm", "EmbeddingRocksDBWrapper") .def( @@ -434,7 +511,29 @@ static auto embedding_rocks_db_wrapper = .def("reset_l2_cache", &EmbeddingRocksDBWrapper::reset_l2_cache) .def( "wait_util_filling_work_done", - &EmbeddingRocksDBWrapper::wait_util_filling_work_done); + &EmbeddingRocksDBWrapper::wait_util_filling_work_done) + .def("create_snapshot", &EmbeddingRocksDBWrapper::create_snapshot) + .def("release_snapshot", &EmbeddingRocksDBWrapper::release_snapshot); + +static auto kv_tensor_wrapper = + torch::class_("fbgemm", "KVTensorWrapper") + .def( + torch::init< + c10::intrusive_ptr, + c10::intrusive_ptr, + std::vector, + int64_t>(), + "", + {torch::arg("db"), + torch::arg("snapshot_handle"), + torch::arg("shape"), + torch::arg("dtype")}) + .def("narrow", &KVTensorWrapper::narrow) + .def_property( + "shape", + &KVTensorWrapper::size, + std::string( + "Returns the shape of the original tensor. Only the narrowed part is materialized.")); TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( diff --git a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py new file mode 100644 index 0000000000..511a554e42 --- /dev/null +++ b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import tempfile +import unittest +from unittest import TestCase + +import fbgemm_gpu +import torch +import torch.testing +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.utils.loader import load_torch_module + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + from test_utils import running_on_github # @manual # pyre-ignore[21] +else: + from fbgemm_gpu.test.test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils + running_on_github, + ) + + load_torch_module( + "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings", + ) + + +@unittest.skipIf(*running_on_github) +class KvTensorWrapperTest(TestCase): + def test_basic(self) -> None: + E = int(1e4) + D = 128 + N = 1000 + weights_precision = SparseType.FP32 + weights_dtype = weights_precision.as_dtype() + + with tempfile.TemporaryDirectory() as ssd_directory: + # pyre-fixme[16]: Module `classes` has no attribute `fbgemm`. + ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( + ssd_directory, + 8, # num_shards + 8, # num_threads + 0, # ssd_memtable_flush_period, + 0, # ssd_memtable_flush_offset, + 4, # ssd_l0_files_per_compact, + D, # embedding_dim + 0, # ssd_rate_limit_mbps, + 1, # ssd_size_ratio, + 8, # ssd_compaction_trigger, + 536870912, # 512MB ssd_write_buffer_size, + 8, # ssd_max_write_buffer_num, + -0.01, # ssd_uniform_init_lower + 0.01, # ssd_uniform_init_upper + 32, # row_storage_bitwidth + 10 * (2**20), # block cache size + ) + + # create random index tensor with size N + indices = torch.randperm(N) + # insert the weights with the corresponding indices into the table + weights = torch.arange(N * D, dtype=weights_dtype).view(N, D) + output_weights = torch.empty_like(weights) + count = torch.tensor([N]) + ssd_db.set(indices, weights, count) + + # force waiting for set to complete + ssd_db.get(indices, output_weights, torch.tensor(indices.shape[0])) + torch.testing.assert_close(weights, output_weights) + + # create a view tensor wrapper + snapshot = ssd_db.create_snapshot() + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, snapshot, [E, D], weights.dtype + ) + self.assertEqual(tensor_wrapper.shape, [E, D]) + + # table has a total of E rows + # load 1000 rows at a time + step = 1000 + for i in range(0, E, step): + narrowed = tensor_wrapper.narrow(0, i, step) + for weight_ind, v in enumerate(indices): + j = v.item() + if j < i or j >= i + step: + continue + self.assertTrue( + torch.equal(narrowed[j % step], weights[weight_ind]), + msg=( + f"Tensor value mismatch at row {j}:\n" + f"actual\n{narrowed[j % step]}\n\nexpected\n{weights[weight_ind]}" + ), + )