Skip to content

Commit

Permalink
Re-organize the backward split code generation, pt.2 (pytorch#1873)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1873

- Use Jinja macros to simplify code generation in `embedding_backward_split_kernel_cta_template.cu` and `embedding_backward_split_kernel_warp_template.cu`

- Update code generation of the kernel instantiations for the experimental optimizer cases

Reviewed By: sryap

Differential Revision: D47385306

fbshipit-source-id: 9992efa24f953890bfb1bc87fb1e8912615b1193
  • Loading branch information
q10 authored and facebook-github-bot committed Jul 14, 2023
1 parent 07c3f6b commit 9e24d2a
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 318 deletions.
156 changes: 55 additions & 101 deletions fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

////////////////////////////////////////////////////////////////////////////////
// Kernel Template Definition
////////////////////////////////////////////////////////////////////////////////

template <
typename emb_t,
typename grad_t,
Expand Down Expand Up @@ -384,42 +388,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
} // for each run
}

////////////////////////////////////////////////////////////////////////////////
// Explicit Template Instantiations
////////////////////////////////////////////////////////////////////////////////

/*
Explicitly instantiate the kernel function template. The instantiations are
based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in
embedding_backward_split_template.cu
*/

{%- for grad_type in ['float', 'at::Half'] %}
{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}

////////////////////////////////////////////////////////////////////////////////
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
////////////////////////////////////////////////////////////////////////////////

{#- /*
Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize)
in the FBGEMM_USE_SUBWARP_SHUFFLE case
constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1);
constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1);
This is needed to compute the unique tuples to use for explicit instantiation,
so that we can avoid duplicate template instantiations.
*/ #}
{%- set tuples = [] %}
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %}
{%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %}
{%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %}
{%- endif %}
{%- endfor %}

{#- /* Enumerate over the unique tuples */ #}
{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}

{%- macro template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) %}
template __global__ __launch_bounds__(kMaxThreads)
void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1
< {{ emb_type }},
Expand Down Expand Up @@ -484,7 +463,51 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize
const int32_t max_segment_length_per_cta,
const bool use_deterministic_algorithms,
{{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }});
{%- endmacro %}

