Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [6/N] (#…
Browse files Browse the repository at this point in the history
…3182)

Summary:
Pull Request resolved: #3182

X-link: facebookresearch/FBGEMM#278

- Convert `PrunedMapCPU::lookup` to use `index_t`

Differential Revision: D62602764
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 26, 2024
1 parent 681c82e commit 948a8bb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/utils/tensor_utils.h"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -374,29 +375,37 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder {
}

Tensor lookup(Tensor indices, Tensor offsets) const {
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

int32_t T = maps_.size();
TORCH_CHECK(T > 0);
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);
TORCH_CHECK(maps_.size() == T);

auto dense_indices = empty_like(indices);
const auto* indices_acc = indices.data_ptr<int32_t>();
auto* dense_indices_acc = dense_indices.data_ptr<int32_t>();
const auto* offsets_acc = offsets.data_ptr<int32_t>();
for (const auto t : c10::irange(T)) {
auto& map = maps_[t];
for (const auto b : c10::irange(B)) {
int32_t indices_start = offsets_acc[t * B + b];
int32_t indices_end = offsets_acc[t * B + b + 1];
int32_t L = indices_end - indices_start;
for (const auto l : c10::irange(L)) {
int32_t slot_sparse_index = indices_acc[indices_start + l];
auto it = map.find(slot_sparse_index);
dense_indices_acc[indices_start + l] =
it != map.end() ? it->second : -1;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "PrunedMapCPU::lookup", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

for (const auto t : c10::irange(T)) {
auto& map = maps_[t];
for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;
for (const auto l : c10::irange(L)) {
const auto slot_sparse_index = indices_acc[indices_start + l];
const auto it = map.find(slot_sparse_index);
dense_indices_acc[indices_start + l] =
it != map.end() ? it->second : -1;
}
}
}
}
});

return dense_indices;
}

Expand Down
38 changes: 20 additions & 18 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,22 +233,24 @@ void bounds_check_indices_cuda(
constexpr size_t kNumThreads = 256;
const auto max_B_ = vbe ? max_B : B;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] {
const auto bounds_check_kernel =
(vbe ? bounds_check_indices_kernel<index_t, true>
: bounds_check_indices_kernel<index_t, false>);
TORCH_DSA_KERNEL_LAUNCH(
bounds_check_kernel,
div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize),
dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize),
0,
at::cuda::getCurrentCUDAStream(),
rows_per_table.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
vbe ? B_offsets.value().data_ptr<int32_t>() : nullptr,
bounds_check_mode_,
warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
FixedDivisor(max_B_));
});
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "bounds_check_indices_cuda", [&] {
const auto bounds_check_kernel =
(vbe ? bounds_check_indices_kernel<index_t, true>
: bounds_check_indices_kernel<index_t, false>);
TORCH_DSA_KERNEL_LAUNCH(
bounds_check_kernel,
div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize),
dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize),
0,
at::cuda::getCurrentCUDAStream(),
rows_per_table
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
vbe ? B_offsets.value().data_ptr<int32_t>() : nullptr,
bounds_check_mode_,
warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
FixedDivisor(max_B_));
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void bounds_check_indices_cpu(
const auto rows_per_table_acc = rows_per_table.accessor<int64_t, 1>();
auto warning_acc = warning.data_ptr<int64_t>();

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] {
auto offsets_acc = offsets.accessor<index_t, 1>();
auto indices_acc = indices.accessor<index_t, 1>();
auto num_indices = indices.numel();
Expand Down

0 comments on commit 948a8bb

Please sign in to comment.