Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable VBE support on CPU #3174

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ list(APPEND gen_gpu_host_source_files
foreach(optimizer ${ALL_OPTIMIZERS})
list(APPEND gen_cpu_source_files
"gen_embedding_backward_split_${optimizer}_cpu.cpp"
"gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp")
"gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp"
"gen_embedding_split_${optimizer}_pt2_autograd.cpp")
list(APPEND gen_gpu_host_source_files
"gen_embedding_backward_split_${optimizer}.cpp"
"gen_embedding_split_${optimizer}_pt2_autograd.cpp"
"gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp")
endforeach()

Expand Down Expand Up @@ -454,6 +454,7 @@ set(fbgemm_gpu_sources_static_cpu
codegen/training/forward/embedding_forward_split_cpu.cpp
codegen/inference/embedding_forward_quantized_host_cpu.cpp
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
codegen/training/pt2/pt2_autograd_utils.cpp
codegen/utils/embedding_bounds_check_host_cpu.cpp
src/config/feature_gates.cpp
src/memory_utils/memory_utils.cpp
Expand All @@ -480,6 +481,7 @@ set(fbgemm_gpu_sources_static_cpu
src/split_embeddings_cache/lru_cache_populate_byte.cpp
src/split_embeddings_cache/lxu_cache.cpp
src/split_embeddings_cache/split_embeddings_cache_ops.cpp
src/split_embeddings_utils/split_embeddings_utils_cpu.cpp
codegen/training/index_select/batch_index_select_dim0_ops.cpp
codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)

Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None:
f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp",
has_cpu_support=True,
is_forward=True,
has_vbe_support=True,
)

# Generate PT2 forward wrapper (CUDA)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@
#include "fbgemm_gpu/utils/ops_utils.h"
#include <torch/script.h>
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/split_embeddings_utils.h"
#include "fbgemm_gpu/config/feature_gates.h"
{%- if has_vbe_support %}
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
{%- endif %}

using Tensor = at::Tensor;

