Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2/N) (#3326)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#420

- Add `index_t` support to TBE training backward kernels

Reviewed By: basilwong

Differential Revision: D65464554
  • Loading branch information
q10 authored and facebook-github-bot committed Nov 13, 2024
1 parent d9d4066 commit f13c09f
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel

{% for vbe in [True, False] %}
{% set vdesc = "_vbe" if vbe else "" %}
template <typename grad_t>
template <typename grad_t, typename offset_t>
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand Down Expand Up @@ -175,12 +175,12 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
fd_B.DivMod(b_t, &t, &b);
{% endif %}

int32_t D_start = D_offsets[t];
int32_t D_end = D_offsets[t + 1];
int32_t D = D_end - D_start;
int64_t indices_start = offsets[b_t];
int64_t indices_end = offsets[b_t + 1];
int32_t L = indices_end - indices_start;
const auto D_start = D_offsets[t];
const auto D_end = D_offsets[t + 1];
const auto D = D_end - D_start;
const auto indices_start = offsets[b_t];
const auto indices_end = offsets[b_t + 1];
const auto L = indices_end - indices_start;

{% if vbe %}
const auto grad_offset = row_grad_offsets[b_t];
Expand Down Expand Up @@ -212,6 +212,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
////////////////////////////////////////////////////////////////////////////////

{% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %}
{% for offset_type in ['int32_t', 'int64_t'] %}
template __global__ __launch_bounds__(kMaxThreads)
void grad_mean{{ vdesc }}_kernel
<{{ grad_type }}> (
Expand All @@ -220,7 +221,7 @@ void grad_mean{{ vdesc }}_kernel
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand All @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel
FixedDivisor fd_B
{% endif %}
);
{% endfor %} // for offset_type in ['int32_t', 'int64_t']
{% endfor %} // for grad_type in ['at::Half', 'float']
{% endfor %} // for vbe in [True, False]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread
>
__global__ __launch_bounds__(kForwardMaxThreads) void
Expand All @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
{%- if not dense %}
const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }},
{%- endif %}
Expand Down Expand Up @@ -113,13 +114,13 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
fd_B.DivMod(b_t, &t, &b);
{%- endif %}

int64_t weights_offset = weights_offsets[t];
int32_t D_start = D_offsets[t];
int32_t D_end = D_offsets[t + 1];
int32_t D = D_end - D_start;
int64_t indices_start = offsets[b_t];
int64_t indices_end = offsets[b_t + 1];
int32_t L = indices_end - indices_start;
const auto weights_offset = weights_offsets[t];
const auto D_start = D_offsets[t];
const auto D_end = D_offsets[t + 1];
const auto D = D_end - D_start;
const auto indices_start = offsets[b_t];
const auto indices_end = offsets[b_t + 1];
const auto L = indices_end - indices_start;
if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) {
// If the table does not require gradient computation, we set the gradient to zero.
for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
Expand Down Expand Up @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void

for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
auto idx = l < L ? indices[indices_start + l] : 0;
{%- if not dense %}
const auto {{ locs_or_addrs_idx }} =
(placement == PlacementType::MANAGED_CACHING && l < L)
? {{ locs_or_addrs_tensor }}[indices_start + l] : 0;
{%- endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
auto idx_j = shfl_sync(idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
{%- endif %}
Expand Down Expand Up @@ -354,72 +355,74 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
const uint32_t info_B_mask = info_B_mask_int64;
{%- endif %}

DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
{%- if not dense %}
lxu_cache_weights.scalar_type(),
{%- else %}
dev_weights.scalar_type(),
{%- endif %}
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel",
[&] {
{%- if vbe %}
const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1});
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
{%- if not dense %}
lxu_cache_weights.scalar_type(),
{%- else %}
const auto& grad_output_reshaped = aligned_grad_output;
dev_weights.scalar_type(),
{%- endif %}
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2",
[&] {
{%- if vbe %}
const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1});
{%- else %}
const auto& grad_output_reshaped = aligned_grad_output;
{%- endif %}

{%- for use_vec_blocking in [False, True] %}
{%- set vbdesc = "vec_blocking_" if use_vec_blocking else "" %}
{%- set dpdesc = "NON_" if not use_vec_blocking else "" %}
DISPATCH_{{ dpdesc }}VEC_BLOCKING_KERNEL(max_D, [&] {
{%- set kernel_name =
"{}_embedding_codegen_grad_indice_weights{}_{}kernel".format(
mdesc, vdesc, vbdesc)
%}
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"{{ kernel_name }}";
#endif
{{ kernel_name }}<
emb_t,
grad_t,
cache_t,
kFixedMaxVecsPerThread><<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, grad_output_reshaped, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32),
MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32),
{%- if vbe %}
MAKE_PTA_WITH_NAME(func_name, vbe_row_output_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, vbe_b_t_map, int32_t, 1, 32),
info_B_num_bits,
info_B_mask
{%- else %}
FixedDivisor(total_B / T)
{%- endif %}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
{%- for use_vec_blocking in [False, True] %}
{%- set vbdesc = "vec_blocking_" if use_vec_blocking else "" %}
{%- set dpdesc = "NON_" if not use_vec_blocking else "" %}
DISPATCH_{{ dpdesc }}VEC_BLOCKING_KERNEL(max_D, [&] {
{%- set kernel_name =
"{}_embedding_codegen_grad_indice_weights{}_{}kernel".format(
mdesc, vdesc, vbdesc)
%}
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "{{ kernel_name }}";
#endif
{{ kernel_name }}<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread><<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, grad_output_reshaped, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32),
MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32),
{%- if vbe %}
MAKE_PTA_WITH_NAME(func_name, vbe_row_output_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, vbe_b_t_map, int32_t, 1, 32),
info_B_num_bits,
info_B_mask
{%- else %}
FixedDivisor(total_B / T)
{%- endif %}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
});
{%- endfor %} {# /* for use_vec_blocking */ #}
});
{%- endfor %} {# /* for use_vec_blocking */ #}
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -105,7 +106,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_type_combo[ph_name].primitive_type }},
{%- endfor %}
Expand All @@ -470,7 +473,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for ph_type_combo in args.placeholder_type_combos %}
{{ template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t"}},
{%- endfor %}
Expand Down Expand Up @@ -90,7 +91,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_type_combo[ph_name].primitive_type }},
{%- endfor %}
Expand All @@ -381,7 +384,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for ph_type_combo in args.placeholder_type_combos %}
{{ template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


Expand Down
Loading

0 comments on commit f13c09f

Please sign in to comment.