Skip to content

Commit

Permalink
support torch class binding in fbgemm (#3151)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3151

X-link: facebookresearch/FBGEMM#243

* add torch class binding
* add numeric tests

usage see `KvTensorWrapperTest`

Reviewed By: duduyi2013

Differential Revision: D62902959
  • Loading branch information
xunnanxu authored and facebook-github-bot committed Sep 20, 2024
1 parent 7628ba2 commit 9a46ccd
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "fbgemm_gpu/utils/ops_utils.h"

using namespace at;
using namespace ssd;

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>
ssd_cache_populate_actions_cuda(
Expand Down Expand Up @@ -255,6 +256,16 @@ void compact_indices_cuda(
Tensor count);

namespace {
class KVTensorWrapper;

struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
explicit EmbeddingSnapshotHandleWrapper(
const EmbeddingRocksDB::SnapshotHandle* handle)
: handle(handle) {}

const EmbeddingRocksDB::SnapshotHandle* handle;
};

class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
public:
EmbeddingRocksDBWrapper(
Expand Down Expand Up @@ -351,11 +362,77 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->wait_util_filling_work_done();
}

c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> create_snapshot() {
auto handle = impl_->create_snapshot();
return c10::make_intrusive<EmbeddingSnapshotHandleWrapper>(handle);
}

void release_snapshot(
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) {
auto handle = snapshot_handle->handle;
CHECK_NE(handle, nullptr);
CHECK(impl_->is_valid_snapshot(handle));
impl_->release_snapshot(handle);
}

int64_t get_max_D() {
return impl_->get_max_D();
}

private:
friend class KVTensorWrapper;

// shared pointer since we use shared_from_this() in callbacks.
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
};

class KVTensorWrapper : public torch::jit::CustomClassHolder {
public:
explicit KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
const c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>& snapshot_handle,
std::vector<int64_t> shape,
int64_t dtype)
: db_(std::move(db->impl_)),
snapshot_handle_(snapshot_handle->handle),
shape_(std::move(shape)) {
CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported";
options_ = at::TensorOptions()
.dtype(static_cast<c10::ScalarType>(dtype))
.device(at::kCPU)
.layout(at::kStrided);
}

at::Tensor narrow(int64_t dim, int64_t start, int64_t length) {
CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported";
CHECK_EQ(db_->get_max_D(), shape_[1]);
auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_);
db_->get_range_from_snapshot(t, start, length, snapshot_handle_);
// TBE may have multiple embeddings in one table padded to max D
// narrow to the actual shape here before returning
return t.narrow(1, 0, shape_[1]);
}

c10::IntArrayRef size() {
return shape_;
}

c10::ScalarType dtype() {
return options_.dtype().toScalarType();
}

private:
std::shared_ptr<EmbeddingRocksDB> db_;
const EmbeddingRocksDB::SnapshotHandle* snapshot_handle_;
at::TensorOptions options_;
std::vector<int64_t> shape_;
};

static auto embedding_snapshot_handle_wrapper =
torch::class_<EmbeddingSnapshotHandleWrapper>(
"fbgemm",
"EmbeddingSnapshotHandleWrapper");

static auto embedding_rocks_db_wrapper =
torch::class_<EmbeddingRocksDBWrapper>("fbgemm", "EmbeddingRocksDBWrapper")
.def(
Expand Down Expand Up @@ -434,7 +511,29 @@ static auto embedding_rocks_db_wrapper =
.def("reset_l2_cache", &EmbeddingRocksDBWrapper::reset_l2_cache)
.def(
"wait_util_filling_work_done",
&EmbeddingRocksDBWrapper::wait_util_filling_work_done);
&EmbeddingRocksDBWrapper::wait_util_filling_work_done)
.def("create_snapshot", &EmbeddingRocksDBWrapper::create_snapshot)
.def("release_snapshot", &EmbeddingRocksDBWrapper::release_snapshot);

static auto kv_tensor_wrapper =
torch::class_<KVTensorWrapper>("fbgemm", "KVTensorWrapper")
.def(
torch::init<
c10::intrusive_ptr<EmbeddingRocksDBWrapper>,
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>,
std::vector<int64_t>,
int64_t>(),
"",
{torch::arg("db"),
torch::arg("snapshot_handle"),
torch::arg("shape"),
torch::arg("dtype")})
.def("narrow", &KVTensorWrapper::narrow)
.def_property(
"shape",
&KVTensorWrapper::size,
std::string(
"Returns the shape of the original tensor. Only the narrowed part is materialized."));

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
Expand Down
98 changes: 98 additions & 0 deletions fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import tempfile
import unittest
from unittest import TestCase

import fbgemm_gpu
import torch
import torch.testing
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.utils.loader import load_torch_module

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
from test_utils import running_on_github # @manual # pyre-ignore[21]
else:
from fbgemm_gpu.test.test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
running_on_github,
)

load_torch_module(
"//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings",
)


@unittest.skipIf(*running_on_github)
class KvTensorWrapperTest(TestCase):
def test_basic(self) -> None:
E = int(1e4)
D = 128
N = 1000
weights_precision = SparseType.FP32
weights_dtype = weights_precision.as_dtype()

with tempfile.TemporaryDirectory() as ssd_directory:
# pyre-fixme[16]: Module `classes` has no attribute `fbgemm`.
ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
8, # num_shards
8, # num_threads
0, # ssd_memtable_flush_period,
0, # ssd_memtable_flush_offset,
4, # ssd_l0_files_per_compact,
D, # embedding_dim
0, # ssd_rate_limit_mbps,
1, # ssd_size_ratio,
8, # ssd_compaction_trigger,
536870912, # 512MB ssd_write_buffer_size,
8, # ssd_max_write_buffer_num,
-0.01, # ssd_uniform_init_lower
0.01, # ssd_uniform_init_upper
32, # row_storage_bitwidth
10 * (2**20), # block cache size
)

# create random index tensor with size N
indices = torch.randperm(N)
# insert the weights with the corresponding indices into the table
weights = torch.arange(N * D, dtype=weights_dtype).view(N, D)
output_weights = torch.empty_like(weights)
count = torch.tensor([N])
ssd_db.set(indices, weights, count)

# force waiting for set to complete
ssd_db.get(indices, output_weights, torch.tensor(indices.shape[0]))
torch.testing.assert_close(weights, output_weights)

# create a view tensor wrapper
snapshot = ssd_db.create_snapshot()
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
ssd_db, snapshot, [E, D], weights.dtype
)
self.assertEqual(tensor_wrapper.shape, [E, D])

# table has a total of E rows
# load 1000 rows at a time
step = 1000
for i in range(0, E, step):
narrowed = tensor_wrapper.narrow(0, i, step)
for weight_ind, v in enumerate(indices):
j = v.item()
if j < i or j >= i + step:
continue
self.assertTrue(
torch.equal(narrowed[j % step], weights[weight_ind]),
msg=(
f"Tensor value mismatch at row {j}:\n"
f"actual\n{narrowed[j % step]}\n\nexpected\n{weights[weight_ind]}"
),
)

0 comments on commit 9a46ccd

Please sign in to comment.