Skip to content

Commit

Permalink
Use unique cache locations in backward for pipeline prefetching (pyto…
Browse files Browse the repository at this point in the history
…rch#2151)

Summary:
Pull Request resolved: pytorch#2151

When pipeline prefetching is enabled (`prefetch_pipeline=True`) for
`EmbeddingLocation.MANAGED_CACHING`, TBE has to update
`lxu_cache_locations` to ensure cache consistency.  Prior to this
diff, TBE performs the full cache lookup when updating
`lxu_cache_locations` (i.e., looking up all indices although they are
duplicated).  The time complexity of such the lookup grows with
the number of indices.  Thus, the lookup can be expensive when the
number of indices is large.  This diff addresses this problem by
looking up only the unique indices (which is normally much smaller
than the full indices).  The number of unique indices tends to stay
more or less the same even when the total number of indices grows.
Thus, looking up only unique indices can reduce cost of cache lookup
effectively.  The unique `lxu_cache_locations` are fed directly to TBE
backward to consume.  Thus, there is no decompression cost.

Reviewed By: ehsanardestani, jianyuh

Differential Revision: D51339208

fbshipit-source-id: be117db4a4a516341a99210e4c4e32dd9526588c
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 28, 2023
1 parent ca1da75 commit 035ed1f
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 29 deletions.
73 changes: 73 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,79 @@ split_embedding_backward_codegen_find_long_segments(
}
}

template <typename info_pta_t, typename info_t, bool nobag>
__global__ __launch_bounds__(kMaxThreads)
void split_embedding_backward_count_unique_indices_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_linear_indices_num_runs,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<info_pta_t, 1, at::RestrictPtrTraits>
sorted_infos,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
dev_or_uvm_unique_indices,
const int info_B_num_bits
) {
const int32_t num_runs = sorted_linear_indices_num_runs[0];
const auto T = weights_placements.size(0);
for (auto run_id = blockIdx.x * blockDim.x + threadIdx.x;
run_id < num_runs;
run_id += blockDim.x * gridDim.x) {
// Obtain the associated table id of the run id
const auto segment_start = sorted_linear_indices_cumulative_run_lengths[run_id];
const auto info = reinterpret_cast<const info_t*>(&sorted_infos[0])[segment_start];
const auto t = nobag ? (info % T) : (info >> info_B_num_bits);

int32_t t_next = -1;
const auto unique_count_offset = run_id + 1;
if (unique_count_offset < num_runs) {
const auto segment_start_next = sorted_linear_indices_cumulative_run_lengths[unique_count_offset];
const auto info_next = reinterpret_cast<const info_t*>(&sorted_infos[0])[segment_start_next];
t_next = nobag ? (info_next % T) : (info_next >> info_B_num_bits);
}

if (t != t_next) {
const auto placement = static_cast<PlacementType>(weights_placements[t]);
if (placement != PlacementType::MANAGED_CACHING) {
// Record num unique indices for PlacementType::DEVICE from unique_count_offset
gpuAtomicAdd(&dev_or_uvm_unique_indices[t], unique_count_offset);
}
if (t_next != -1) {
const auto placement_next = static_cast<PlacementType>(weights_placements[t_next]);
if (placement_next != PlacementType::MANAGED_CACHING) {
// Record num unique indices for PlacementType::DEVICE from unique_count_offset
gpuAtomicAdd(&dev_or_uvm_unique_indices[t_next], -unique_count_offset);
}
}
}
}
}

{% for nobag in [True, False] %}
{% set info_pta_t = "int64_t" if nobag else "int32_t" %}
template __global__ __launch_bounds__(kMaxThreads)
void split_embedding_backward_count_unique_indices_kernel
<
{{ info_pta_t }},
{{ "int64_t" if nobag else "uint32_t" }},
{{ "true" if nobag else "false" }}
> (
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_linear_indices_num_runs,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<{{ info_pta_t }}, 1, at::RestrictPtrTraits>
sorted_infos,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
dev_or_uvm_unique_indices,
const int info_B_num_bits
);
{% endfor %}

{% for vbe in [True, False] %}
{% set vbe_desc = "_vbe" if vbe else "" %}
template <typename grad_t>
Expand Down
28 changes: 26 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
{%- endif %}
const bool use_uniq_cache_locations,
const bool use_homogeneous_placements,
{{ args.split_function_args | join(", ") }});

