Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [5/N] (#…
Browse files Browse the repository at this point in the history
…3129)

Summary:
Pull Request resolved: #3129

X-link: facebookresearch/FBGEMM#216

- Convert `pruned_hashmap_lookup_cpu` to use `index_t`

Differential Revision: D62472965
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 26, 2024
1 parent 278dbe5 commit b6ffa02
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ inline uint32_t pruned_hash_function(uint32_t h) {
return h;
}

inline uint64_t pruned_hash_function(uint64_t k) {
// MurmorHash3 64-bit mixing function.
k ^= k >> 33;
k *= (0xff51afd7ed558ccd);
k ^= k >> 33;
k *= (0xc4ceb9fe1a85ec53);
k ^= k >> 33;
return k;
}

} // namespace

void pruned_hashmap_insert_{{ wdesc }}_cpu(
Expand Down Expand Up @@ -404,54 +414,67 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(hash_table);
TENSOR_ON_CPU(hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);

int32_t T = hash_table_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

auto dense_indices = empty_like(indices);
const auto* indices_acc = indices.data_ptr<int32_t>();
auto* dense_indices_acc = dense_indices.data_ptr<int32_t>();

const auto* offsets_acc = offsets.data_ptr<int32_t>();
const auto hash_table_acc = hash_table.accessor<int32_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
for (const auto t : c10::irange(T)) {
int64_t table_start = hash_table_offsets_acc[t];
int64_t table_end = hash_table_offsets_acc[t + 1];
int64_t capacity = table_end - table_start;
for (const auto b : c10::irange(B)) {
int32_t indices_start = offsets_acc[t * B + b];
int32_t indices_end = offsets_acc[t * B + b + 1];
int32_t L = indices_end - indices_start;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] {
using hash_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

if (table_start == table_end) {
for (const auto l : c10::irange(L)) {
dense_indices_acc[indices_start + l] = indices_acc[indices_start + l];
}
} else {
for (const auto l : c10::irange(L)) {
int32_t idx = indices_acc[indices_start + l];
uint32_t slot = pruned_hash_function(static_cast<uint32_t>(idx)) % capacity;
while (true) {
int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast<int64_t>(slot)][0];

// empty slot
if (slot_sparse_idx == -1) {
dense_indices_acc[indices_start + l] = -1;
break;
}
// already exists
if (slot_sparse_idx == idx) {
dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast<int64_t>(slot)][1];
break;
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();

const auto* offsets_acc = offsets.data_ptr<index_t>();
const auto hash_table_acc = hash_table.accessor<index_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

if (table_start == table_end) {
for (const auto l : c10::irange(L)) {
dense_indices_acc[indices_start + l] = indices_acc[indices_start + l];
}

} else {
for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
auto slot = pruned_hash_function(static_cast<hash_t>(idx)) % capacity;

while (true) {
const auto slot_sparse_idx = hash_table_acc[table_start + static_cast<int64_t>(slot)][0];

// empty slot
if (slot_sparse_idx == -1) {
dense_indices_acc[indices_start + l] = -1;
break;
}
// already exists
if (slot_sparse_idx == idx) {
dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast<int64_t>(slot)][1];
break;
}
// linear probe
slot = (slot + 1) % capacity;
}
// linear probe
slot = (slot + 1) % capacity;
}
}
}
}
}
});

return dense_indices;
}

Expand All @@ -466,6 +489,7 @@ Tensor pruned_array_lookup_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(index_remappings);
TENSOR_ON_CPU(index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);

int32_t T = index_remappings_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Tensor pruned_hashmap_lookup_cuda(
TORCH_CHECK(hash_table.size(0) < std::numeric_limits<int32_t>::max());
constexpr size_t kForwardMaxThreads = 256;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_cuda", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
Expand Down Expand Up @@ -218,7 +218,7 @@ Tensor pruned_array_lookup_cuda(
TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim());
constexpr size_t kForwardMaxThreads = 256;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] {
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cuda", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel";
Expand Down

0 comments on commit b6ffa02

Please sign in to comment.