Expand Down Expand Up @@ -236,9 +240,9 @@ enum SSDTensor {
const Tensor& /*prev_iter_dev*/,
{%- endif %}
{%- if "iter" not in args_pt2.split_function_arg_names %}
const int64_t iter,
const int64_t /*iter*/,
{%- endif %}
const double gwd_lower_bound,
const double /*gwd_lower_bound*/,
{%- endif %} {# /* if is_gwd */ #}
{%- for arg_type in args_pt2.split_function_args %}
{{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %}
Expand Down Expand Up @@ -617,7 +621,6 @@ class {{ autograd_func }} :
const c10::SymInt,
const int64_t,
const c10::SymInt)>();

auto [
vbe_row_output_offsets,
vbe_b_t_map
Expand Down Expand Up @@ -850,6 +853,11 @@ static torch::autograd::variable_list backward(
// {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda)
weights_dev = weights_dev.flatten();
{%- endif %}
{%- if vbe %}
if (weights_host.numel() > 1){
grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets);
}
{%- endif %}

{%- set grad_indice_weights_op =
"{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc)
Expand Down Expand Up @@ -883,7 +891,7 @@ static torch::autograd::variable_list backward(
{%- else %}
const Tensor& /*feature_requires_grad*/
{%- endif %}
)>();
)>();

const auto grad_indice_weights = !indice_weights.defined() ?
Variable() :
Expand Down Expand Up @@ -1014,7 +1022,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2(
{%- if not ssd %}
{%- if has_vbe_support %}
// has vbe support and on gpu
if (B_offsets.has_value() && !(weights[0].numel() > 0)) {
if (B_offsets.has_value()) {
{%- if has_global_weight_decay_support %}
// vbe and has gwd support
if (apply_global_weight_decay && weight_decay > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
{%- set vdesc = "_vbe" if vbe else "" %}

{%- if is_forward %}
{#-/* PT2 wrapper function for backward grad_indice_weights CPU */#}
Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper(
Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& grad_output,
const Tensor& host_weights,
const Tensor& /*dev_weights*/,
Expand All @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper(
const Tensor& indices,
const Tensor& offsets,
const Tensor& /*lxu_cache_locations*/,
const Tensor& feature_requires_grad) {
{%- if vbe %}
const Tensor& feature_requires_grad,
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64
{%- else %}
const Tensor& feature_requires_grad
{%- endif %}
) {
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow(
Expand All @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper(

{% if is_forward %}
{#-/* PT2 wrapper function for forward CPU */#}
Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper(
Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& host_weights,
const Tensor& /*dev_weights*/,
const Tensor& /*uvm_weights*/,
Expand All @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper(
const Tensor& indice_weights,
const Tensor& /*lxu_cache_locations*/,
const Tensor& /*uvm_cache_stats*/,
{%- if vbe %}
const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/
const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/
const c10::SymInt vbe_output_size,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
{%- endif %}
const bool /*is_experimental = false*/,
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
.typed<Tensor(
Tensor, Tensor, Tensor, c10::SymInt, Tensor, Tensor, Tensor, int64_t, Tensor, int64_t
)>();
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
.typed<Tensor(
Tensor, Tensor, Tensor, c10::SymInt, Tensor, Tensor, Tensor, int64_t, Tensor, int64_t
)>();
{%- if vbe %}
// TODO: remove this after vbe is implemented for CPU kernel
Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map;
Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets;
const auto output = op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output_dtype);
auto options = at::TensorOptions()
.dtype(output.options().dtype())
.device(host_weights.options().device());
const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__);
Tensor output_new = at::empty({vbe_output_size_}, options);
const int32_t T = D_offsets.numel() - 1;
const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1;

return op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output_dtype);
}
for (int32_t r = 0; r < R; r++){
auto D_offset = 0;
for (int32_t t = 0; t < T; t++){
const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item<int32_t>();
const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item<int32_t>();
const int32_t D = D_offsets[t + 1].item<int32_t>() - D_offsets[t].item<int32_t>();
const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item<int32_t>();
const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item<int32_t>();

TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D));
auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten();
output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values);
D_offset += D;
}
}
return output_new;
{%- else %}
return op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
output_dtype);
{%- endif %}
}
{% else %}
{#-/* PT2 wrapper function for backward CPU */#}
Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper(
Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& grad_output,
const Tensor& host_weights,
const Tensor& /*dev_weights*/,
Expand All @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap
const int64_t /*BT_block_size*/,
const int64_t /*max_segment_length_per_warp*/,
const bool stochastic_rounding,
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
{%- if vbe %}
const Tensor& B_offsets,
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
{%- endif %}
const bool /*use_uniq_cache_locations*/,
const bool /*use_homogeneous_placements*/,
{{ args_pt2.split_function_args | join(", ") }}
Expand Down Expand Up @@ -194,29 +258,30 @@ namespace {
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{%- if is_forward %}
DISPATCH_TO_CPU(
"split_embedding_codegen_grad_indice_weights_pt2_wrapper",
split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper);
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper",
split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper);
{%- endif %}

{%- for weighted in [True, False] %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
{%- if is_forward %}
{%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format(
wdesc
{%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format(
wdesc, vdesc
)
%}
DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper);
{%- else %}

{%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format(
optimizer, wdesc
{%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format(
optimizer, wdesc, vdesc
)
%}
DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper);
{%- endif %}
{%- endfor %} {#-/*for weighted*/#}
}

} // namespace
{%- endfor %} {#-/* for vbe in [True, False] */#}

{% endif %} // if has_cpu_support
// clang-format on
62 changes: 62 additions & 0 deletions fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>

using Tensor = at::Tensor;

namespace fbgemm_gpu {

////////////////////////////////////////////////////////////////////////////////
// Helper Functions
////////////////////////////////////////////////////////////////////////////////

Tensor reshape_vbe_output(
const Tensor& grad_output,
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& D_offsets) {
/* FOR CPU VBE to use the same backend */
const auto T = D_offsets.numel() - 1;
int32_t max_B = 0;
int32_t total_D = 0;
// find max_B, total_D to create output [max_B, total_D]
for (int32_t t = 0; t < T; t++) {
auto b = B_offsets[t + 1].item<int32_t>() - B_offsets[t].item<int32_t>();
max_B = std::max(max_B, b);
total_D += D_offsets[t + 1].item<int32_t>() - D_offsets[t].item<int32_t>();
}
auto grad_output_ = at::empty({max_B, total_D}, grad_output.options());
// for each feature
auto offset = 0;

const int32_t R = B_offsets_rank_per_feature.size(1) - 1;
for (int32_t r = 0; r < R; r++) {
auto D_offset = 0;
for (int32_t t = 0; t < T; t++) {
const int32_t b_begin = B_offsets_rank_per_feature[t][r].item<int32_t>();
const int32_t b_end =
B_offsets_rank_per_feature[t][r + 1].item<int32_t>();
const int32_t D =
D_offsets[t + 1].item<int32_t>() - D_offsets[t].item<int32_t>();
const int32_t b = b_end - b_begin;
const int32_t num_elm = b * D;
auto values = grad_output.slice(0, offset, offset + num_elm);
values = values.reshape({b, D});
grad_output_.index_put_(
{at::indexing::Slice(b_begin, b_end),
at::indexing::Slice(D_offset, D_offset + D)},
values);
D_offset += D;
offset += num_elm;
}
}
return grad_output_;
}
} // namespace fbgemm_gpu
31 changes: 31 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
// #include <ATen/core/op_registration/op_registration.h>
// #include <torch/script.h>
// #include "fbgemm_gpu/embedding_common.h"
// #include "fbgemm_gpu/utils/dispatch_macros.h"
// #include "fbgemm_gpu/utils/ops_utils.h"
// #include "fbgemm_gpu/utils/tensor_utils.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {

////////////////////////////////////////////////////////////////////////////////
// Helper Functions
////////////////////////////////////////////////////////////////////////////////

Tensor reshape_vbe_output(
const Tensor& grad_output,
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& D_offsets);
} // namespace fbgemm_gpu
Loading
Loading