Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int32_t indices in TBE training (2D/N) #3374

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,7 @@ endif()

# Silence warnings in asmjit
target_compile_options(fbgemm_gpu_py PRIVATE
-Wno-deprecated-anon-enum-enum-conversion)
target_compile_options(fbgemm_gpu_py PRIVATE
-Wno-deprecated-anon-enum-enum-conversion
-Wno-deprecated-declarations)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel

{% for vbe in [True, False] %}
{% set vdesc = "_vbe" if vbe else "" %}
template <typename grad_t>
template <typename grad_t, typename offset_t>
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand Down Expand Up @@ -212,15 +212,16 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
////////////////////////////////////////////////////////////////////////////////

{% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %}
{% for offset_type in ['int32_t', 'int64_t'] %}
template __global__ __launch_bounds__(kMaxThreads)
void grad_mean{{ vdesc }}_kernel
<{{ grad_type }}> (
<{{ grad_type }}, {{ offset_type }}> (
pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand All @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel
FixedDivisor fd_B
{% endif %}
);
{% endfor %} // for offset_type in ['int32_t', 'int64_t']
{% endfor %} // for grad_type in ['at::Half', 'float']
{% endfor %} // for vbe in [True, False]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ split_embedding_backward_codegen_find_long_segments(
const bool use_deterministic_algorithms);


template <typename grad_t>
template <typename grad_t, typename offset_t>
__global__ __launch_bounds__(kMaxThreads) void
grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits> grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits> grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
{%- if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand Down Expand Up @@ -742,6 +742,7 @@ Tensor {{ embedding_cuda_op }}(
else {
{{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }});
size_t temp_storage_bytes = 0;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] {
AT_CUDA_CHECK(radix_sort_pairs(
nullptr,
temp_storage_bytes,
Expand All @@ -753,9 +754,11 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));

auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));

AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
Expand All @@ -767,6 +770,7 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
});
}
}

Expand All @@ -775,6 +779,7 @@ Tensor {{ embedding_cuda_op }}(
}
{%- endif %}

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
Expand All @@ -800,9 +805,11 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));

auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));

AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
Expand Down Expand Up @@ -853,7 +860,7 @@ Tensor {{ embedding_cuda_op }}(
MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32),
{%- if vbe %}
MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32),
Expand Down Expand Up @@ -1181,6 +1188,7 @@ Tensor {{ embedding_cuda_op }}(

}); // DISPATCH_OPTIMAL_KERNEL
}); // DISPATCH_EMB_GRAD_CACHE_TYPES
}); // AT_DISPATCH_INDEX_TYPES

{%- if dense %}
return grad_dev_weights;
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,13 @@ transpose_embedding_input(
int end_bit = sizeof(KeyT) * 8, \
cudaStream_t stream = 0)

DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, float);
DECL_RADIX_SORT_PAIRS_FN(int64_t, double);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DECL_RADIX_SORT_PAIRS_FN(int32_t, int32_t);
DECL_RADIX_SORT_PAIRS_FN(int32_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int32_t, float);
DECL_RADIX_SORT_PAIRS_FN(int32_t, double);

#undef DECL_RADIX_SORT_PAIRS_FN
8 changes: 6 additions & 2 deletions fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ using namespace fbgemm_gpu;
}
#endif

DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DEF_RADIX_SORT_PAIRS_FN(int64_t, float);
DEF_RADIX_SORT_PAIRS_FN(int64_t, double);
DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DEF_RADIX_SORT_PAIRS_FN(int32_t, int32_t);
DEF_RADIX_SORT_PAIRS_FN(int32_t, int64_t);
DEF_RADIX_SORT_PAIRS_FN(int32_t, float);
DEF_RADIX_SORT_PAIRS_FN(int32_t, double);
30 changes: 14 additions & 16 deletions fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) {

template <typename index_t, typename info_acc_t, bool nobag, bool vbe>
__global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
Expand All @@ -79,7 +79,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
// Use a raw pointer to avoid creating dummy PackedTensorAccessor
const uint32_t* const __restrict__ vbe_b_t_map,
FixedDivisor fd) {
const int32_t T = hash_size_cumsum.size(0) - 1;
const auto T = hash_size_cumsum.size(0) - 1;
auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
int32_t b;
int32_t t;
Expand All @@ -97,17 +97,16 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
}

const index_t hash_offset = valid ? hash_size_cumsum[t] : -1;
const index_t indices_start = valid ? offsets[b_t] : -1;
const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0;
const auto indices_start = valid ? offsets[b_t] : -1;
const auto L = valid ? offsets[b_t + 1] - indices_start : 0;
const int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize;

// Compile-time conditional
if (nobag) {
for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) {
const index_t indices_start_warp =
fbgemm_gpu::shfl_sync(indices_start, j);
const int32_t t_warp = fbgemm_gpu::shfl_sync(t, j);
const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j);
const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j);
const auto t_warp = fbgemm_gpu::shfl_sync(t, j);
const auto L_warp = fbgemm_gpu::shfl_sync(L, j);
const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j);
for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) {
const index_t idx = __ldg(&indices[indices_start_warp + i]);
Expand All @@ -124,10 +123,9 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
reinterpret_cast<uint32_t*>(&b)[0];
}
for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) {
const index_t indices_start_warp =
fbgemm_gpu::shfl_sync(indices_start, j);
const uint32_t info_warp = fbgemm_gpu::shfl_sync(info, j);
const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j);
const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j);
const auto info_warp = fbgemm_gpu::shfl_sync(info, j);
const auto L_warp = fbgemm_gpu::shfl_sync(L, j);
const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j);
for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) {
const index_t idx = __ldg(&indices[indices_start_warp + i]);
Expand All @@ -142,7 +140,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
template <typename index_t, typename info_acc_t>
__global__
__launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
Expand All @@ -153,7 +151,7 @@ __launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel(
linear_indices,
FixedDivisor fd,
int32_t fixed_L_per_warp) {
const int32_t T = hash_size_cumsum.size(0) - 1;
const auto T = hash_size_cumsum.size(0) - 1;
auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
int32_t b;
int32_t t;
Expand Down Expand Up @@ -258,7 +256,7 @@ transpose_embedding_input(
kMaxThreads, \
0, \
at::cuda::getCurrentCUDAStream()>>>( \
MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, index_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, int64_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, infos, INFO_ACC_T, 1, 32), \
Expand Down Expand Up @@ -296,7 +294,7 @@ transpose_embedding_input(
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, hash_size_cumsum, index_t, 1, 32),
func_name, hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, total_L_offsets.value(), index_t, 1, 32),
Expand Down
Loading