From 0da72d45605c3c43c9f0cb8bb056f6bfdaed1a65 Mon Sep 17 00:00:00 2001 From: rhdong Date: Wed, 26 Jul 2023 22:02:33 +0800 Subject: [PATCH] [Feat] make allocator customizable --- include/merlin/allocator.cuh | 127 ++++++ include/merlin/core_kernels.cuh | 115 ++++-- include/merlin/core_kernels/kernel_utils.cuh | 7 +- include/merlin/group_lock.hpp | 19 +- include/merlin/initializers.cuh | 147 ------- include/merlin/memory_pool.cuh | 84 ++-- include/merlin/utils.cuh | 43 -- include/merlin_hashtable.cuh | 36 +- tests/memory_pool_test.cc.cu | 391 ++++++++++--------- tests/merlin_hashtable_test.cc.cu | 78 +++- 10 files changed, 578 insertions(+), 469 deletions(-) create mode 100644 include/merlin/allocator.cuh delete mode 100644 include/merlin/initializers.cuh diff --git a/include/merlin/allocator.cuh b/include/merlin/allocator.cuh new file mode 100644 index 000000000..5e45c584d --- /dev/null +++ b/include/merlin/allocator.cuh @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "debug.hpp" +#include "utils.cuh" + +namespace nv { +namespace merlin { + +enum MemoryType { + Device, // HBM + Pinned, // Pinned Host Memory + Host, // Host Memory + Managed, // Pageable Host Memory(Not required) +}; + +/* This abstract class defines the allocator APIs required by HKV. + Any of the customized allocators should inherit from it. + */ +class BaseAllocator { + public: + BaseAllocator(const BaseAllocator&) = delete; + BaseAllocator(BaseAllocator&&) = delete; + + BaseAllocator& operator=(const BaseAllocator&) = delete; + BaseAllocator& operator=(BaseAllocator&&) = delete; + + BaseAllocator() = default; + virtual ~BaseAllocator() = default; + + virtual void alloc(const MemoryType type, void** ptr, size_t size, + unsigned int pinned_flags = cudaHostAllocDefault) = 0; + + virtual void alloc_async(const MemoryType type, void** ptr, size_t size, + cudaStream_t stream) = 0; + + virtual void free(const MemoryType type, void* ptr) = 0; + + virtual void free_async(const MemoryType type, void* ptr, + cudaStream_t stream) = 0; +}; + +class DefaultAllocator : public virtual BaseAllocator { + public: + DefaultAllocator(){}; + ~DefaultAllocator() override{}; + + void alloc(const MemoryType type, void** ptr, size_t size, + unsigned int pinned_flags = cudaHostAllocDefault) override { + switch (type) { + case MemoryType::Device: + CUDA_CHECK(cudaMalloc(ptr, size)); + break; + case MemoryType::Pinned: + CUDA_CHECK(cudaMallocHost(ptr, size, pinned_flags)); + break; + case MemoryType::Host: + *ptr = std::malloc(size); + break; + } + return; + } + + void alloc_async(const MemoryType type, void** ptr, size_t size, + cudaStream_t stream) override { + if (type == MemoryType::Device) { + CUDA_CHECK(cudaMallocAsync(ptr, size, stream)); + } else { + MERLIN_CHECK(false, + "[DefaultAllocator] alloc_async is only support for " + "MemoryType::Device!"); + } + return; + } + + void free(const MemoryType type, void* ptr) override { + if (ptr == nullptr) { + return; + } + switch (type) { + case MemoryType::Pinned: + CUDA_CHECK(cudaFreeHost(ptr)); + break; + case MemoryType::Device: + CUDA_CHECK(cudaFree(ptr)); + break; + case MemoryType::Host: + std::free(ptr); + break; + } + return; + } + + void free_async(const MemoryType type, void* ptr, + cudaStream_t stream) override { + if (ptr == nullptr) { + return; + } + + if (type == MemoryType::Device) { + CUDA_CHECK(cudaFreeAsync(ptr, stream)); + } else { + MERLIN_CHECK(false, + "[DefaultAllocator] free_async is only support for " + "MemoryType::Device!"); + } + } +}; + +} // namespace merlin +} // namespace nv diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 60b066641..6719071d5 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -16,8 +16,7 @@ #pragma once -#include -#include +#include "merlin/allocator.cuh" #include "merlin/core_kernels/find_or_insert.cuh" #include "merlin/core_kernels/find_ptr_or_insert.cuh" #include "merlin/core_kernels/kernel_utils.cuh" @@ -30,7 +29,6 @@ namespace nv { namespace merlin { - template __global__ void create_locks(S* __restrict mutex, const size_t start, const size_t end) { @@ -103,10 +101,55 @@ __global__ void get_bucket_others_address(Bucket* __restrict buckets, *address = buckets[index].digests_; } +template +void realloc(P* ptr, size_t old_size, size_t new_size, + BaseAllocator* allocator) { + // Truncate old_size to limit dowstream copy ops. + old_size = std::min(old_size, new_size); + + // Alloc new buffer and copy at old data. + char* new_ptr; + allocator->alloc(MemoryType::Device, (void**)&new_ptr, new_size); + if (*ptr != nullptr) { + CUDA_CHECK(cudaMemcpy(new_ptr, *ptr, old_size, cudaMemcpyDefault)); + allocator->free(MemoryType::Device, *ptr); + } + + // Zero-fill remainder. + CUDA_CHECK(cudaMemset(new_ptr + old_size, 0, new_size - old_size)); + + // Switch to new pointer. + *ptr = reinterpret_cast

