diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu index f20b1b97b..032ef7e86 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_grad_template.cu @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel {% for vbe in [True, False] %} {% set vdesc = "_vbe" if vbe else "" %} -template +template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32 offsets, {% if vbe %} const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -212,15 +212,16 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( //////////////////////////////////////////////////////////////////////////////// {% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %} +{% for offset_type in ['int32_t', 'int64_t'] %} template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel -<{{ grad_type }}> ( +<{{ grad_type }}, {{ offset_type }}> ( pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output_mean, const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets, {% if vbe %} const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel FixedDivisor fd_B {% endif %} ); +{% endfor %} // for offset_type in ['int32_t', 'int64_t'] {% endfor %} // for grad_type in ['at::Half', 'float'] {% endfor %} // for vbe in [True, False] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7c4d85fc3..f168bf79f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -232,13 +232,13 @@ split_embedding_backward_codegen_find_long_segments( const bool use_deterministic_algorithms); -template +template __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 offsets, + const pta::PackedTensorAccessor32 offsets, {%- if vbe %} const pta::PackedTensorAccessor32 grad_offsets, const pta::PackedTensorAccessor32 b_t_map, @@ -860,7 +860,7 @@ Tensor {{ embedding_cuda_op }}( MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64), MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64), MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32), {%- if vbe %} MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32),