Skip to content

Commit

Permalink
Enable VBE support on CPU (#3174)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3174

Previous VBE on CPU was enabled in lookup_{{ optimizer }}.py.

To support MTIA ops, VBE should be done after torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2.

This diff follows the same implementation but enables it C++ so that it goes through the same PT2 pipeline (i.e., lookup -> VBE autograd -> cpu wrapper (*do vbe here*) -> cpu kernel).  the call is done

Reviewed By: q10

Differential Revision: D63410944
  • Loading branch information
spcyppt authored and facebook-github-bot committed Sep 27, 2024
1 parent 5582d23 commit b2c45b5
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 39 deletions.
5 changes: 3 additions & 2 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ list(APPEND gen_gpu_host_source_files
foreach(optimizer ${ALL_OPTIMIZERS})
list(APPEND gen_cpu_source_files
"gen_embedding_backward_split_${optimizer}_cpu.cpp"
"gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp")
"gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp"
"gen_embedding_split_${optimizer}_pt2_autograd.cpp")
list(APPEND gen_gpu_host_source_files
"gen_embedding_backward_split_${optimizer}.cpp"
"gen_embedding_split_${optimizer}_pt2_autograd.cpp"
"gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp")
endforeach()

Expand Down Expand Up @@ -456,6 +456,7 @@ set(fbgemm_gpu_sources_static_cpu
codegen/training/forward/embedding_forward_split_cpu.cpp
codegen/inference/embedding_forward_quantized_host_cpu.cpp
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
codegen/training/pt2/pt2_autograd_utils.cpp
codegen/utils/embedding_bounds_check_host_cpu.cpp
src/config/feature_gates.cpp
src/memory_utils/memory_utils.cpp
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None:
f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp",
has_cpu_support=True,
is_forward=True,
has_vbe_support=True,
)

# Generate PT2 forward wrapper (CUDA)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@
#include "fbgemm_gpu/utils/ops_utils.h"
#include <torch/script.h>
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/split_embeddings_utils.h"
#include "fbgemm_gpu/config/feature_gates.h"
{%- if has_vbe_support %}
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
{%- endif %}

using Tensor = at::Tensor;

Expand Down Expand Up @@ -236,9 +240,9 @@ enum SSDTensor {
const Tensor& /*prev_iter_dev*/,
{%- endif %}
{%- if "iter" not in args_pt2.split_function_arg_names %}
const int64_t iter,
const int64_t /*iter*/,
{%- endif %}
const double gwd_lower_bound,
const double /*gwd_lower_bound*/,
{%- endif %} {# /* if is_gwd */ #}
{%- for arg_type in args_pt2.split_function_args %}
{{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %}
Expand Down Expand Up @@ -617,7 +621,6 @@ class {{ autograd_func }} :
const c10::SymInt,
const int64_t,
const c10::SymInt)>();

auto [
vbe_row_output_offsets,
vbe_b_t_map
Expand Down Expand Up @@ -850,6 +853,11 @@ static torch::autograd::variable_list backward(
// {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda)
weights_dev = weights_dev.flatten();
{%- endif %}
{%- if vbe %}
if (weights_host.numel() > 1){
grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets);
}
{%- endif %}

{%- set grad_indice_weights_op =
"{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc)
Expand Down Expand Up @@ -883,7 +891,7 @@ static torch::autograd::variable_list backward(
{%- else %}
const Tensor& /*feature_requires_grad*/
{%- endif %}
)>();
)>();

const auto grad_indice_weights = !indice_weights.defined() ?
Variable() :
Expand Down Expand Up @@ -1014,7 +1022,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2(
{%- if not ssd %}
{%- if has_vbe_support %}
// has vbe support and on gpu
if (B_offsets.has_value() && !(weights[0].numel() > 0)) {
if (B_offsets.has_value()) {
{%- if has_global_weight_decay_support %}
// vbe and has gwd support
if (apply_global_weight_decay && weight_decay > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
{%- set vdesc = "_vbe" if vbe else "" %}

{%- if is_forward %}
{#-/* PT2 wrapper function for backward grad_indice_weights CPU */#}
Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper(
Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& grad_output,
const Tensor& host_weights,
const Tensor& /*dev_weights*/,
Expand All @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper(
const Tensor& indices,
const Tensor& offsets,
const Tensor& /*lxu_cache_locations*/,
const Tensor& feature_requires_grad) {
{%- if vbe %}
const Tensor& feature_requires_grad,
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64
{%- else %}
const Tensor& feature_requires_grad
{%- endif %}
) {
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow(
Expand All @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper(

{% if is_forward %}
{#-/* PT2 wrapper function for forward CPU */#}
Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper(
Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& host_weights,
const Tensor& /*dev_weights*/,
const Tensor& /*uvm_weights*/,
Expand All @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper(
const Tensor& indice_weights,
const Tensor& /*lxu_cache_locations*/,
const Tensor& /*uvm_cache_stats*/,
{%- if vbe %}
const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/
const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/
const c10::SymInt vbe_output_size,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
{%- endif %}
const bool /*is_experimental = false*/,
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
.typed<Tensor(
Tensor, Tensor, Tensor, c10::SymInt, Tensor, Tensor, Tensor, int64_t, Tensor, int64_t
)>();
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
.typed<Tensor(
Tensor, Tensor, Tensor, c10::SymInt, Tensor, Tensor, Tensor, int64_t, Tensor, int64_t
)>();
{%- if vbe %}
// TODO: remove this after vbe is implemented for CPU kernel
Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map;
Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets;
const auto output = op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output_dtype);
auto options = at::TensorOptions()
.dtype(output.options().dtype())
.device(host_weights.options().device());
const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__);
Tensor output_new = at::empty({vbe_output_size_}, options);
const int32_t T = D_offsets.numel() - 1;
const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1;

return op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output_dtype);
}
for (int32_t r = 0; r < R; r++){
auto D_offset = 0;
for (int32_t t = 0; t < T; t++){
const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item<int32_t>();
const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item<int32_t>();
const int32_t D = D_offsets[t + 1].item<int32_t>() - D_offsets[t].item<int32_t>();
const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item<int32_t>();
const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item<int32_t>();

TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D));
auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten();
output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values);
D_offset += D;
}
}
return output_new;
{%- else %}
return op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output_dtype);
{%- endif %}
}
{% else %}
{#-/* PT2 wrapper function for backward CPU */#}
Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper(
Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& grad_output,
const Tensor& host_weights,
const Tensor& /*dev_weights*/,
Expand All @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap
const int64_t /*BT_block_size*/,
const int64_t /*max_segment_length_per_warp*/,
const bool stochastic_rounding,
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
{%- if vbe %}
const Tensor& B_offsets,
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
{%- endif %}
const bool /*use_uniq_cache_locations*/,
const bool /*use_homogeneous_placements*/,
{{ args_pt2.split_function_args | join(", ") }}
Expand Down Expand Up @@ -194,29 +258,30 @@ namespace {
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{%- if is_forward %}
DISPATCH_TO_CPU(
"split_embedding_codegen_grad_indice_weights_pt2_wrapper",
split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper);
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper",
split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper);
{%- endif %}

{%- for weighted in [True, False] %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
{%- if is_forward %}
{%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format(
wdesc
{%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format(
wdesc, vdesc
)
%}
DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper);
{%- else %}

{%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format(
optimizer, wdesc
{%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format(
optimizer, wdesc, vdesc
)
%}
DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper);
{%- endif %}
{%- endfor %} {#-/*for weighted*/#}
}

} // namespace
{%- endfor %} {#-/* for vbe in [True, False] */#}

{% endif %} // if has_cpu_support
// clang-format on
62 changes: 62 additions & 0 deletions fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>

using Tensor = at::Tensor;

namespace fbgemm_gpu {

////////////////////////////////////////////////////////////////////////////////
// Helper Functions
////////////////////////////////////////////////////////////////////////////////

Tensor reshape_vbe_output(
const Tensor& grad_output,
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& D_offsets) {
/* FOR CPU VBE to use the same backend */
const auto T = D_offsets.numel() - 1;
int32_t max_B = 0;
int32_t total_D = 0;
// find max_B, total_D to create output [max_B, total_D]
for (int32_t t = 0; t < T; t++) {
auto b = B_offsets[t + 1].item<int32_t>() - B_offsets[t].item<int32_t>();
max_B = std::max(max_B, b);
total_D += D_offsets[t + 1].item<int32_t>() - D_offsets[t].item<int32_t>();
}
auto grad_output_ = at::empty({max_B, total_D}, grad_output.options());
// for each feature
auto offset = 0;

const int32_t R = B_offsets_rank_per_feature.size(1) - 1;
for (int32_t r = 0; r < R; r++) {
auto D_offset = 0;
for (int32_t t = 0; t < T; t++) {
const int32_t b_begin = B_offsets_rank_per_feature[t][r].item<int32_t>();
const int32_t b_end =
B_offsets_rank_per_feature[t][r + 1].item<int32_t>();
const int32_t D =
D_offsets[t + 1].item<int32_t>() - D_offsets[t].item<int32_t>();
const int32_t b = b_end - b_begin;
const int32_t num_elm = b * D;
auto values = grad_output.slice(0, offset, offset + num_elm);
values = values.reshape({b, D});
grad_output_.index_put_(
{at::indexing::Slice(b_begin, b_end),
at::indexing::Slice(D_offset, D_offset + D)},
values);
D_offset += D;
offset += num_elm;
}
}
return grad_output_;
}
} // namespace fbgemm_gpu
31 changes: 31 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
// #include <ATen/core/op_registration/op_registration.h>
// #include <torch/script.h>
// #include "fbgemm_gpu/embedding_common.h"
// #include "fbgemm_gpu/utils/dispatch_macros.h"
// #include "fbgemm_gpu/utils/ops_utils.h"
// #include "fbgemm_gpu/utils/tensor_utils.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {

////////////////////////////////////////////////////////////////////////////////
// Helper Functions
////////////////////////////////////////////////////////////////////////////////

Tensor reshape_vbe_output(
const Tensor& grad_output,
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& D_offsets);
} // namespace fbgemm_gpu
Loading

0 comments on commit b2c45b5

Please sign in to comment.