(new_ptr); + return; +} + +template +void realloc_host(P* ptr, size_t old_size, size_t new_size, + BaseAllocator* allocator) { + // Truncate old_size to limit dowstream copy ops. + old_size = std::min(old_size, new_size); + + // Alloc new buffer and copy at old data. + char* new_ptr = nullptr; + allocator->alloc(MemoryType::Host, (void**)&new_ptr, new_size); + + if (*ptr != nullptr) { + std::memcpy(new_ptr, *ptr, old_size); + allocator->free(MemoryType::Host, *ptr); + } + + // Zero-fill remainder. + std::memset(new_ptr + old_size, 0, new_size - old_size); + + // Switch to new pointer. + *ptr = reinterpret_cast

(new_ptr); + return; +} + /* Initialize the buckets with index from start to end. */ template -void initialize_buckets(Table** table, const size_t start, - const size_t end) { +void initialize_buckets(Table** table, BaseAllocator* allocator, + const size_t start, const size_t end) { /* As testing results show us, when the number of buckets is greater than * the 4 million the performance will drop significantly, we believe the * to many pinned memory allocation causes this issue, so we change the @@ -127,7 +170,8 @@ void initialize_buckets(Table** table, const size_t start, realloc_host( &((*table)->slices), (*table)->num_of_memory_slices * sizeof(V*), - ((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*)); + ((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*), + allocator); for (size_t i = (*table)->num_of_memory_slices; i < (*table)->num_of_memory_slices + num_of_memory_slices; i++) { @@ -138,12 +182,13 @@ void initialize_buckets(Table** table, const size_t start, (*table)->bucket_max_size * sizeof(V) * (*table)->dim; if ((*table)->remaining_hbm_for_vectors >= slice_real_size) { - CUDA_CHECK(cudaMalloc(&((*table)->slices[i]), slice_real_size)); + allocator->alloc(MemoryType::Device, (void**)&((*table)->slices[i]), + slice_real_size); (*table)->remaining_hbm_for_vectors -= slice_real_size; } else { (*table)->is_pure_hbm = false; - CUDA_CHECK(cudaMallocHost(&((*table)->slices[i]), slice_real_size, - cudaHostAllocMapped)); + allocator->alloc(MemoryType::Pinned, (void**)&((*table)->slices[i]), + slice_real_size, cudaHostAllocMapped); } for (int j = 0; j < num_of_buckets_in_one_slice; j++) { if ((*table)->is_pure_hbm) { @@ -178,7 +223,8 @@ void initialize_buckets(Table** table, const size_t start, bucket_memory_size += reserve_size * sizeof(uint8_t); for (int i = start; i < end; i++) { uint8_t* address = nullptr; - CUDA_CHECK(cudaMalloc(&address, bucket_memory_size)); + allocator->alloc(MemoryType::Device, (void**)&(address), + bucket_memory_size); allocate_bucket_others<<<1, 1>>>((*table)->buckets, i, address, reserve_size, bucket_max_size); } @@ -244,14 +290,13 @@ size_t get_slice_size(Table** table) { DIM: Vector dimension. */ template -void create_table(Table** table, const size_t dim, - const size_t init_size = 134217728, +void create_table(Table** table, BaseAllocator* allocator, + const size_t dim, const size_t init_size = 134217728, const size_t max_size = std::numeric_limits::max(), const size_t max_hbm_for_vectors = 0, const size_t bucket_max_size = 128, const size_t tile_size = 32, const bool primary = true) { - (*table) = - reinterpret_cast*>(std::malloc(sizeof(Table))); + allocator->alloc(MemoryType::Host, (void**)table, sizeof(Table)); std::memset(*table, 0, sizeof(Table)); (*table)->dim = dim; (*table)->bucket_max_size = bucket_max_size; @@ -277,39 +322,39 @@ void create_table(Table** table, const size_t dim, (*table)->remaining_hbm_for_vectors = max_hbm_for_vectors; (*table)->primary = primary; - CUDA_CHECK(cudaMalloc((void**)&((*table)->locks), - (*table)->buckets_num * sizeof(Mutex))); + allocator->alloc(MemoryType::Device, (void**)&((*table)->locks), + (*table)->buckets_num * sizeof(Mutex)); CUDA_CHECK( cudaMemset((*table)->locks, 0, (*table)->buckets_num * sizeof(Mutex))); - CUDA_CHECK(cudaMalloc((void**)&((*table)->buckets_size), - (*table)->buckets_num * sizeof(int))); + allocator->alloc(MemoryType::Device, (void**)&((*table)->buckets_size), + (*table)->buckets_num * sizeof(int)); CUDA_CHECK(cudaMemset((*table)->buckets_size, 0, (*table)->buckets_num * sizeof(int))); - CUDA_CHECK(cudaMalloc((void**)&((*table)->buckets), - (*table)->buckets_num * sizeof(Bucket))); + allocator->alloc(MemoryType::Device, (void**)&((*table)->buckets), + (*table)->buckets_num * sizeof(Bucket)); CUDA_CHECK(cudaMemset((*table)->buckets, 0, (*table)->buckets_num * sizeof(Bucket))); - initialize_buckets(table, 0, (*table)->buckets_num); + initialize_buckets(table, allocator, 0, (*table)->buckets_num); CudaCheckError(); } /* Double the capacity on storage, must be followed by calling the * rehash_kernel. */ template -void double_capacity(Table** table) { +void double_capacity(Table** table, BaseAllocator* allocator) { realloc(&((*table)->locks), (*table)->buckets_num * sizeof(Mutex), - (*table)->buckets_num * sizeof(Mutex) * 2); + (*table)->buckets_num * sizeof(Mutex) * 2, allocator); realloc(&((*table)->buckets_size), (*table)->buckets_num * sizeof(int), - (*table)->buckets_num * sizeof(int) * 2); + (*table)->buckets_num * sizeof(int) * 2, allocator); realloc*>( &((*table)->buckets), (*table)->buckets_num * sizeof(Bucket), - (*table)->buckets_num * sizeof(Bucket) * 2); + (*table)->buckets_num * sizeof(Bucket) * 2, allocator); - initialize_buckets(table, (*table)->buckets_num, + initialize_buckets(table, allocator, (*table)->buckets_num, (*table)->buckets_num * 2); (*table)->capacity *= 2; @@ -318,7 +363,7 @@ void double_capacity(Table** table) { /* free all of the resource of a Table. */ template -void destroy_table(Table** table) { +void destroy_table(Table** table, BaseAllocator* allocator) { uint8_t** d_address = nullptr; CUDA_CHECK(cudaMalloc((void**)&d_address, sizeof(uint8_t*))); for (int i = 0; i < (*table)->buckets_num; i++) { @@ -327,15 +372,15 @@ void destroy_table(Table** table) { <<<1, 1>>>((*table)->buckets, i, d_address); CUDA_CHECK(cudaMemcpy(&h_address, d_address, sizeof(uint8_t*), cudaMemcpyDeviceToHost)); - CUDA_CHECK(cudaFree(h_address)); + allocator->free(MemoryType::Device, h_address); } CUDA_CHECK(cudaFree(d_address)); for (int i = 0; i < (*table)->num_of_memory_slices; i++) { if (is_on_device((*table)->slices[i])) { - CUDA_CHECK(cudaFree((*table)->slices[i])); + allocator->free(MemoryType::Device, (*table)->slices[i]); } else { - CUDA_CHECK(cudaFreeHost((*table)->slices[i])); + allocator->free(MemoryType::Pinned, (*table)->slices[i]); } } { @@ -345,11 +390,11 @@ void destroy_table(Table** table) { release_locks <<>>((*table)->locks, 0, (*table)->buckets_num); } - std::free((*table)->slices); - CUDA_CHECK(cudaFree((*table)->buckets_size)); - CUDA_CHECK(cudaFree((*table)->buckets)); - CUDA_CHECK(cudaFree((*table)->locks)); - std::free(*table); + allocator->free(MemoryType::Host, (*table)->slices); + allocator->free(MemoryType::Device, (*table)->buckets_size); + allocator->free(MemoryType::Device, (*table)->buckets); + allocator->free(MemoryType::Device, (*table)->locks); + allocator->free(MemoryType::Host, *table); CUDA_CHECK(cudaDeviceSynchronize()); CudaCheckError(); } diff --git a/include/merlin/core_kernels/kernel_utils.cuh b/include/merlin/core_kernels/kernel_utils.cuh index 8672a1373..a2369a0fa 100644 --- a/include/merlin/core_kernels/kernel_utils.cuh +++ b/include/merlin/core_kernels/kernel_utils.cuh @@ -54,14 +54,15 @@ __forceinline__ __device__ void LDGSTS(ElementType* dst, const ElementType* src); template <> -__forceinline__ __device__ void LDGSTS( - uint8_t* dst, const uint8_t* src) { +__forceinline__ __device__ void LDGSTS(uint8_t* dst, + const uint8_t* src) { uint8_t element = *src; *dst = element; } template <> -__forceinline__ __device__ void LDGSTS(uint16_t* dst, const uint16_t* src) { +__forceinline__ __device__ void LDGSTS(uint16_t* dst, + const uint16_t* src) { uint16_t element = *src; *dst = element; } diff --git a/include/merlin/group_lock.hpp b/include/merlin/group_lock.hpp index 703c2407e..6760db4a8 100644 --- a/include/merlin/group_lock.hpp +++ b/include/merlin/group_lock.hpp @@ -14,6 +14,17 @@ * limitations under the License. */ +#pragma once + +#include +#include +#include +#include +#include + +namespace nv { +namespace merlin { + /* * Implementing a group mutex and relative lock guard for better E2E * performance: @@ -24,14 +35,6 @@ * - The `write_read_lock` is used for special APIs (like `reserve` `erase` * `clear` etc.) */ -#include -#include -#include -#include -#include - -namespace nv { -namespace merlin { class group_shared_mutex { public: diff --git a/include/merlin/initializers.cuh b/include/merlin/initializers.cuh deleted file mode 100644 index 6df875688..000000000 --- a/include/merlin/initializers.cuh +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "curand_philox4x32_x.h" -#include "types.cuh" -#include "utils.cuh" - -namespace nv { -namespace merlin { -namespace initializers { - -inline void cuda_rand_check_(curandStatus_t val, const char* file, int line) { - if (val != CURAND_STATUS_SUCCESS) { - throw CudaException(std::string(file) + ":" + std::to_string(line) + - ": CURAND error " + std::to_string(val)); - } -} - -#define CURAND_CHECK(val) \ - { nv::merlin::initializers::cuda_rand_check_((val), __FILE__, __LINE__); } - -template -void zeros(T* d_data, const size_t len, cudaStream_t stream) { - CUDA_CHECK(cudaMemsetAsync(d_data, 0, len, stream)); -} - -template -void random_normal(T* d_data, const size_t len, cudaStream_t stream, - const T mean = 0.0, const T stddev = 0.05, - const unsigned long long seed = 2022ULL) { - curandGenerator_t generator; - CURAND_CHECK(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_DEFAULT)); - CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(generator, seed)); - CURAND_CHECK(curandGenerateNormal(generator, d_data, len, mean, stddev)); -} - -template -__global__ void adjust_max_min(T* d_data, const T minval, const T maxval, - const size_t N) { - int tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid < N) { - d_data[tid] = - d_data[tid] * (maxval - minval) + (0.5 * (maxval + minval) - 0.5); - } -} - -template -void random_uniform(T* d_data, const size_t len, cudaStream_t stream, - const T minval = 0.0, const T maxval = 1.0, - const unsigned long long seed = 2022ULL) { - curandGenerator_t generator; - - CURAND_CHECK(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_DEFAULT)); - CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(generator, seed)); - - int N = len; - int block_size = 256; - int grid_size = (N + block_size - 1) / block_size; - CURAND_CHECK(curandGenerateUniform(generator, d_data, N)); - adjust_max_min - <<>>(d_data, minval, maxval, N); -} - -template -__global__ void init_states(curandStatePhilox4_32_10_t* states, - const unsigned long long seed, const size_t N) { - int tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid < N) { - curand_init(seed, tid, 0, &states[tid]); - } -} - -template -__global__ void make_truncated_normal(T* d_data, - curandStatePhilox4_32_10_t* states, - const size_t N) { - int tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid < N) { - constexpr T truncated_val = T(2.0); - while (fabsf(d_data[tid]) > truncated_val) { - d_data[tid] = curand_normal(&states[tid]); - } - } -} - -template -void truncated_normal(T* d_data, const size_t len, cudaStream_t stream, - const T minval = 0.0, const T maxval = 1.0, - const unsigned long long seed = 2022ULL) { - curandGenerator_t generator; - - CURAND_CHECK(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_DEFAULT)); - CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(generator, seed)); - - int N = len; - int block_size = 256; - int grid_size = (N + block_size - 1) / block_size; - curandStatePhilox4_32_10_t* d_states; - CUDA_CHECK(cudaMallocAsync(&d_states, N, stream)); - - init_states<<>>(d_states, seed, N); - - make_truncated_normal - <<>>(d_data, d_states, N); - - adjust_max_min - <<>>(d_data, minval, maxval, N); - - CUDA_CHECK(cudaFreeAsync(d_states, stream)); -} - -template -class Initializer { - public: - virtual ~Initializer() {} - virtual void initialize(T* data, size_t len, cudaStream_t stream) {} -}; - -template -class Zeros final : public Initializer { - public: - void initialize(T* data, const size_t len, cudaStream_t stream) override { - zeros(data, len, stream); - } -}; - -} // namespace initializers -} // namespace merlin -} // namespace nv \ No newline at end of file diff --git a/include/merlin/memory_pool.cuh b/include/merlin/memory_pool.cuh index 271676a6e..0b78b13ae 100644 --- a/include/merlin/memory_pool.cuh +++ b/include/merlin/memory_pool.cuh @@ -24,6 +24,7 @@ #include #include #include +#include "allocator.cuh" #include "debug.hpp" namespace nv { @@ -36,25 +37,29 @@ namespace merlin { template struct AllocatorBase { using type = T; - using sync_unique_ptr = std::unique_ptr; + using sync_unique_ptr = std::unique_ptr>; using async_unique_ptr = std::unique_ptr>; using shared_ptr = std::shared_ptr; - inline static sync_unique_ptr make_unique(size_t n) { - return sync_unique_ptr(Allocator::alloc(n)); + inline static sync_unique_ptr make_unique(size_t n, + BaseAllocator* allocator) { + return {Allocator::alloc(n, allocator), + [allocator](type* p) { Allocator::free(p, allocator); }}; } - inline static async_unique_ptr make_unique(size_t n, cudaStream_t stream) { - return {Allocator::alloc(n, stream), - [stream](type* p) { Allocator::free(p); }}; + inline static async_unique_ptr make_unique(size_t n, BaseAllocator* allocator, + cudaStream_t stream) { + return {Allocator::alloc(n, allocator, stream), + [stream, allocator](type* p) { Allocator::free(p, allocator); }}; } - inline static shared_ptr make_shared(size_t n, cudaStream_t stream = 0) { - return {Allocator::alloc(n, stream), - [stream](type* p) { Allocator::free(p, stream); }}; + inline static shared_ptr make_shared(size_t n, BaseAllocator* allocator, + cudaStream_t stream = 0) { + return {Allocator::alloc(n, allocator, stream), + [stream, allocator](type* p) { + Allocator::free(p, allocator, stream); + }}; } - - inline void operator()(type* ptr) { Allocator::free(ptr); } }; /** @@ -68,11 +73,17 @@ struct StandardAllocator final : AllocatorBase> { static constexpr const char* name{"StandardAllocator"}; - inline static type* alloc(size_t n, cudaStream_t stream = 0) { - return new type[n]; + inline static type* alloc(size_t n, BaseAllocator* allocator, + cudaStream_t stream = 0) { + type* ptr; + allocator->alloc(MemoryType::Host, (void**)&ptr, n * sizeof(T)); + return ptr; } - inline static void free(type* ptr, cudaStream_t stream = 0) { delete[] ptr; } + inline static void free(type* ptr, BaseAllocator* allocator, + cudaStream_t stream = 0) { + allocator->free(MemoryType::Host, ptr); + } }; /** @@ -84,14 +95,16 @@ struct HostAllocator final : AllocatorBase> { static constexpr const char* name{"HostAllocator"}; - inline static type* alloc(size_t n, cudaStream_t stream = 0) { + inline static type* alloc(size_t n, BaseAllocator* allocator, + cudaStream_t stream = 0) { void* ptr; - CUDA_CHECK(cudaMallocHost(&ptr, sizeof(T) * n)); + allocator->alloc(MemoryType::Pinned, (void**)&ptr, n * sizeof(T)); return reinterpret_cast(ptr); } - inline static void free(type* ptr, cudaStream_t stream = 0) { - CUDA_CHECK(cudaFreeHost(ptr)); + inline static void free(type* ptr, BaseAllocator* allocator, + cudaStream_t stream = 0) { + allocator->free(MemoryType::Pinned, ptr); } }; @@ -105,26 +118,25 @@ struct DeviceAllocator final : AllocatorBase> { static constexpr const char* name{"DeviceAllocator"}; - inline static type* alloc(size_t n, cudaStream_t stream = 0) { + inline static type* alloc(size_t n, BaseAllocator* allocator, + cudaStream_t stream = 0) { void* ptr; - cudaError_t res; if (stream) { - res = cudaMallocAsync(&ptr, sizeof(T) * n, stream); + allocator->alloc_async(MemoryType::Device, (void**)&ptr, n * sizeof(T), + stream); } else { - res = cudaMalloc(&ptr, sizeof(T) * n); + allocator->alloc(MemoryType::Device, (void**)&ptr, n * sizeof(T)); } - CUDA_CHECK(res); return reinterpret_cast(ptr); } - inline static void free(type* ptr, cudaStream_t stream = 0) { - cudaError_t res; + inline static void free(type* ptr, BaseAllocator* allocator, + cudaStream_t stream = 0) { if (stream) { - res = cudaFreeAsync(ptr, stream); + allocator->free_async(MemoryType::Device, ptr, stream); } else { - res = cudaFree(ptr); + allocator->free(MemoryType::Device, ptr); } - CUDA_CHECK(res); } }; @@ -328,7 +340,8 @@ class MemoryPool final { } }; - MemoryPool(const MemoryPoolOptions& options) : options_{options} { + MemoryPool(const MemoryPoolOptions& options, BaseAllocator* allocator) + : options_{options}, allocator_{allocator} { // Create initial buffer stock. stock_.reserve(options_.max_stock); @@ -390,7 +403,7 @@ class MemoryPool final { void deplete_stock() { std::lock_guard lock(mutex_); for (auto& ptr : stock_) { - Allocator::free(ptr); + Allocator::free(ptr, allocator_); } stock_.clear(); } @@ -442,7 +455,7 @@ class MemoryPool final { std::get<1>(pending) == buffer_size_) { stock_.emplace_back(std::get<0>(pending)); } else { - Allocator::free(std::get<0>(pending), stream); + Allocator::free(std::get<0>(pending), allocator_, stream); } ready_events_.emplace_back(std::get<2>(pending)); return true; @@ -458,7 +471,7 @@ class MemoryPool final { inline void clear_stock_unsafe(cudaStream_t stream) { for (auto& ptr : stock_) { - Allocator::free(ptr, stream); + Allocator::free(ptr, allocator_, stream); } stock_.clear(); } @@ -500,7 +513,7 @@ class MemoryPool final { // Forge new buffers until request can be filled. for (; first != last; ++first) { - *first = Allocator::alloc(allocation_size, stream); + *first = Allocator::alloc(allocation_size, allocator_, stream); } return allocation_size; @@ -516,7 +529,7 @@ class MemoryPool final { // occured), the provided buffers are incompatible and have to be discarded. if (allocation_size != buffer_size_) { while (first != last) { - Allocator::free(*first++); + Allocator::free(*first++, allocator_); } return; } @@ -560,7 +573,7 @@ class MemoryPool final { if (stock_.size() < options_.max_stock) { stock_.emplace_back(*first); } else { - Allocator::free(*first); + Allocator::free(*first, allocator_); } } } @@ -574,6 +587,7 @@ class MemoryPool final { std::vector ready_events_; std::vector> pending_; + BaseAllocator* allocator_; }; template diff --git a/include/merlin/utils.cuh b/include/merlin/utils.cuh index 977ed4b2b..316f12bf4 100644 --- a/include/merlin/utils.cuh +++ b/include/merlin/utils.cuh @@ -24,7 +24,6 @@ #include #include #include "cuda_fp16.h" -#include "cuda_runtime_api.h" #include "debug.hpp" using namespace cooperative_groups; @@ -270,48 +269,6 @@ struct TypeConvertFunc { } }; -template -void realloc(P* ptr, size_t old_size, size_t new_size) { - // Truncate old_size to limit dowstream copy ops. - old_size = std::min(old_size, new_size); - - // Alloc new buffer and copy at old data. - char* new_ptr; - CUDA_CHECK(cudaMalloc(&new_ptr, new_size)); - if (*ptr != nullptr) { - CUDA_CHECK(cudaMemcpy(new_ptr, *ptr, old_size, cudaMemcpyDefault)); - CUDA_CHECK(cudaFree(*ptr)); - } - - // Zero-fill remainder. - CUDA_CHECK(cudaMemset(new_ptr + old_size, 0, new_size - old_size)); - - // Switch to new pointer. - *ptr = reinterpret_cast

(new_ptr); - return; -} - -template -void realloc_host(P* ptr, size_t old_size, size_t new_size) { - // Truncate old_size to limit dowstream copy ops. - old_size = std::min(old_size, new_size); - - // Alloc new buffer and copy at old data. - char* new_ptr = reinterpret_cast(std::malloc(new_size)); - - if (*ptr != nullptr) { - std::memcpy(new_ptr, *ptr, old_size); - std::free(*ptr); - } - - // Zero-fill remainder. - std::memset(new_ptr + old_size, 0, new_size - old_size); - - // Switch to new pointer. - *ptr = reinterpret_cast

(new_ptr); - return; -} - template __forceinline__ __device__ void lock( const cg::thread_block_tile& tile, mutex& set_mutex, diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index 55294b850..d4fd6b43f 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -22,9 +22,11 @@ #include #include #include +#include #include #include #include +#include "merlin/allocator.cuh" #include "merlin/array_kernels.cuh" #include "merlin/core_kernels.cuh" #include "merlin/flexible_buffer.cuh" @@ -156,6 +158,7 @@ class HashTable { using value_type = V; using score_type = S; using Pred = EraseIfPredict; + using allocator_type = BaseAllocator; private: using TableCore = nv::merlin::Table; @@ -186,10 +189,15 @@ class HashTable { CUDA_CHECK(cudaDeviceSynchronize()); initialized_ = false; - destroy_table(&table_); - CUDA_CHECK(cudaFree(d_table_)); + destroy_table(&table_, allocator_); + allocator_->free(MemoryType::Device, d_table_); dev_mem_pool_.reset(); host_mem_pool_.reset(); + + CUDA_CHECK(cudaDeviceSynchronize()); + if (default_allocator_ && allocator_ != nullptr) { + delete allocator_; + } } } @@ -205,12 +213,16 @@ class HashTable { * * @param options The configuration options. */ - void init(const HashTableOptions& options) { + void init(const HashTableOptions& options, + allocator_type* allocator = nullptr) { if (initialized_) { return; } options_ = options; + default_allocator_ = (allocator == nullptr); + allocator_ = (allocator == nullptr) ? (new DefaultAllocator()) : allocator; + if (options_.device_id >= 0) { CUDA_CHECK(cudaSetDevice(options_.device_id)); } else { @@ -232,22 +244,24 @@ class HashTable { CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id)); shared_mem_size_ = deviceProp.sharedMemPerBlock; create_table( - &table_, options_.dim, options_.init_capacity, options_.max_capacity, - options_.max_hbm_for_vectors, options_.max_bucket_size); + &table_, allocator_, options_.dim, options_.init_capacity, + options_.max_capacity, options_.max_hbm_for_vectors, + options_.max_bucket_size); options_.block_size = SAFE_GET_BLOCK_SIZE(options_.block_size); reach_max_capacity_ = (options_.init_capacity * 2 > options_.max_capacity); MERLIN_CHECK((!(options_.io_by_cpu && options_.max_hbm_for_vectors != 0)), "[HierarchicalKV] `io_by_cpu` should not be true when " "`max_hbm_for_vectors` is not 0!"); - CUDA_CHECK(cudaMalloc((void**)&(d_table_), sizeof(TableCore))); + allocator_->alloc(MemoryType::Device, (void**)&(d_table_), + sizeof(TableCore)); sync_table_configuration(); // Create memory pools. dev_mem_pool_ = std::make_unique>>( - options_.device_memory_pool); + options_.device_memory_pool, allocator_); host_mem_pool_ = std::make_unique>>( - options_.host_memory_pool); + options_.host_memory_pool, allocator_); CUDA_CHECK(cudaDeviceSynchronize()); initialized_ = true; @@ -1439,7 +1453,7 @@ class HashTable { while (capacity() < new_capacity && capacity() * 2 <= options_.max_capacity) { - double_capacity(&table_); + double_capacity(&table_, allocator_); CUDA_CHECK(cudaDeviceSynchronize()); sync_table_configuration(); @@ -1617,7 +1631,7 @@ class HashTable { sizeof(value_type) * dim()}; MERLIN_CHECK(max_workspace_size >= tuple_size, "[HierarchicalKV] max_workspace_size is smaller than a single " - "`key + scoredata + value` tuple! Please set a larger value!"); + "`key + score + value` tuple! Please set a larger value!"); const size_type n{max_workspace_size / tuple_size}; const size_type ws_size{n * tuple_size}; @@ -1754,6 +1768,8 @@ class HashTable { const unsigned int kernel_select_interval_ = 7; std::unique_ptr dev_mem_pool_; std::unique_ptr host_mem_pool_; + allocator_type* allocator_; + bool default_allocator_ = true; }; } // namespace merlin diff --git a/tests/memory_pool_test.cc.cu b/tests/memory_pool_test.cc.cu index 75deb2db6..b56f1aed9 100644 --- a/tests/memory_pool_test.cc.cu +++ b/tests/memory_pool_test.cc.cu @@ -17,6 +17,7 @@ #include #include #include +#include "merlin/allocator.cuh" #include "merlin/memory_pool.cuh" using namespace nv::merlin; @@ -31,16 +32,18 @@ struct DebugAllocator final static constexpr const char* name{"DebugAllocator"}; - inline static type* alloc(size_t n, cudaStream_t stream = 0) { - type* ptr{Allocator::alloc(n, stream)}; + inline static type* alloc(size_t n, BaseAllocator* allocator, + cudaStream_t stream = 0) { + type* ptr{Allocator::alloc(n, allocator, stream)}; std::cout << Allocator::name << "[type_name = " << typeid(type).name() << "]: " << static_cast(ptr) << " allocated = " << n << " x " << sizeof(type) << " bytes, stream = " << stream << '\n'; return ptr; } - inline static void free(type* ptr, cudaStream_t stream = 0) { - Allocator::free(ptr, stream); + inline static void free(type* ptr, BaseAllocator* allocator, + cudaStream_t stream = 0) { + Allocator::free(ptr, allocator, stream); std::cout << Allocator::name << "[type_name = " << typeid(type).name() << "]: " << static_cast(ptr) << " freed, stream = " << stream << '\n'; @@ -92,9 +95,10 @@ std::ostream& operator<<(std::ostream& os, const SomeType& obj) { void test_standard_allocator() { using Allocator = DebugAllocator>; + std::shared_ptr default_allocator(new DefaultAllocator()); { - auto ptr{Allocator::make_unique(1)}; + auto ptr{Allocator::make_unique(1, default_allocator.get())}; ASSERT_NE(ptr.get(), nullptr); std::cout << "Sync UPtr after alloc: " << *ptr << std::endl; @@ -107,7 +111,7 @@ void test_standard_allocator() { } { - auto ptr{Allocator::make_unique(1, nullptr)}; + auto ptr{Allocator::make_unique(1, default_allocator.get(), nullptr)}; ASSERT_NE(ptr.get(), nullptr); std::cout << "Async UPtr after alloc: " << *ptr << std::endl; @@ -120,7 +124,7 @@ void test_standard_allocator() { } { - auto ptr{Allocator::make_shared(1)}; + auto ptr{Allocator::make_shared(1, default_allocator.get())}; ASSERT_NE(ptr.get(), nullptr); std::cout << "SPtr after alloc: " << *ptr << std::endl; @@ -135,9 +139,10 @@ void test_standard_allocator() { void test_host_allocator() { using Allocator = DebugAllocator>; + std::shared_ptr default_allocator(new DefaultAllocator()); { - auto ptr{Allocator::make_unique(1)}; + auto ptr{Allocator::make_unique(1, default_allocator.get())}; ASSERT_NE(ptr.get(), nullptr); std::cout << "Sync UPtr after alloc: " << *ptr << std::endl; @@ -150,7 +155,7 @@ void test_host_allocator() { } { - auto ptr{Allocator::make_unique(1, nullptr)}; + auto ptr{Allocator::make_unique(1, default_allocator.get(), nullptr)}; ASSERT_NE(ptr.get(), nullptr); std::cout << "Async UPtr after alloc: " << *ptr << std::endl; @@ -163,7 +168,7 @@ void test_host_allocator() { } { - auto ptr{Allocator::make_shared(1)}; + auto ptr{Allocator::make_shared(1, default_allocator.get())}; ASSERT_NE(ptr.get(), nullptr); std::cout << "SPtr after alloc: " << *ptr << std::endl; @@ -178,6 +183,7 @@ void test_host_allocator() { void test_device_allocator() { using Allocator = DebugAllocator>; + std::shared_ptr default_allocator(new DefaultAllocator()); int num_devices; CUDA_CHECK(cudaGetDeviceCount(&num_devices)); @@ -190,11 +196,14 @@ void test_device_allocator() { CUDA_CHECK(cudaStreamCreate(&stream)); { - auto ptr{Allocator::make_unique(1)}; + auto ptr{Allocator::make_unique(1, default_allocator.get())}; ASSERT_NE(ptr.get(), nullptr); std::cout << "Sync UPtr after alloc: " << *ptr << std::endl; const SomeType tmp{47, 11}; + + std::cout << "Sync UPtr after alloc get ptr: " << ptr.get() << std::endl; + CUDA_CHECK(cudaMemset(ptr.get(), 0, sizeof(SomeType))); CUDA_CHECK( cudaMemcpy(ptr.get(), &tmp, sizeof(SomeType), cudaMemcpyHostToDevice)); std::cout << "Sync UPtr after set: " << *ptr << std::endl; @@ -204,7 +213,7 @@ void test_device_allocator() { } { - auto ptr{Allocator::make_unique(1, stream)}; + auto ptr{Allocator::make_unique(1, default_allocator.get(), stream)}; ASSERT_NE(ptr.get(), nullptr); std::cout << "Async UPtr after alloc: " << *ptr << std::endl; @@ -218,7 +227,7 @@ void test_device_allocator() { } { - auto ptr{Allocator::make_shared(1, stream)}; + auto ptr{Allocator::make_shared(1, default_allocator.get(), stream)}; ASSERT_NE(ptr.get(), nullptr); std::cout << "SPtr after alloc: " << *ptr << std::endl; @@ -241,98 +250,102 @@ void test_borrow_return_no_context() { "Need at least one CUDA capable device for running this test."); CUDA_CHECK(cudaSetDevice(0)); - MemoryPool>> pool{opt}; - const size_t buffer_size{256L * 1024}; - - // Initial status. - std::cout << ".:: Initial state ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 0); - - // Borrow and return one buffer (unique ptr). + std::shared_ptr default_allocator(new DefaultAllocator()); { - auto buffer{pool.get_unique(buffer_size)}; - std::cout << ".:: Borrow 1 (unique) ::.\n" << pool << std::endl; + MemoryPool>> pool{ + opt, default_allocator.get()}; + const size_t buffer_size{256L * 1024}; + + // Initial status. + std::cout << ".:: Initial state ::.\n" << pool << std::endl; ASSERT_EQ(pool.current_stock(), 0); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 1 (unique) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 1); - ASSERT_EQ(pool.num_pending(), 0); - // Borrow and return one buffer (shared ptr). - { - auto buffer{pool.get_shared(buffer_size)}; - std::cout << ".:: Borrow 1 (shared) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); + // Borrow and return one buffer (unique ptr). + { + auto buffer{pool.get_unique(buffer_size)}; + std::cout << ".:: Borrow 1 (unique) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 1 (unique) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 1); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 1 (shared) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 1); - ASSERT_EQ(pool.num_pending(), 0); - // Borrow static workspace with less than `max_stock` buffers. - { - auto ws{pool.get_workspace<2>(buffer_size)}; - std::cout << ".:: Borrow 2 (static) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); + // Borrow and return one buffer (shared ptr). + { + auto buffer{pool.get_shared(buffer_size)}; + std::cout << ".:: Borrow 1 (shared) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 1 (shared) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 1); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 2 (static) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 2); - ASSERT_EQ(pool.num_pending(), 0); - // Borrow dynamic workspace with less than `max_stock` buffers. - { - auto ws{pool.get_workspace(2, buffer_size)}; - std::cout << ".:: Borrow 2 (dynamic) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); + // Borrow static workspace with less than `max_stock` buffers. + { + auto ws{pool.get_workspace<2>(buffer_size)}; + std::cout << ".:: Borrow 2 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 2 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 2); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 2 (dynamic) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 2); - ASSERT_EQ(pool.num_pending(), 0); + // Borrow dynamic workspace with less than `max_stock` buffers. + { + auto ws{pool.get_workspace(2, buffer_size)}; + std::cout << ".:: Borrow 2 (dynamic) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } - // Await unfinished GPU work (shouldn't change anything). - pool.await_pending(); - std::cout << ".:: Await pending (shouldn't change anything) ::.\n" - << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 2); - ASSERT_EQ(pool.num_pending(), 0); + std::cout << ".:: Return 2 (dynamic) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 2); + ASSERT_EQ(pool.num_pending(), 0); - // Borrow workspace that exceeds base pool size. - { - auto ws{pool.get_workspace<6>(buffer_size)}; - std::cout << ".:: Borrow 6 (static) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); + // Await unfinished GPU work (shouldn't change anything). + pool.await_pending(); + std::cout << ".:: Await pending (shouldn't change anything) ::.\n" + << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 2); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 6 (static) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), opt.max_stock); - ASSERT_EQ(pool.num_pending(), 0); - // Borrow a buffer that is smaller than the current buffer size. - { - auto ws{pool.get_unique(buffer_size / 2)}; - std::cout << ".:: Borrow 1 (smaller) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), opt.max_stock - 1); + // Borrow workspace that exceeds base pool size. + { + auto ws{pool.get_workspace<6>(buffer_size)}; + std::cout << ".:: Borrow 6 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 6 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), opt.max_stock); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 1 (smaller) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), opt.max_stock); - ASSERT_EQ(pool.num_pending(), 0); - // Borrow a buffer that is bigger than the current buffer size. - { - auto ws{pool.get_unique(buffer_size + 37)}; - std::cout << ".:: Borrow 1 (bigger) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); + // Borrow a buffer that is smaller than the current buffer size. + { + auto ws{pool.get_unique(buffer_size / 2)}; + std::cout << ".:: Borrow 1 (smaller) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), opt.max_stock - 1); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 1 (smaller) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), opt.max_stock); + ASSERT_EQ(pool.num_pending(), 0); + + // Borrow a buffer that is bigger than the current buffer size. + { + auto ws{pool.get_unique(buffer_size + 37)}; + std::cout << ".:: Borrow 1 (bigger) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 1 (smaller) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 1); ASSERT_EQ(pool.num_pending(), 0); } - std::cout << ".:: Return 1 (smaller) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 1); - ASSERT_EQ(pool.num_pending(), 0); } void test_borrow_return_with_context() { @@ -345,120 +358,124 @@ void test_borrow_return_with_context() { cudaStream_t stream; CUDA_CHECK(cudaStreamCreate(&stream)); - MemoryPool>> pool(opt); - const size_t buffer_size{256L * 1024}; - - // Initial status. - std::cout << ".:: Initial state ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 0); - - // Borrow and return one buffer (unique ptr). + std::shared_ptr default_allocator(new DefaultAllocator()); { - auto buffer{pool.get_unique(buffer_size, stream)}; - std::cout << ".:: Borrow 1 (unique) ::.\n" << pool << std::endl; + MemoryPool>> pool( + opt, default_allocator.get()); + const size_t buffer_size{256L * 1024}; + + // Initial status. + std::cout << ".:: Initial state ::.\n" << pool << std::endl; ASSERT_EQ(pool.current_stock(), 0); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 1 (unique) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 1); - // Borrow and return one buffer (shared ptr). - { - auto buffer{pool.get_shared(buffer_size, stream)}; - std::cout << ".:: Borrow 1 (shared) ::.\n" << pool << std::endl; + // Borrow and return one buffer (unique ptr). + { + auto buffer{pool.get_unique(buffer_size, stream)}; + std::cout << ".:: Borrow 1 (unique) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 1 (unique) ::.\n" << pool << std::endl; ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 1 (shared) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 1); + ASSERT_EQ(pool.num_pending(), 1); - // Borrow static workspace with less than `max_stock` buffers. - { - auto ws{pool.get_workspace<2>(buffer_size, stream)}; - std::cout << ".:: Borrow 2 (static) ::.\n" << pool << std::endl; + // Borrow and return one buffer (shared ptr). + { + auto buffer{pool.get_shared(buffer_size, stream)}; + std::cout << ".:: Borrow 1 (shared) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 1 (shared) ::.\n" << pool << std::endl; ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 2 (static) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 2); - - // Await unfinished GPU work. - pool.await_pending(stream); - std::cout << ".:: Await pending ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 2); - ASSERT_EQ(pool.num_pending(), 0); - - // Borrow workspace that exceeds base pool size. Possible results: - // 1. If this thread is slower than the driver. - // Upon return we will see a partial deallocation before inserting the last - // buffer into the pending queue. - // 2. If this the driver is slower than this thread queuing/querying events. - // Either 0-3 buffers in stock partial dallocation - // 1-5 buffers pending. Hence there is no good way to check. - { - auto ws{pool.get_workspace<6>(buffer_size, stream)}; - std::cout << ".:: Borrow 6 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.num_pending(), 1); + + // Borrow static workspace with less than `max_stock` buffers. + { + auto ws{pool.get_workspace<2>(buffer_size, stream)}; + std::cout << ".:: Borrow 2 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 2 (static) ::.\n" << pool << std::endl; ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 2); + + // Await unfinished GPU work. + pool.await_pending(stream); + std::cout << ".:: Await pending ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 2); ASSERT_EQ(pool.num_pending(), 0); - } - std::cout << ".:: Return 6 (static) ::.\n" << pool << std::endl; - ASSERT_GE(pool.num_pending(), 1); - - // Ensure stable situation by - // - ensuring that all pending buffers dealt with. - // - pinning 3 buffers, while clearing the remaining stock - // - Then we pin 1 of the 3 buffers and release it to make it pending. - // - Result: 2 stock buffers, 1 pending. - pool.await_pending(); - ASSERT_EQ(pool.num_pending(), 0); - { - auto ws{pool.get_workspace<3>(buffer_size, stream)}; - pool.deplete_stock(); - ASSERT_EQ(pool.current_stock(), 0); - } - pool.await_pending(stream); - { auto ws{pool.get_workspace<1>(buffer_size, stream)}; } - ASSERT_EQ(pool.current_stock(), 2); - ASSERT_EQ(pool.num_pending(), 1); - std::cout << ".:: Ensure 2 stock + 1 pending situation ::.\n" - << pool << std::endl; - - // Borrow a buffer that is smaller than the current buffer size. - { - auto ws{pool.get_unique(buffer_size / 2, stream)}; - std::cout << ".:: Borrow 1 (smaller) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 1); + + // Borrow workspace that exceeds base pool size. Possible results: + // 1. If this thread is slower than the driver. + // Upon return we will see a partial deallocation before inserting the + // last buffer into the pending queue. + // 2. If this the driver is slower than this thread queuing/querying events. + // Either 0-3 buffers in stock partial dallocation + // 1-5 buffers pending. Hence there is no good way to check. + { + auto ws{pool.get_workspace<6>(buffer_size, stream)}; + std::cout << ".:: Borrow 6 (static) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 0); + } + std::cout << ".:: Return 6 (static) ::.\n" << pool << std::endl; + ASSERT_GE(pool.num_pending(), 1); + + // Ensure stable situation by + // - ensuring that all pending buffers dealt with. + // - pinning 3 buffers, while clearing the remaining stock + // - Then we pin 1 of the 3 buffers and release it to make it pending. + // - Result: 2 stock buffers, 1 pending. + pool.await_pending(); + ASSERT_EQ(pool.num_pending(), 0); + { + auto ws{pool.get_workspace<3>(buffer_size, stream)}; + pool.deplete_stock(); + ASSERT_EQ(pool.current_stock(), 0); + } + pool.await_pending(stream); + { auto ws{pool.get_workspace<1>(buffer_size, stream)}; } + ASSERT_EQ(pool.current_stock(), 2); ASSERT_EQ(pool.num_pending(), 1); - } - std::cout << ".:: Return 1 (smaller) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 1); - ASSERT_EQ(pool.num_pending(), 2); + std::cout << ".:: Ensure 2 stock + 1 pending situation ::.\n" + << pool << std::endl; + + // Borrow a buffer that is smaller than the current buffer size. + { + auto ws{pool.get_unique(buffer_size / 2, stream)}; + std::cout << ".:: Borrow 1 (smaller) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 1); + ASSERT_EQ(pool.num_pending(), 1); + } + std::cout << ".:: Return 1 (smaller) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 1); + ASSERT_EQ(pool.num_pending(), 2); - // Borrow a buffer that is bigger than the current buffer size. This will - // evict the stock buffers which are smaller, but will not concern the buffers - // that are still pending. - { - auto ws{pool.get_unique(buffer_size + 37, stream)}; - std::cout << ".:: Borrow 1 (bigger) ::.\n" << pool << std::endl; + // Borrow a buffer that is bigger than the current buffer size. This will + // evict the stock buffers which are smaller, but will not concern the + // buffers that are still pending. + { + auto ws{pool.get_unique(buffer_size + 37, stream)}; + std::cout << ".:: Borrow 1 (bigger) ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 0); + ASSERT_EQ(pool.num_pending(), 2); + } + std::cout << ".:: Return 1 (bigger) ::.\n" << pool << std::endl; ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 2); - } - std::cout << ".:: Return 1 (bigger) ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 0); - ASSERT_EQ(pool.num_pending(), 3); + ASSERT_EQ(pool.num_pending(), 3); - // Because there are now pending buffers that are too small, they will be - // cleared once the associated work has been completed. - pool.await_pending(stream); - std::cout << ".:: Await pending ::.\n" << pool << std::endl; - ASSERT_EQ(pool.current_stock(), 1); - ASSERT_EQ(pool.num_pending(), 0); + // Because there are now pending buffers that are too small, they will be + // cleared once the associated work has been completed. + pool.await_pending(stream); + std::cout << ".:: Await pending ::.\n" << pool << std::endl; + ASSERT_EQ(pool.current_stock(), 1); + ASSERT_EQ(pool.num_pending(), 0); - CUDA_CHECK(cudaStreamDestroy(stream)); + CUDA_CHECK(cudaStreamDestroy(stream)); + } } TEST(MemoryPoolTest, standard_allocator) { test_standard_allocator(); } diff --git a/tests/merlin_hashtable_test.cc.cu b/tests/merlin_hashtable_test.cc.cu index 19e665f86..2401ce7b8 100644 --- a/tests/merlin_hashtable_test.cc.cu +++ b/tests/merlin_hashtable_test.cc.cu @@ -34,6 +34,8 @@ using V = float; using S = uint64_t; using Table = nv::merlin::HashTable; using TableOptions = nv::merlin::HashTableOptions; +using BaseAllocator = nv::merlin::BaseAllocator; +using MemoryType = nv::merlin::MemoryType; template struct EraseIfPredFunctor { @@ -53,6 +55,77 @@ struct ExportIfPredFunctor { } }; +class CustomizedAllocator : public virtual BaseAllocator { + public: + CustomizedAllocator(){}; + ~CustomizedAllocator() override{}; + + void alloc(const MemoryType type, void** ptr, size_t size, + unsigned int pinned_flags = cudaHostAllocDefault) override { + switch (type) { + case MemoryType::Device: + CUDA_CHECK(cudaMalloc(ptr, size)); + break; + case MemoryType::Managed: + CUDA_CHECK(cudaMallocManaged(ptr, size, cudaMemAttachGlobal)); + break; + case MemoryType::Pinned: + CUDA_CHECK(cudaMallocHost(ptr, size, pinned_flags)); + break; + case MemoryType::Host: + *ptr = std::malloc(size); + break; + } + return; + } + + void alloc_async(const MemoryType type, void** ptr, size_t size, + cudaStream_t stream) override { + if (type == MemoryType::Device) { + CUDA_CHECK(cudaMallocAsync(ptr, size, stream)); + } else { + MERLIN_CHECK(false, + "[CustomizedAllocator] alloc_async is only support for " + "MemoryType::Device!"); + } + return; + } + + void free(const MemoryType type, void* ptr) override { + if (ptr == nullptr) { + return; + } + switch (type) { + case MemoryType::Pinned: + CUDA_CHECK(cudaFreeHost(ptr)); + break; + case MemoryType::Device: + case MemoryType::Managed: + CUDA_CHECK(cudaFree(ptr)); + break; + case MemoryType::Host: + std::free(ptr); + break; + } + return; + } + + void free_async(const MemoryType type, void* ptr, + cudaStream_t stream) override { + if (ptr == nullptr) { + return; + } + + if (type == MemoryType::Device) { + CUDA_CHECK(cudaFreeAsync(ptr, stream)); + } else { + MERLIN_CHECK(false, + "[CustomizedAllocator] free_async is only support for " + "MemoryType::Device!"); + } + } +}; + void test_basic(size_t max_hbm_for_vectors) { constexpr uint64_t BUCKET_MAX_SIZE = 128; constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL - (128 + 1); @@ -397,6 +470,9 @@ void test_basic_when_full(size_t max_hbm_for_vectors) { options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors); options.evict_strategy = nv::merlin::EvictStrategy::kCustomized; + std::unique_ptr customized_allocator = + std::make_unique(); + CUDA_CHECK(cudaMallocHost(&h_keys, KEY_NUM * sizeof(K))); CUDA_CHECK(cudaMallocHost(&h_scores, KEY_NUM * sizeof(S))); CUDA_CHECK(cudaMallocHost(&h_vectors, KEY_NUM * sizeof(V) * options.dim)); @@ -437,7 +513,7 @@ void test_basic_when_full(size_t max_hbm_for_vectors) { uint64_t total_size = 0; for (int i = 0; i < TEST_TIMES; i++) { std::unique_ptr table = std::make_unique
(); - table->init(options); + table->init(options, customized_allocator.get()); total_size = table->size(stream); CUDA_CHECK(cudaStreamSynchronize(stream)); ASSERT_EQ(total_size, 0);