Skip to content

Commit

Permalink
Add maximum key length check, safe string compare, and byte packing.
Browse files Browse the repository at this point in the history
  • Loading branch information
ehpor committed Dec 13, 2024
1 parent 85c5443 commit 2d6905d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
17 changes: 14 additions & 3 deletions benchmarks/hash_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

int main(int argc, char **argv)
{
typedef HashMap<int, 16384, 32> MyHashMap;
typedef HashMap<uint16_t, 16384, 13> MyHashMap;

std::size_t buffer_size = MyHashMap::CalculateBufferSize();
std::cout << "Buffer size: " << buffer_size << " bytes" << std::endl;
char *buffer = new char[buffer_size];

MyHashMap map(buffer);
Expand All @@ -21,9 +22,14 @@ int main(int argc, char **argv)
std::string key = "key" + std::to_string(i);

auto start = GetTimeStamp();
map.Insert(key, i);
bool success = map.Insert(key, uint16_t(i));
auto end = GetTimeStamp();

if (!success)
{
std::cout << "Insertion failed." << std::endl;
}

total_time += end - start;
}

Expand All @@ -36,9 +42,14 @@ int main(int argc, char **argv)
std::string key = "key" + std::to_string(i);

auto start = GetTimeStamp();
const int *value = map.Find(key);
auto *value = map.Find(key);
auto end = GetTimeStamp();

if (value == nullptr || *value != i)
{
std::cout << "Key not found." << std::endl;
}

total_time += end - start;
}

Expand Down
30 changes: 21 additions & 9 deletions catkit_core/HashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <atomic>
#include <string>
#include <cstring>
#include <limits>

// MurmurHash3 32-bit version
uint32_t murmurhash3(const std::string &key, uint32_t seed = 0)
Expand Down Expand Up @@ -69,7 +70,7 @@ template <typename Value, std::size_t Size, std::size_t MaxKeyLength>
class HashMap
{
private:
enum EntryFlags
enum EntryFlags : uint8_t
{
UNOCCUPIED = 0,
INITIALIZING = 1,
Expand All @@ -78,10 +79,10 @@ class HashMap

struct Entry
{
Value value;

std::atomic<EntryFlags> flags = EntryFlags::UNOCCUPIED;
char key[MaxKeyLength];

Value value;
};

Entry *m_Data;
Expand Down Expand Up @@ -114,6 +115,12 @@ class HashMap

bool Insert(const std::string &key, const Value &value)
{
if (key.size() > MaxKeyLength)
{
// Key is too long to fit in the fixed-size buffer.
return false;
}

size_t index = hash(key);

for (size_t i = 0; i < Size; ++i)
Expand All @@ -139,7 +146,7 @@ class HashMap
if (flags == EntryFlags::OCCUPIED)
{
// Check if the key is our key.
if (std::strcmp(m_Data[probe].key, key.c_str()) == 0)
if (AreKeysTheSame(m_Data[probe].key, key.c_str()))
{
// Key already exists.
return false;
Expand All @@ -148,10 +155,9 @@ class HashMap
}
else
{
// Copy key ensuring null-termination.
std::size_t key_length = std::min(key.size(), MaxKeyLength - 1);
// Copy key.
std::size_t key_length = std::min(key.size(), MaxKeyLength);
key.copy(m_Data[probe].key, key_length);
m_Data[probe].key[MaxKeyLength - 1] = '\0';

// Copy m_Data.
m_Data[probe].value = value;
Expand All @@ -169,7 +175,7 @@ class HashMap

const Value *Find(const std::string &key) const
{
if (key.size() >= MaxKeyLength)
if (key.size() > MaxKeyLength)
{
// Key is too long to fit in the fixed-size buffer.
return nullptr;
Expand All @@ -183,20 +189,26 @@ class HashMap

EntryFlags flags = m_Data[probe].flags.load(std::memory_order_acquire);

if (flags == EntryFlags::OCCUPIED && std::strcmp(m_Data[probe].key, key.c_str()) == 0)
if (flags == EntryFlags::OCCUPIED && AreKeysTheSame(m_Data[probe].key, key.c_str()))
{
return &m_Data[probe].value;
}

if (flags != EntryFlags::OCCUPIED)
{
// Key not found.
break;
}
}

// Key not found.
return nullptr;
}

bool AreKeysTheSame(const char *ky1, const char *ky2) const
{
return std::strncmp(ky1, ky2, MaxKeyLength) == 0;
}
};

#endif // HASH_MAP_H

0 comments on commit 2d6905d

Please sign in to comment.