Skip to content

Commit

Permalink
Add support for int64_t indices in TBE inference [2/N] (#3125)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#214

Pull Request resolved: #3125

- Add support for int64_t indices in TBE inference [2/N]

- Convert `pruned_array_lookup_cuda` to use index_t

Reviewed By: jianyuh

Differential Revision: D62271409
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 26, 2024
1 parent ebbe476 commit 0466806
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,42 +89,43 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
}
}

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
const int32_t B,
const int32_t T,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
const int32_t t = b_t / B;
const int32_t b = b_t % B;
if (b_t >= B * T) {
return;
}
const int32_t indices_start = offsets[t * B + b];
const int32_t indices_end = offsets[t * B + b + 1];
const int32_t L = indices_end - indices_start;
const index_t indices_start = offsets[t * B + b];
const index_t indices_end = offsets[t * B + b + 1];
const index_t L = indices_end - indices_start;

const int64_t index_remappings_start = index_remappings_offsets[t];
const int64_t index_remappings_end = index_remappings_offsets[t + 1];
const int64_t capacity = index_remappings_end - index_remappings_start;

if (capacity > 0) {
for (int32_t l = threadIdx.x; l < L; l += blockDim.x) {
int32_t idx = indices[indices_start + l];
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
index_t idx = indices[indices_start + l];
dense_indices[indices_start + l] =
index_remappings[index_remappings_start + idx];
}
} else {
for (int32_t l = threadIdx.x; l < L; l += blockDim.x) {
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
dense_indices[indices_start + l] = indices[indices_start + l];
}
}
Expand Down Expand Up @@ -178,6 +179,7 @@ Tensor pruned_array_lookup_cuda(
Tensor index_remappings_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, index_remappings, index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);
CUDA_DEVICE_GUARD(indices);
Expand All @@ -204,23 +206,26 @@ Tensor pruned_array_lookup_cuda(
TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim());
constexpr size_t kForwardMaxThreads = 256;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
#endif
nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return dense_indices;
}
38 changes: 36 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ std::string tensor_scalar_type_is_one_of(
const ScalarTypes&... ttypes) {
auto has_match = false;

// Collect the GPU index of the first non-empty optional tensor and make sure
// that all tensors are on this same index.
(
[&](const auto& ttype) {
if (ten.scalar_type() == ttype) {
Expand Down Expand Up @@ -339,3 +337,39 @@ std::string tensor_scalar_type_is_one_of(
const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \
TORCH_CHECK(has_match.empty(), has_match); \
} while (false)

template <typename... Tensors>
std::string tensors_have_same_scalar_type(const Tensors&... tensors) {
std::optional<at::ScalarType> dtype;
bool have_same_type = true;

(
[&](const auto& tensor) {
if (!dtype) {
dtype = tensor.scalar_type();
} else if (*dtype != tensor.scalar_type()) {
have_same_type = false;
}
}(tensors),
...);

if (have_same_type) {
return "";
}

std::string msg = "Tensors' scalar types (";
(
[&](const auto& tensor) {
msg.append(toString(tensor.scalar_type()));
msg.append(", ");
}(tensors),
...);
msg.append(") are not one and the same!");
return msg;
}

#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \
do { \
const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \
TORCH_CHECK(have_same_type.empty(), have_same_type); \
} while (false)

0 comments on commit 0466806

Please sign in to comment.