From d056aa3689380f7decad83c90bfc36f5dcf04195 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Thu, 26 Sep 2024 23:18:38 -0700 Subject: [PATCH] Add generate_vbe_metadata CPU fallback (#3183) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/279 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3183 Add generate_vbe_metadata for CPU Reviewed By: q10 Differential Revision: D63494418 fbshipit-source-id: 3936e2546ccf4fb89632df5c49148141adbabe71 --- .../fbgemm_gpu/split_embeddings_utils.cuh | 26 +------------- .../fbgemm_gpu/split_embeddings_utils.h | 36 +++++++++++++++++++ .../split_embeddings_utils.cpp | 17 +++++++++ 3 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index 42fe5eb4c..8351e046c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -12,14 +12,7 @@ #include #include #include "fbgemm_gpu/embedding_common.h" - -// These values are adjusted in backward based on B and T -constexpr int DEFAULT_INFO_NUM_BITS = 32; -constexpr int DEFAULT_INFO_B_NUM_BITS = 26; -constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; -constexpr uint32_t MAX_T = - (1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1; -constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; +#include "fbgemm_gpu/split_embeddings_utils.h" /** * "Transpose" embedding inputs by sorting indices by their values. @@ -50,11 +43,6 @@ transpose_embedding_input( const int64_t fixed_L_per_warp = 0, const int64_t num_warps_per_feature = 0); -std::tuple -get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); - -std::tuple adjust_info_B_num_bits(int32_t B, int32_t T); - // Use these functions instead of directly calling cub functions // to reduce code size and compilation time. // Arguments are the same as cub::DeviceRadixSort::SortPairs @@ -77,15 +65,3 @@ DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t); DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t); #undef DECL_RADIX_SORT_PAIRS_FN - -std::tuple -generate_vbe_metadata( - const at::Tensor& B_offsets, - const at::Tensor& B_offsets_rank_per_feature, - const at::Tensor& output_offsets_feature_rank, - const at::Tensor& D_offsets, - const int64_t D, - const bool nobag, - const int64_t max_B_feature_rank, - const int64_t info_B_num_bits, - const int64_t total_B); diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h new file mode 100644 index 000000000..b41681012 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +// These values are adjusted in backward based on B and T +constexpr int DEFAULT_INFO_NUM_BITS = 32; +constexpr int DEFAULT_INFO_B_NUM_BITS = 26; +constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; +constexpr uint32_t MAX_T = + (1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1; +constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; + +std::tuple +get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); + +std::tuple adjust_info_B_num_bits(int32_t B, int32_t T); + +std::tuple +generate_vbe_metadata( + const at::Tensor& B_offsets, + const at::Tensor& B_offsets_rank_per_feature, + const at::Tensor& output_offsets_feature_rank, + const at::Tensor& D_offsets, + const int64_t D, + const bool nobag, + const int64_t max_B_feature_rank, + const int64_t info_B_num_bits, + const int64_t total_B); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp index 4e8407fb1..4ae9ae0f7 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp @@ -33,6 +33,22 @@ generate_vbe_metadata_meta( return {row_output_offsets, b_t_map}; } +std::tuple +generate_vbe_metadata_cpu( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const c10::SymInt max_B_feature_rank, + const int64_t info_B_num_bits, + const c10::SymInt total_B) { + Tensor row_output_offsets = output_offsets_feature_rank; + Tensor b_t_map = B_offsets_rank_per_feature; + return {row_output_offsets, b_t_map}; +} + } // namespace TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -68,6 +84,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); + DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {