Skip to content

Commit

Permalink
Split up split_embeddings_cache_cuda.cu (pytorch#1881)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1881

- Split up split_embeddings_cache_cuda.cu

Reviewed By: sryap, spcyppt

Differential Revision: D47491160

fbshipit-source-id: 9c025f06dbc003fe8734d64ae5a102b1a37fd5a5
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 11, 2023
1 parent 1e194b7 commit 1605d82
Show file tree
Hide file tree
Showing 11 changed files with 3,306 additions and 3,202 deletions.
8 changes: 7 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,13 @@ if(NOT FBGEMM_CPU_ONLY)
src/sparse_ops/sparse_reorder_batched_ad.cu
src/sparse_ops/sparse_segment_sum_csr.cu
src/sparse_ops/sparse_zipf.cu
src/split_embeddings_cache_cuda.cu
src/split_embeddings_cache/lfu_cache.cu
src/split_embeddings_cache/lru_cache_find.cu
src/split_embeddings_cache/lru_cache_populate.cu
src/split_embeddings_cache/lru_cache_populate_byte.cu
src/split_embeddings_cache/lxu_cache.cu
src/split_embeddings_cache/linearize_cache_indices.cu
src/split_embeddings_cache/reset_weight_momentum.cu
src/split_embeddings_utils.cu)

set_source_files_properties(${fbgemm_gpu_sources_static_gpu}
Expand Down
13 changes: 13 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
///@defgroup table-batched-embed-cuda CUDA Operators
/// The following are CUDA Operators

namespace fbgemm_gpu {

enum uvm_cache_stats_index {
num_calls = 0,
num_requested_indices = 1,
num_unique_indices = 2,
num_unique_misses = 3,
num_conflict_unique_misses = 4,
num_conflict_misses = 5,
};

} // namespace fbgemm_gpu

///@ingroup table-batched-embed-cuda
/// Deduplicate indices.
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
Expand Down
90 changes: 90 additions & 0 deletions fbgemm_gpu/src/split_embeddings_cache/common.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

// clang-format off
#include "fbgemm_gpu/cub_namespace_prefix.cuh"
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_select.cuh>
#include <cub/block/block_reduce.cuh>
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
// clang-format on

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <limits>
#include <mutex>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "fbgemm_gpu/ops_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
#include "fbgemm_gpu/split_embeddings_utils.cuh"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

namespace {

constexpr size_t kCacheMaxThreads = 512;
constexpr int32_t kCacheLocationMissing = -1;
constexpr int64_t kCacheStateInvalid = -1;

// // TODO: do we care about 64-bit indices? Currently we just ignore.
// __host__ DEVICE_INLINE uint32_t cache_slot(int32_t h_in, int32_t C) {
// // MurmorHash3 32-bit mixing function.
// uint32_t h = (uint32_t)h_in;
// h ^= h >> 16;
// h *= 0x85ebca6b;
// h ^= h >> 13;
// h *= 0xc2b2ae35;
// h ^= h >> 16;
// //
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
// return ((uint64_t)h * (uint64_t)C) >> 32;
// }

__host__ DEVICE_INLINE uint32_t
cache_slot(const int64_t h_in, const int32_t C) {
// MurmurHash3 64-bit mixing function.
uint64_t h = (uint64_t)h_in;
h ^= h >> 33;
h *= 0xff51afd7ed558ccd;
h ^= h >> 33;
h *= 0xc4ceb9fe1a85ec53;
h ^= h >> 33;

return h % (uint32_t)C;
}

// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
// not sensitive to grid size as long as the number thread blocks per SM is not
// too small nor too big.
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;

int get_max_thread_blocks_for_cache_kernels_() {
return get_device_sm_cnt_() * MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
}

} // namespace
Loading

0 comments on commit 1605d82

Please sign in to comment.