Skip to content

Commit

Permalink
[Feat] make allocator customizable
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Jul 27, 2023
1 parent f7d36f7 commit 0da72d4
Show file tree
Hide file tree
Showing 10 changed files with 578 additions and 469 deletions.
127 changes: 127 additions & 0 deletions include/merlin/allocator.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <stdlib.h>
#include "debug.hpp"
#include "utils.cuh"

namespace nv {
namespace merlin {

enum MemoryType {
Device, // HBM
Pinned, // Pinned Host Memory
Host, // Host Memory
Managed, // Pageable Host Memory(Not required)
};

/* This abstract class defines the allocator APIs required by HKV.
Any of the customized allocators should inherit from it.
*/
class BaseAllocator {
public:
BaseAllocator(const BaseAllocator&) = delete;
BaseAllocator(BaseAllocator&&) = delete;

BaseAllocator& operator=(const BaseAllocator&) = delete;
BaseAllocator& operator=(BaseAllocator&&) = delete;

BaseAllocator() = default;
virtual ~BaseAllocator() = default;

virtual void alloc(const MemoryType type, void** ptr, size_t size,
unsigned int pinned_flags = cudaHostAllocDefault) = 0;

virtual void alloc_async(const MemoryType type, void** ptr, size_t size,
cudaStream_t stream) = 0;

virtual void free(const MemoryType type, void* ptr) = 0;

virtual void free_async(const MemoryType type, void* ptr,
cudaStream_t stream) = 0;
};

class DefaultAllocator : public virtual BaseAllocator {
public:
DefaultAllocator(){};
~DefaultAllocator() override{};

void alloc(const MemoryType type, void** ptr, size_t size,
unsigned int pinned_flags = cudaHostAllocDefault) override {
switch (type) {
case MemoryType::Device:
CUDA_CHECK(cudaMalloc(ptr, size));
break;
case MemoryType::Pinned:
CUDA_CHECK(cudaMallocHost(ptr, size, pinned_flags));
break;
case MemoryType::Host:
*ptr = std::malloc(size);
break;
}
return;
}

void alloc_async(const MemoryType type, void** ptr, size_t size,
cudaStream_t stream) override {
if (type == MemoryType::Device) {
CUDA_CHECK(cudaMallocAsync(ptr, size, stream));
} else {
MERLIN_CHECK(false,
"[DefaultAllocator] alloc_async is only support for "
"MemoryType::Device!");
}
return;
}

void free(const MemoryType type, void* ptr) override {
if (ptr == nullptr) {
return;
}
switch (type) {
case MemoryType::Pinned:
CUDA_CHECK(cudaFreeHost(ptr));
break;
case MemoryType::Device:
CUDA_CHECK(cudaFree(ptr));
break;
case MemoryType::Host:
std::free(ptr);
break;
}
return;
}

void free_async(const MemoryType type, void* ptr,
cudaStream_t stream) override {
if (ptr == nullptr) {
return;
}

if (type == MemoryType::Device) {
CUDA_CHECK(cudaFreeAsync(ptr, stream));
} else {
MERLIN_CHECK(false,
"[DefaultAllocator] free_async is only support for "
"MemoryType::Device!");
}
}
};

} // namespace merlin
} // namespace nv
115 changes: 80 additions & 35 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

#pragma once