{%- endfor %} {#-/*for nobag*/#}
Expand Down Expand Up @@ -177,6 +179,8 @@ class {{ autograd_func }} :
const int64_t vbe_output_size,
{%- endif %}
const bool is_experimental,
const bool use_uniq_cache_locations_bwd,
const bool use_homogeneous_placements,
{{ args.split_function_args | join(", ") }}) {

const auto T = weights_offsets.sym_numel();
Expand Down Expand Up @@ -263,6 +267,8 @@ class {{ autograd_func }} :
ctx->saved_data["info_B_num_bits"] = info_B_num_bits;
const auto info_B_mask_int64 = static_cast<int64_t>(info_B_mask);
ctx->saved_data["info_B_mask"] = info_B_mask_int64;
ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd;
ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements;

{%- for (var, _) in args.saved_data %}
ctx->saved_data["{{ var }}"] = {{ var }};
Expand Down Expand Up @@ -392,6 +398,10 @@ class {{ autograd_func }} :
{%- endif %} {#-/* if optimizer != "none" */#}
const int32_t info_B_num_bits = ctx->saved_data["info_B_num_bits"].toInt();
const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt();
const auto use_uniq_cache_locations_bwd =
ctx->saved_data["use_uniq_cache_locations_bwd"].toBool();
const auto use_homogeneous_placements =
ctx->saved_data["use_homogeneous_placements"].toBool();

{%- for (var, ivalue_cast) in args.saved_data %}
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
Expand Down Expand Up @@ -510,6 +520,8 @@ class {{ autograd_func }} :
vbe_row_output_offsets,
vbe_b_t_map,
{%- endif %}
use_uniq_cache_locations_bwd,
use_homogeneous_placements,
{{ args.split_function_arg_names | join(", ") }}
) {{ ":" if not weighted else ";" }}
{%- endfor %} {#-/* for weighted in [False, True] */#}
Expand Down Expand Up @@ -546,6 +558,8 @@ class {{ autograd_func }} :
Variable(), // vbe_output_size
{%- endif %}
Variable(), // is_experimental
Variable(), // use_uniq_cache_locations_bwd
Variable(), // use_homogeneous_placements
{{ args.split_variables | join(", ") }}
};
{%- else %}
Expand Down Expand Up @@ -585,6 +599,8 @@ class {{ autograd_func }} :
vbe_row_output_offsets,
vbe_b_t_map,
{%- endif %}
use_uniq_cache_locations_bwd,
use_homogeneous_placements,
{{ args.split_function_arg_names | join(", ") }}
);
return {
Expand Down Expand Up @@ -615,6 +631,8 @@ class {{ autograd_func }} :
Variable(), // vbe_output_size
{%- endif %}
Variable(), // is_experimental
Variable(), // use_uniq_cache_locations_bwd
Variable(), // use_homogeneous_placements
{{ args.split_variables | join(", ") }}
};
{%- endif %}
Expand Down Expand Up @@ -657,7 +675,9 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
const int64_t max_B = -1,
const int64_t max_B_feature_rank = -1,
const int64_t vbe_output_size = -1,
const bool is_experimental = false
const bool is_experimental = false,
const bool use_uniq_cache_locations_bwd = false,
const bool use_homogeneous_placements = false
) {
{%- if has_gpu_support %}
{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
Expand Down Expand Up @@ -721,6 +741,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
vbe_output_size,
{%- endif %}
is_experimental,
use_uniq_cache_locations_bwd,
use_homogeneous_placements,
{{ args.split_function_arg_names | join(", ") }})[0];
}
{%- endfor %} {#-/* for nobag */#}
Expand Down Expand Up @@ -767,7 +789,9 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
" int max_B=-1, "
" int max_B_feature_rank=-1, "
" int vbe_output_size=-1, "
" bool is_experimental=False) -> Tensor");
" bool is_experimental=False, "
" bool use_uniq_cache_locations_bwd=False, "
" bool use_homogeneous_placements=False) -> Tensor");
// We're playing a funny trick here: we're using the autograd
// implementation of the operator at all the dispatch keys. This is OK
// because autograd.Function works even in a context where there is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations,
const bool use_uniq_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> table_unique_indices_offsets,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
Expand Down Expand Up @@ -381,10 +383,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
stochastic_rounding,
stochastic_rounding_philox_args,
current_run_id,
use_uniq_cache_locations ? (current_run_id - table_unique_indices_offsets[t_0]) : segment_start,
D,
t_0,
idx,
segment_start,
shfl_sync_mask,
0, // shared_weight_offset
{{ args.split_function_arg_names | join(", ") }});
Expand Down Expand Up @@ -462,6 +464,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_lxu_cache_locations,
const bool use_uniq_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> table_unique_indices_offsets,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations,
const bool use_uniq_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> table_unique_indices_offsets,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<cache_t, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
Expand Down Expand Up @@ -222,10 +224,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
stochastic_rounding,
stochastic_rounding_philox_args,
run_id,
use_uniq_cache_locations ? (run_id - table_unique_indices_offsets[t_0]) : segment_start,
D,
t_0,
idx,
segment_start,
shfl_sync_mask,
threadIdx.y * kMaxVecsPerThread * kThreadGroupSize, // shared_weight_offset
{{ args.split_function_arg_names | join(", ") }});
Expand Down Expand Up @@ -301,6 +303,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
{%- endif %}
{%- if not dense %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_lxu_cache_locations,
const bool use_uniq_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> table_unique_indices_offsets,
{%- endif %}
{%- if weighted %}
const pta::PackedTensorAccessor32<at::acc_type<{{ cache_type }}, true>, 1, at::RestrictPtrTraits> sorted_indice_weights,
Expand Down
Loading

0 comments on commit 035ed1f

Please sign in to comment.