Skip to content

Commit

Permalink
Add support for int32_t indices (pytorch#3319)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#414


- Update template instantiations of v2 and nobag kernels to address undefined symbols error

- Fix `hash_size_cumsum` in `transpose_embedding_input` to accept `int64_t` only

Differential Revision: D62794566
  • Loading branch information
q10 authored and facebook-github-bot committed Nov 4, 2024
1 parent 646d173 commit 16a237a
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel

{% for vbe in [True, False] %}
{% set vdesc = "_vbe" if vbe else "" %}
template <typename grad_t>
template <typename grad_t, typename index_t>
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand Down Expand Up @@ -212,6 +212,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
////////////////////////////////////////////////////////////////////////////////

{% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %}
{% for index_type in ['int32_t', 'int64_t'] %}
template __global__ __launch_bounds__(kMaxThreads)
void grad_mean{{ vdesc }}_kernel
<{{ grad_type }}> (
Expand All @@ -220,7 +221,7 @@ void grad_mean{{ vdesc }}_kernel
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand All @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel
FixedDivisor fd_B
{% endif %}
);
{% endfor %} // for index_type in ['int32_t', 'int64_t']
{% endfor %} // for grad_type in ['at::Half', 'float']
{% endfor %} // for vbe in [True, False]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread
>
__global__ __launch_bounds__(kForwardMaxThreads) void
Expand All @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
{%- if not dense %}
const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }},
{%- endif %}
Expand Down Expand Up @@ -117,8 +118,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
int32_t D_start = D_offsets[t];
int32_t D_end = D_offsets[t + 1];
int32_t D = D_end - D_start;
int64_t indices_start = offsets[b_t];
int64_t indices_end = offsets[b_t + 1];
index_t indices_start = offsets[b_t];
index_t indices_end = offsets[b_t + 1];
int32_t L = indices_end - indices_start;
if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) {
// If the table does not require gradient computation, we set the gradient to zero.
Expand Down Expand Up @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void

for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
auto idx = l < L ? indices[indices_start + l] : 0;
{%- if not dense %}
const auto {{ locs_or_addrs_idx }} =
(placement == PlacementType::MANAGED_CACHING && l < L)
? {{ locs_or_addrs_tensor }}[indices_start + l] : 0;
{%- endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
auto idx_j = shfl_sync(idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
{%- endif %}
Expand Down Expand Up @@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
const uint32_t info_B_mask = info_B_mask_int64;
{%- endif %}

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
Expand Down Expand Up @@ -386,6 +388,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread><<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
Expand All @@ -400,8 +403,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
{%- endif %}
Expand All @@ -421,6 +424,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
});
{%- endfor %} {# /* for use_vec_blocking */ #}
});
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_indice_weights;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -104,8 +105,8 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_type_combo[ph_name].primitive_type }},
{%- endfor %}
Expand All @@ -469,8 +472,8 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for ph_type_combo in args.placeholder_type_combos %}
{{ template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t"}},
{%- endfor %}
Expand Down Expand Up @@ -89,8 +90,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_type_combo[ph_name].primitive_type }},
{%- endfor %}
Expand All @@ -380,8 +383,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for ph_type_combo in args.placeholder_type_combos %}
{{ template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


Expand Down
Loading

0 comments on commit 16a237a

Please sign in to comment.