{%- macro bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) %}
{%- for grad_type in ['float', 'at::Half'] %}
{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{{ template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


{%- if is_experimental_optimizer %}

{{ bulk_template_instantiations(max_embedding_dim // items_per_warp, 'kWarpSize') }}

{%- else %}

////////////////////////////////////////////////////////////////////////////////
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
////////////////////////////////////////////////////////////////////////////////

{#- /*
Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize)
in the FBGEMM_USE_SUBWARP_SHUFFLE case
constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1);
constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1);
This is needed to compute the unique tuples to use for explicit instantiation,
so that we can avoid duplicate template instantiations.
*/ #}
{%- set tuples = [] %}
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %}
{%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %}
{%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %}
{%- endif %}
{%- endfor %}

{#- /* Enumerate over the unique tuples */ #}
{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}
{{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -502,87 +525,18 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %}
{%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %}
{%- set temp = tuples.append((t0, "kWarpSize")) %}
{%- endif %}
{%- endfor %}

{#- /* Enumerate over the unique tuples */ #}
{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}

template __global__ __launch_bounds__(kMaxThreads)
void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ kMaxVecsPerThread }},
{{ kThreadGroupSize }}
> (
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output,
{%- if optimizer != "none" %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
{%- endif %}
{%- endif %} // if optimizer != "none"
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- else %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_lxu_cache_locations,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
{%- endif %}
{%- if not dense and optimizer != "none" %}
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{%- else %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights,
{%- if optimizer == "none" %}
const int32_t max_D,
{%- endif %}
{%- endif %} // if not dense and optimizer != "none"
{%- if not nobag and vbe %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> B_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> output_offsets,
{%- endif %}
{%- if not nobag %}
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- endif %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_id_to_really_long_run_ids,
pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 2, at::RestrictPtrTraits> temp_grad_accum,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> grad_accum_counter,
const int32_t max_segment_length_per_cta,
const bool use_deterministic_algorithms,
{{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }});

{{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}

////////////////////////////////////////////////////////////////////////////////
#endif
////////////////////////////////////////////////////////////////////////////////

{%- endfor %}
{%- endfor %}
{%- endfor %}

{%- endif %}
// clang-format on
149 changes: 55 additions & 94 deletions fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

////////////////////////////////////////////////////////////////////////////////
// Kernel Template Definition
////////////////////////////////////////////////////////////////////////////////

template <
typename emb_t,
typename grad_t,
Expand Down Expand Up @@ -228,42 +232,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
}


////////////////////////////////////////////////////////////////////////////////
// Explicit Template Instantiations
////////////////////////////////////////////////////////////////////////////////

/*
Explicitly instantiate the kernel function template. The instantiations are
based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in
embedding_backward_split_template.cu
*/

{%- for grad_type in ['float', 'at::Half'] %}
{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}

////////////////////////////////////////////////////////////////////////////////
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
////////////////////////////////////////////////////////////////////////////////

{#- /*
Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize)
in the FBGEMM_USE_SUBWARP_SHUFFLE case
constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1);
constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1);
This is needed to compute the unique tuples to use for explicit instantiation,
so that we can avoid duplicate template instantiations.
*/ #}
{%- set tuples = [] %}
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %}
{%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %}
{%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %}
{%- endif %}
{%- endfor %}

{#- /* Enumerate over the unique tuples */ #}
{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}

{%- macro template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) %}
template __global__ __launch_bounds__(kBackwardMaxThreads)
void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1
< {{ emb_type }},
Expand Down Expand Up @@ -321,7 +300,51 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize
const uint32_t info_B_mask,
{%- endif %}
{{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }});
{%- endmacro %}

{%- macro bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) %}
{%- for grad_type in ['float', 'at::Half'] %}
{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{{ template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


{%- if is_experimental_optimizer %}

{{ bulk_template_instantiations(max_embedding_dim // items_per_warp, 'kWarpSize') }}

{%- else %}

////////////////////////////////////////////////////////////////////////////////
#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
////////////////////////////////////////////////////////////////////////////////

{#- /*
Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize)
in the FBGEMM_USE_SUBWARP_SHUFFLE case
constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1);
constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1);
This is needed to compute the unique tuples to use for explicit instantiation,
so that we can avoid duplicate template instantiations.
*/ #}
{%- set tuples = [] %}
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %}
{%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %}
{%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %}
{%- endif %}
{%- endfor %}

{#- /* Enumerate over the unique tuples */ #}
{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}
{{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -339,80 +362,18 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize
{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %}
{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %}
{%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %}
{%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %}
{%- set temp = tuples.append((t0, "kWarpSize")) %}
{%- endif %}
{%- endfor %}

{#- /* Enumerate over the unique tuples */ #}
{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %}

template __global__ __launch_bounds__(kBackwardMaxThreads)
void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ kMaxVecsPerThread }},
{{ kThreadGroupSize }}
> (
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output,
{%- if optimizer != "none" %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights,
{%- if not dense %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights,
pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
{%- endif %}
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{%- else %}
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- else %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_infos,
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
{%- endif %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs,
int32_t max_segment_length_per_warp,
{%- if not dense and optimizer != "none" %}
bool stochastic_rounding,
at::PhiloxCudaState stochastic_rounding_philox_args,
{%- else %}
pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights,
{%- if optimizer == "none" %}
const int32_t max_D,
{%- endif %}
{%- endif %} // if not dense and optimizer != "none"
{%- if not nobag and vbe %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> B_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> output_offsets,
{%- endif %}
{%- if not nobag %}
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
{%- endif %}
{{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }});

{{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }}
{%- endfor %}

////////////////////////////////////////////////////////////////////////////////
#endif
////////////////////////////////////////////////////////////////////////////////

{%- endfor %}
{%- endfor %}
{%- endfor %}

{%- endif %}
// clang-format on
Loading

0 comments on commit 9e24d2a

Please sign in to comment.