Skip to content

Commit

Permalink
Add generate_vbe_metadata CPU fallback (#3183)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#279

Pull Request resolved: #3183

Add generate_vbe_metadata for CPU

Reviewed By: q10

Differential Revision: D63494418

fbshipit-source-id: 3936e2546ccf4fb89632df5c49148141adbabe71
  • Loading branch information
spcyppt authored and facebook-github-bot committed Sep 27, 2024
1 parent a9b7ae8 commit d056aa3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
26 changes: 1 addition & 25 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include "fbgemm_gpu/embedding_common.h"

// These values are adjusted in backward based on B and T
constexpr int DEFAULT_INFO_NUM_BITS = 32;
constexpr int DEFAULT_INFO_B_NUM_BITS = 26;
constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1;
constexpr uint32_t MAX_T =
(1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1;
constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1;
#include "fbgemm_gpu/split_embeddings_utils.h"

/**
* "Transpose" embedding inputs by sorting indices by their values.
Expand Down Expand Up @@ -50,11 +43,6 @@ transpose_embedding_input(
const int64_t fixed_L_per_warp = 0,
const int64_t num_warps_per_feature = 0);

std::tuple<int64_t, int64_t>
get_infos_metadata(at::Tensor unused, int64_t B, int64_t T);

std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(int32_t B, int32_t T);

// Use these functions instead of directly calling cub functions
// to reduce code size and compilation time.
// Arguments are the same as cub::DeviceRadixSort::SortPairs
Expand All @@ -77,15 +65,3 @@ DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t);

#undef DECL_RADIX_SORT_PAIRS_FN

std::tuple<at::Tensor /*row_output_offsets*/, at::Tensor /*b_t_map*/>
generate_vbe_metadata(
const at::Tensor& B_offsets,
const at::Tensor& B_offsets_rank_per_feature,
const at::Tensor& output_offsets_feature_rank,
const at::Tensor& D_offsets,
const int64_t D,
const bool nobag,
const int64_t max_B_feature_rank,
const int64_t info_B_num_bits,
const int64_t total_B);
36 changes: 36 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <ATen/ATen.h>

// These values are adjusted in backward based on B and T
constexpr int DEFAULT_INFO_NUM_BITS = 32;
constexpr int DEFAULT_INFO_B_NUM_BITS = 26;
constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1;
constexpr uint32_t MAX_T =
(1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1;
constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1;

std::tuple<int64_t, int64_t>
get_infos_metadata(at::Tensor unused, int64_t B, int64_t T);

std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(int32_t B, int32_t T);

std::tuple<at::Tensor /*row_output_offsets*/, at::Tensor /*b_t_map*/>
generate_vbe_metadata(
const at::Tensor& B_offsets,
const at::Tensor& B_offsets_rank_per_feature,
const at::Tensor& output_offsets_feature_rank,
const at::Tensor& D_offsets,
const int64_t D,
const bool nobag,
const int64_t max_B_feature_rank,
const int64_t info_B_num_bits,
const int64_t total_B);
17 changes: 17 additions & 0 deletions fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ generate_vbe_metadata_meta(
return {row_output_offsets, b_t_map};
}

std::tuple<Tensor /*row_output_offsets*/, Tensor /*b_t_map*/>
generate_vbe_metadata_cpu(
const Tensor& B_offsets,
const Tensor& B_offsets_rank_per_feature,
const Tensor& output_offsets_feature_rank,
const Tensor& D_offsets,
const int64_t D,
const bool nobag,
const c10::SymInt max_B_feature_rank,
const int64_t info_B_num_bits,
const c10::SymInt total_B) {
Tensor row_output_offsets = output_offsets_feature_rank;
Tensor b_t_map = B_offsets_rank_per_feature;
return {row_output_offsets, b_t_map};
}

} // namespace

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
Expand Down Expand Up @@ -68,6 +84,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input);
DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata);
DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata);
DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
Expand Down

0 comments on commit d056aa3

Please sign in to comment.