Skip to content

Commit

Permalink
Re-organize the forward split code generation (pytorch#1879)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1879

- Re-organize the forward split code generation

Reviewed By: sryap

Differential Revision: D47420155

fbshipit-source-id: 607088818d1801e5fbb8d307c96a75922e8e7188
  • Loading branch information
q10 authored and facebook-github-bot committed Jul 17, 2023
1 parent 72c6035 commit 79b365b
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 133 deletions.
8 changes: 4 additions & 4 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ grad_mean{{ vbe_desc }}_kernel(
For the experimental optimizers, kMaxVecsPerThread and kThreadGroupSize are
fixed to 8 (1024 elements) and kWarpSize, respectively.
*/
#define DISPATCH_KERNEL_BODY(MAX_D, ...) \
#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \
[&] { \
constexpr auto kMaxVecsPerThread = {{ max_embedding_dim // items_per_warp }}; \
constexpr auto kThreadGroupSize = kWarpSize; \
Expand All @@ -208,7 +208,7 @@ grad_mean{{ vbe_desc }}_kernel(
is unable to use std::max in constexpr context.
*/
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
#define DISPATCH_KERNEL_BODY(MAX_D, ...) \
#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \
[&] { \
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
Expand All @@ -223,7 +223,7 @@ grad_mean{{ vbe_desc }}_kernel(
}()

#else
#define DISPATCH_KERNEL_BODY(MAX_D, ...) \
#define DISPATCH_OPTIMAL_KERNEL(MAX_D, ...) \
[&] { \
constexpr int kThreadGroupSize = kWarpSize; \
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
Expand Down Expand Up @@ -573,7 +573,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi
}
{%- endif %}
DISPATCH_KERNEL_BODY(max_D, [&] {
DISPATCH_OPTIMAL_KERNEL(max_D, [&] {
// Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB; H100: 144 KB), BT_block_size must be a power of two.
while (BT_block_size * sizeof(at::acc_type<cache_t, true>) * 4 * kWarpSize * kMaxVecsPerThread >= used_shared_bytes) {
BT_block_size /= 2;
Expand Down
171 changes: 66 additions & 105 deletions fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,76 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }
}


////////////////////////////////////////////////////////////////////////////////
// Explicit Template Instantiations
////////////////////////////////////////////////////////////////////////////////

/*
Explicitly instantiate the kernel function template. The instantiations are
based on the types enumerated by DISPATCH_EMB_CACHE_TYPES macro used in
embedding_forward_split_template.cu
*/

{%- for output_type in ['uint8_t', 'at::Half', 'float'] %}
{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- macro template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) %}
template __launch_bounds__(kForwardMaxThreads) __global__
void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel
<
{{ emb_type }},
{{ cache_type }},
{{ output_type }},
{%- if not dense %}
{{ use_cache }},
{%- endif %}
int64_t,
{%- if not nobag %}
{{- kMaxVecsPerThread }},
{%- endif %}
{{ kThreadGroupSize }}
> (
const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
const pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
{%- if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> output_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- else %}
FixedDivisor fd_B,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
{%- if not nobag %}
int64_t pooling_mode,
{%- endif %}
{%- if weighted %}
pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> indice_weights,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> lxu_cache_locations,
{%- endif %}
pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output);
{%- endmacro %}

{%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %}
{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for output_type in ['uint8_t', 'at::Half', 'float'] %}
{{ template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


////////////////////////////////////////////////////////////////////////////////
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
Expand Down Expand Up @@ -419,55 +480,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }
(NULL,·NULL,·(kWarpSize·/·1))
*/ #}
{%- for (use_cache, kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}

template __launch_bounds__(kForwardMaxThreads) __global__
void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel
<
{{ emb_type }},
{{ cache_type }},
{{ output_type }},
{%- if not dense %}
{{ use_cache }},
{%- endif %}
int64_t,
{%- if not nobag %}
{{- kMaxVecsPerThread }},
{%- endif %}
{{ kThreadGroupSize }}
> (
const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
const pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
{%- if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> output_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- else %}
FixedDivisor fd_B,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
{%- if not nobag %}
int64_t pooling_mode,
{%- endif %}
{%- if weighted %}
pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> indice_weights,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> lxu_cache_locations,
{%- endif %}
pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output);

{{ bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -528,61 +541,9 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }
(NULL,·NULL,·kWarpSize)
*/ #}
{%- for (use_cache, kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}

template __launch_bounds__(kForwardMaxThreads) __global__
void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel
<
{{ emb_type }},
{{ cache_type }},
{{ output_type }},
{%- if not dense %}
{{ use_cache }},
{%- endif %}
int64_t,
{%- if not nobag %}
{{- kMaxVecsPerThread }},
{%- endif %}
{{ kThreadGroupSize }}
> (
const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
const pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
const pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
{%- if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> output_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- else %}
FixedDivisor fd_B,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
{%- if not nobag %}
int64_t pooling_mode,
{%- endif %}
{%- if weighted %}
pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> indice_weights,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> lxu_cache_locations,
{%- endif %}
pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output);

{{ bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}

////////////////////////////////////////////////////////////////////////////////
#endif
////////////////////////////////////////////////////////////////////////////////

{%- endfor %}
{%- endfor %}
{%- endfor %}
Loading

0 comments on commit 79b365b

Please sign in to comment.