#include <cstdlib>
#include <cstring>
#include "merlin/allocator.cuh"
#include "merlin/core_kernels/find_or_insert.cuh"
#include "merlin/core_kernels/find_ptr_or_insert.cuh"
#include "merlin/core_kernels/kernel_utils.cuh"
Expand All @@ -30,7 +29,6 @@
namespace nv {
namespace merlin {


template <class S>
__global__ void create_locks(S* __restrict mutex, const size_t start,
const size_t end) {
Expand Down Expand Up @@ -103,10 +101,55 @@ __global__ void get_bucket_others_address(Bucket<K, V, S>* __restrict buckets,
*address = buckets[index].digests_;
}

template <class P>
void realloc(P* ptr, size_t old_size, size_t new_size,
BaseAllocator* allocator) {
// Truncate old_size to limit dowstream copy ops.
old_size = std::min(old_size, new_size);

// Alloc new buffer and copy at old data.
char* new_ptr;
allocator->alloc(MemoryType::Device, (void**)&new_ptr, new_size);
if (*ptr != nullptr) {
CUDA_CHECK(cudaMemcpy(new_ptr, *ptr, old_size, cudaMemcpyDefault));
allocator->free(MemoryType::Device, *ptr);
}

// Zero-fill remainder.
CUDA_CHECK(cudaMemset(new_ptr + old_size, 0, new_size - old_size));

// Switch to new pointer.
*ptr = reinterpret_cast<P>(new_ptr);
return;
}

template <class P>
void realloc_host(P* ptr, size_t old_size, size_t new_size,
BaseAllocator* allocator) {
// Truncate old_size to limit dowstream copy ops.
old_size = std::min(old_size, new_size);

// Alloc new buffer and copy at old data.
char* new_ptr = nullptr;
allocator->alloc(MemoryType::Host, (void**)&new_ptr, new_size);

if (*ptr != nullptr) {
std::memcpy(new_ptr, *ptr, old_size);
allocator->free(MemoryType::Host, *ptr);
}

// Zero-fill remainder.
std::memset(new_ptr + old_size, 0, new_size - old_size);

// Switch to new pointer.
*ptr = reinterpret_cast<P>(new_ptr);
return;
}

/* Initialize the buckets with index from start to end. */
template <class K, class V, class S>
void initialize_buckets(Table<K, V, S>** table, const size_t start,
const size_t end) {
void initialize_buckets(Table<K, V, S>** table, BaseAllocator* allocator,
const size_t start, const size_t end) {
/* As testing results show us, when the number of buckets is greater than
* the 4 million the performance will drop significantly, we believe the
* to many pinned memory allocation causes this issue, so we change the
Expand All @@ -127,7 +170,8 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,

realloc_host<V**>(
&((*table)->slices), (*table)->num_of_memory_slices * sizeof(V*),
((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*));
((*table)->num_of_memory_slices + num_of_memory_slices) * sizeof(V*),
allocator);

for (size_t i = (*table)->num_of_memory_slices;
i < (*table)->num_of_memory_slices + num_of_memory_slices; i++) {
Expand All @@ -138,12 +182,13 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,
(*table)->bucket_max_size * sizeof(V) *
(*table)->dim;
if ((*table)->remaining_hbm_for_vectors >= slice_real_size) {
CUDA_CHECK(cudaMalloc(&((*table)->slices[i]), slice_real_size));
allocator->alloc(MemoryType::Device, (void**)&((*table)->slices[i]),
slice_real_size);
(*table)->remaining_hbm_for_vectors -= slice_real_size;
} else {
(*table)->is_pure_hbm = false;
CUDA_CHECK(cudaMallocHost(&((*table)->slices[i]), slice_real_size,
cudaHostAllocMapped));
allocator->alloc(MemoryType::Pinned, (void**)&((*table)->slices[i]),
slice_real_size, cudaHostAllocMapped);
}
for (int j = 0; j < num_of_buckets_in_one_slice; j++) {
if ((*table)->is_pure_hbm) {
Expand Down Expand Up @@ -178,7 +223,8 @@ void initialize_buckets(Table<K, V, S>** table, const size_t start,
bucket_memory_size += reserve_size * sizeof(uint8_t);
for (int i = start; i < end; i++) {
uint8_t* address = nullptr;
CUDA_CHECK(cudaMalloc(&address, bucket_memory_size));
allocator->alloc(MemoryType::Device, (void**)&(address),
bucket_memory_size);
allocate_bucket_others<K, V, S><<<1, 1>>>((*table)->buckets, i, address,
reserve_size, bucket_max_size);
}
Expand Down Expand Up @@ -244,14 +290,13 @@ size_t get_slice_size(Table<K, V, S>** table) {
DIM: Vector dimension.
*/
template <class K, class V, class S>
void create_table(Table<K, V, S>** table, const size_t dim,
const size_t init_size = 134217728,
void create_table(Table<K, V, S>** table, BaseAllocator* allocator,
const size_t dim, const size_t init_size = 134217728,
const size_t max_size = std::numeric_limits<size_t>::max(),
const size_t max_hbm_for_vectors = 0,
const size_t bucket_max_size = 128,
const size_t tile_size = 32, const bool primary = true) {
(*table) =
reinterpret_cast<Table<K, V, S>*>(std::malloc(sizeof(Table<K, V, S>)));
allocator->alloc(MemoryType::Host, (void**)table, sizeof(Table<K, V, S>));
std::memset(*table, 0, sizeof(Table<K, V, S>));
(*table)->dim = dim;
(*table)->bucket_max_size = bucket_max_size;
Expand All @@ -277,39 +322,39 @@ void create_table(Table<K, V, S>** table, const size_t dim,
(*table)->remaining_hbm_for_vectors = max_hbm_for_vectors;
(*table)->primary = primary;

CUDA_CHECK(cudaMalloc((void**)&((*table)->locks),
(*table)->buckets_num * sizeof(Mutex)));
allocator->alloc(MemoryType::Device, (void**)&((*table)->locks),
(*table)->buckets_num * sizeof(Mutex));
CUDA_CHECK(
cudaMemset((*table)->locks, 0, (*table)->buckets_num * sizeof(Mutex)));

CUDA_CHECK(cudaMalloc((void**)&((*table)->buckets_size),
(*table)->buckets_num * sizeof(int)));
allocator->alloc(MemoryType::Device, (void**)&((*table)->buckets_size),
(*table)->buckets_num * sizeof(int));
CUDA_CHECK(cudaMemset((*table)->buckets_size, 0,
(*table)->buckets_num * sizeof(int)));

CUDA_CHECK(cudaMalloc((void**)&((*table)->buckets),
(*table)->buckets_num * sizeof(Bucket<K, V, S>)));
allocator->alloc(MemoryType::Device, (void**)&((*table)->buckets),
(*table)->buckets_num * sizeof(Bucket<K, V, S>));
CUDA_CHECK(cudaMemset((*table)->buckets, 0,
(*table)->buckets_num * sizeof(Bucket<K, V, S>)));

initialize_buckets<K, V, S>(table, 0, (*table)->buckets_num);
initialize_buckets<K, V, S>(table, allocator, 0, (*table)->buckets_num);
CudaCheckError();
}

/* Double the capacity on storage, must be followed by calling the
* rehash_kernel. */
template <class K, class V, class S>
void double_capacity(Table<K, V, S>** table) {
void double_capacity(Table<K, V, S>** table, BaseAllocator* allocator) {
realloc<Mutex*>(&((*table)->locks), (*table)->buckets_num * sizeof(Mutex),
(*table)->buckets_num * sizeof(Mutex) * 2);
(*table)->buckets_num * sizeof(Mutex) * 2, allocator);
realloc<int*>(&((*table)->buckets_size), (*table)->buckets_num * sizeof(int),
(*table)->buckets_num * sizeof(int) * 2);
(*table)->buckets_num * sizeof(int) * 2, allocator);

realloc<Bucket<K, V, S>*>(
&((*table)->buckets), (*table)->buckets_num * sizeof(Bucket<K, V, S>),
(*table)->buckets_num * sizeof(Bucket<K, V, S>) * 2);
(*table)->buckets_num * sizeof(Bucket<K, V, S>) * 2, allocator);

initialize_buckets<K, V, S>(table, (*table)->buckets_num,
initialize_buckets<K, V, S>(table, allocator, (*table)->buckets_num,
(*table)->buckets_num * 2);

(*table)->capacity *= 2;
Expand All @@ -318,7 +363,7 @@ void double_capacity(Table<K, V, S>** table) {

/* free all of the resource of a Table. */
template <class K, class V, class S>
void destroy_table(Table<K, V, S>** table) {
void destroy_table(Table<K, V, S>** table, BaseAllocator* allocator) {
uint8_t** d_address = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_address, sizeof(uint8_t*)));
for (int i = 0; i < (*table)->buckets_num; i++) {
Expand All @@ -327,15 +372,15 @@ void destroy_table(Table<K, V, S>** table) {
<<<1, 1>>>((*table)->buckets, i, d_address);
CUDA_CHECK(cudaMemcpy(&h_address, d_address, sizeof(uint8_t*),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaFree(h_address));
allocator->free(MemoryType::Device, h_address);
}
CUDA_CHECK(cudaFree(d_address));

for (int i = 0; i < (*table)->num_of_memory_slices; i++) {
if (is_on_device((*table)->slices[i])) {
CUDA_CHECK(cudaFree((*table)->slices[i]));
allocator->free(MemoryType::Device, (*table)->slices[i]);
} else {
CUDA_CHECK(cudaFreeHost((*table)->slices[i]));
allocator->free(MemoryType::Pinned, (*table)->slices[i]);
}
}
{
Expand All @@ -345,11 +390,11 @@ void destroy_table(Table<K, V, S>** table) {
release_locks<Mutex>
<<<grid_size, block_size>>>((*table)->locks, 0, (*table)->buckets_num);
}
std::free((*table)->slices);
CUDA_CHECK(cudaFree((*table)->buckets_size));
CUDA_CHECK(cudaFree((*table)->buckets));
CUDA_CHECK(cudaFree((*table)->locks));
std::free(*table);
allocator->free(MemoryType::Host, (*table)->slices);
allocator->free(MemoryType::Device, (*table)->buckets_size);
allocator->free(MemoryType::Device, (*table)->buckets);
allocator->free(MemoryType::Device, (*table)->locks);
allocator->free(MemoryType::Host, *table);
CUDA_CHECK(cudaDeviceSynchronize());
CudaCheckError();
}
Expand Down
7 changes: 4 additions & 3 deletions include/merlin/core_kernels/kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ __forceinline__ __device__ void LDGSTS(ElementType* dst,
const ElementType* src);

template <>
__forceinline__ __device__ void LDGSTS<uint8_t>(
uint8_t* dst, const uint8_t* src) {
__forceinline__ __device__ void LDGSTS<uint8_t>(uint8_t* dst,
const uint8_t* src) {
uint8_t element = *src;
*dst = element;
}

template <>
__forceinline__ __device__ void LDGSTS<uint16_t>(uint16_t* dst, const uint16_t* src) {
__forceinline__ __device__ void LDGSTS<uint16_t>(uint16_t* dst,
const uint16_t* src) {
uint16_t element = *src;
*dst = element;
}
Expand Down
Loading

0 comments on commit 0da72d4

Please sign in to comment.