Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace element_neighbors_locks_ lockes with read locks where possible #401

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 74 additions & 47 deletions src/VecSim/algorithms/hnsw/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ using graphNodeType = pair<idType, ushort>; // represented as: (element_id, leve

////////////////////////////////////// Auxiliary HNSW structs //////////////////////////////////////

using elem_mutex_t = std::shared_mutex;
// Vectors flags (for marking a specific vector)
typedef enum {
DELETE_MARK = 0x1, // element is logically deleted, but still exists in the graph
Expand All @@ -74,8 +75,8 @@ struct ElementMetaData {

ElementMetaData(labelType label = SIZE_MAX) noexcept : label(label), flags(IN_PROCESS) {}
};
#pragma pack() // restore default packing

#pragma pack() // restore default packing
struct LevelData {
vecsim_stl::vector<idType> *incomingEdges;
linkListSize numLinks;
Expand All @@ -94,7 +95,7 @@ struct LevelData {

struct ElementGraphData {
size_t toplevel;
std::mutex neighborsGuard;
elem_mutex_t neighborsGuard;
LevelData *others;
LevelData level0;

Expand Down Expand Up @@ -293,10 +294,14 @@ class HNSWIndex : public VecSimIndexAbstract<DistType>,
inline auto safeGetEntryPointState() const;
inline void lockIndexDataGuard() const;
inline void unlockIndexDataGuard() const;
inline void lockNodeLinks(idType node_id) const;
inline void unlockNodeLinks(idType node_id) const;
inline void lockNodeLinks(ElementGraphData *node_data) const;
inline void unlockNodeLinks(ElementGraphData *node_data) const;
inline void writeLockNodeLinks(idType node_id) const;
inline void writeUnlockNodeLinks(idType node_id) const;
inline void writeLockNodeLinks(ElementGraphData *node_data) const;
inline void writeUnlockNodeLinks(ElementGraphData *node_data) const;
inline void readLockNodeLinks(idType node_id) const;
inline void readUnlockNodeLinks(idType node_id) const;
inline void readLockNodeLinks(ElementGraphData *node_data) const;
inline void readUnlockNodeLinks(ElementGraphData *node_data) const;
inline VisitedNodesHandler *getVisitedList() const;
inline void returnVisitedList(VisitedNodesHandler *visited_nodes_handler) const;
VecSimIndexInfo info() const override;
Expand Down Expand Up @@ -502,24 +507,46 @@ void HNSWIndex<DataType, DistType>::unlockIndexDataGuard() const {
indexDataGuard.unlock();
}

/////////////// WRITE LOCKS /////////////////
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::lockNodeLinks(ElementGraphData *node_data) const {
void HNSWIndex<DataType, DistType>::writeLockNodeLinks(ElementGraphData *node_data) const {
node_data->neighborsGuard.lock();
}

template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unlockNodeLinks(ElementGraphData *node_data) const {
void HNSWIndex<DataType, DistType>::writeUnlockNodeLinks(ElementGraphData *node_data) const {
node_data->neighborsGuard.unlock();
}

template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::lockNodeLinks(idType node_id) const {
lockNodeLinks(getGraphDataByInternalId(node_id));
void HNSWIndex<DataType, DistType>::writeLockNodeLinks(idType node_id) const {
writeLockNodeLinks(getGraphDataByInternalId(node_id));
}

template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::writeUnlockNodeLinks(idType node_id) const {
writeUnlockNodeLinks(getGraphDataByInternalId(node_id));
}

/////////////// READ LOCKS /////////////////
template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::readLockNodeLinks(ElementGraphData *node_data) const {
node_data->neighborsGuard.lock_shared();
}

template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::readUnlockNodeLinks(ElementGraphData *node_data) const {
node_data->neighborsGuard.unlock_shared();
}

template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::readLockNodeLinks(idType node_id) const {
readLockNodeLinks(getGraphDataByInternalId(node_id));
}

template <typename DataType, typename DistType>
void HNSWIndex<DataType, DistType>::unlockNodeLinks(idType node_id) const {
unlockNodeLinks(getGraphDataByInternalId(node_id));
void HNSWIndex<DataType, DistType>::readUnlockNodeLinks(idType node_id) const {
readUnlockNodeLinks(getGraphDataByInternalId(node_id));
}

/**
Expand Down Expand Up @@ -579,7 +606,7 @@ void HNSWIndex<DataType, DistType>::processCandidate(
candidatesMaxHeap<DistType> &candidate_set, DistType &lowerBound) const {

ElementGraphData *cur_element = getGraphDataByInternalId(curNodeId);
lockNodeLinks(cur_element);
readLockNodeLinks(cur_element);
LevelData &node_level = getLevelData(cur_element, layer);

if (node_level.numLinks > 0) {
Expand Down Expand Up @@ -653,7 +680,7 @@ void HNSWIndex<DataType, DistType>::processCandidate(
}
}
}
unlockNodeLinks(cur_element);
readUnlockNodeLinks(cur_element);
}

template <typename DataType, typename DistType>
Expand All @@ -664,7 +691,7 @@ void HNSWIndex<DataType, DistType>::processCandidate_RangeSearch(
candidatesMaxHeap<DistType> &candidate_set, DistType dyn_range, DistType radius) const {

auto *cur_element = getGraphDataByInternalId(curNodeId);
lockNodeLinks(cur_element);
readLockNodeLinks(cur_element);
LevelData &node_level = getLevelData(cur_element, layer);
if (node_level.numLinks > 0) {

Expand Down Expand Up @@ -719,7 +746,7 @@ void HNSWIndex<DataType, DistType>::processCandidate_RangeSearch(
}
}
}
unlockNodeLinks(cur_element);
readUnlockNodeLinks(cur_element);
}

template <typename DataType, typename DistType>
Expand Down Expand Up @@ -865,16 +892,16 @@ void HNSWIndex<DataType, DistType>::revisitNeighborConnections(
// Acquire all relevant locks for making the updates for the selected neighbor - all its removed
// neighbors, along with the neighbors itself and the cur node.
// but first, we release the node and neighbors lock to avoid deadlocks.
unlockNodeLinks(new_node_id);
unlockNodeLinks(selected_neighbor);
writeUnlockNodeLinks(new_node_id);
writeUnlockNodeLinks(selected_neighbor);

nodes_to_update.push_back(selected_neighbor);
nodes_to_update.push_back(new_node_id);

std::sort(nodes_to_update.begin(), nodes_to_update.end());
size_t nodes_to_update_count = nodes_to_update.size();
for (size_t i = 0; i < nodes_to_update_count; i++) {
lockNodeLinks(nodes_to_update[i]);
writeLockNodeLinks(nodes_to_update[i]);
}
size_t neighbour_neighbours_idx = 0;
bool update_cur_node_required = true;
Expand Down Expand Up @@ -923,7 +950,7 @@ void HNSWIndex<DataType, DistType>::revisitNeighborConnections(
// Done updating the neighbor's neighbors.
neighbor_level.numLinks = neighbour_neighbours_idx;
for (size_t i = 0; i < nodes_to_update_count; i++) {
unlockNodeLinks(nodes_to_update[i]);
writeUnlockNodeLinks(nodes_to_update[i]);
}
}

Expand Down Expand Up @@ -959,11 +986,11 @@ idType HNSWIndex<DataType, DistType>::mutuallyConnectNewElement(
idType selected_neighbor = neighbor_data.second; // neighbor's id
auto *neighbor_graph_data = getGraphDataByInternalId(selected_neighbor);
if (new_node_id < selected_neighbor) {
lockNodeLinks(new_node_level);
lockNodeLinks(neighbor_graph_data);
writeLockNodeLinks(new_node_level);
writeLockNodeLinks(neighbor_graph_data);
} else {
lockNodeLinks(neighbor_graph_data);
lockNodeLinks(new_node_level);
writeLockNodeLinks(neighbor_graph_data);
writeLockNodeLinks(new_node_level);
}

// validations...
Expand All @@ -975,15 +1002,15 @@ idType HNSWIndex<DataType, DistType>::mutuallyConnectNewElement(
if (new_node_level_data.numLinks == max_M_cur) {
// The new node cannot add more neighbors
this->log("Couldn't add all chosen neighbors upon inserting a new node");
unlockNodeLinks(new_node_level);
unlockNodeLinks(neighbor_graph_data);
writeUnlockNodeLinks(new_node_level);
writeUnlockNodeLinks(neighbor_graph_data);
break;
}

// If one of the two nodes has already deleted - skip the operation.
if (isMarkedDeleted(new_node_id) || isMarkedDeleted(selected_neighbor)) {
unlockNodeLinks(new_node_level);
unlockNodeLinks(neighbor_graph_data);
writeUnlockNodeLinks(new_node_level);
writeUnlockNodeLinks(neighbor_graph_data);
continue;
}

Expand All @@ -994,8 +1021,8 @@ idType HNSWIndex<DataType, DistType>::mutuallyConnectNewElement(
if (neighbor_level_data.numLinks < max_M_cur) {
new_node_level_data.links[new_node_level_data.numLinks++] = selected_neighbor;
neighbor_level_data.links[neighbor_level_data.numLinks++] = new_node_id;
unlockNodeLinks(new_node_level);
unlockNodeLinks(neighbor_graph_data);
writeUnlockNodeLinks(new_node_level);
writeUnlockNodeLinks(neighbor_graph_data);
continue;
}

Expand Down Expand Up @@ -1105,15 +1132,15 @@ void HNSWIndex<DataType, DistType>::replaceEntryPoint() {
volatile idType candidate_in_process = INVALID_ID;

// Go over the entry point's neighbors at the top level.
lockNodeLinks(old_entry_point);
readLockNodeLinks(old_entry_point);
LevelData &old_ep_level = getLevelData(old_entry_point, maxLevel);
// Tries to set the (arbitrary) first neighbor as the entry point which is not deleted,
// if exists.
for (size_t i = 0; i < old_ep_level.numLinks; i++) {
if (!isMarkedDeleted(old_ep_level.links[i])) {
if (!isInProcess(old_ep_level.links[i])) {
entrypointNode = old_ep_level.links[i];
unlockNodeLinks(old_entry_point);
readUnlockNodeLinks(old_entry_point);
return;
} else {
// Store this candidate which is currently being inserted into the graph in
Expand All @@ -1122,7 +1149,7 @@ void HNSWIndex<DataType, DistType>::replaceEntryPoint() {
}
}
}
unlockNodeLinks(old_entry_point);
readUnlockNodeLinks(old_entry_point);

// If there is no neighbors in the current level, check for any vector at
// this level to be the new entry point.
Expand Down Expand Up @@ -1273,7 +1300,7 @@ void HNSWIndex<DataType, DistType>::greedySearchLevel(const void *vector_data, s

changed = false;
auto *element = getGraphDataByInternalId(bestCand);
lockNodeLinks(element);
readLockNodeLinks(element);
LevelData &node_level_data = getLevelData(element, level);

for (int i = 0; i < node_level_data.numLinks; i++) {
Expand All @@ -1295,7 +1322,7 @@ void HNSWIndex<DataType, DistType>::greedySearchLevel(const void *vector_data, s
}
}
}
unlockNodeLinks(element);
readUnlockNodeLinks(element);
} while (changed);
if (!running_query) {
bestCand = bestNonDeletedCand;
Expand All @@ -1311,18 +1338,18 @@ HNSWIndex<DataType, DistType>::safeCollectAllNodeIncomingNeighbors(idType node_i
for (size_t level = 0; level <= element->toplevel; level++) {
// Save the node neighbor's in the current level while holding its neighbors lock.
std::vector<idType> neighbors_copy;
lockNodeLinks(element);
readLockNodeLinks(element);
auto &node_level_data = getLevelData(element, level);
// Store the deleted element's neighbours.
neighbors_copy.assign(node_level_data.links,
node_level_data.links + node_level_data.numLinks);
unlockNodeLinks(element);
readUnlockNodeLinks(element);

// Go over the neighbours and collect tho ones that also points back to the removed node.
for (auto neighbour_id : neighbors_copy) {
// Hold the neighbor's lock while we are going over its neighbors.
auto *neighbor = getGraphDataByInternalId(neighbour_id);
lockNodeLinks(neighbor);
readLockNodeLinks(neighbor);
LevelData &neighbour_level_data = getLevelData(neighbor, level);

for (size_t j = 0; j < neighbour_level_data.numLinks; j++) {
Expand All @@ -1332,16 +1359,16 @@ HNSWIndex<DataType, DistType>::safeCollectAllNodeIncomingNeighbors(idType node_i
break;
}
}
unlockNodeLinks(neighbor);
readUnlockNodeLinks(neighbor);
}

// Next, collect the rest of incoming edges (the ones that are not bidirectional) in the
// current level to repair them.
lockNodeLinks(element);
readLockNodeLinks(element);
for (auto incoming_edge : *node_level_data.incomingEdges) {
incoming_neighbors.emplace_back(incoming_edge, (ushort)level);
}
unlockNodeLinks(element);
readUnlockNodeLinks(element);
}
return incoming_neighbors;
}
Expand Down Expand Up @@ -1402,7 +1429,7 @@ void HNSWIndex<DataType, DistType>::mutuallyUpdateForRepairedNode(
std::sort(nodes_to_update.begin(), nodes_to_update.end());
size_t nodes_to_update_count = nodes_to_update.size();
for (size_t i = 0; i < nodes_to_update_count; i++) {
lockNodeLinks(nodes_to_update[i]);
writeLockNodeLinks(nodes_to_update[i]);
}

LevelData &node_level = getLevelData(node_id, level);
Expand Down Expand Up @@ -1477,7 +1504,7 @@ void HNSWIndex<DataType, DistType>::mutuallyUpdateForRepairedNode(
// Done updating the node's neighbors.
node_level.numLinks = node_neighbors_idx;
for (size_t i = 0; i < nodes_to_update_count; i++) {
unlockNodeLinks(nodes_to_update[i]);
writeUnlockNodeLinks(nodes_to_update[i]);
}
}

Expand All @@ -1499,7 +1526,7 @@ void HNSWIndex<DataType, DistType>::repairNodeConnections(idType node_id, size_t
// after the repair as well.
const void *node_data = getDataByInternalId(node_id);
auto *element = getGraphDataByInternalId(node_id);
lockNodeLinks(element);
readLockNodeLinks(element);
LevelData &node_level_data = getLevelData(element, level);
for (size_t j = 0; j < node_level_data.numLinks; j++) {
node_orig_neighbours_set[node_level_data.links[j]] = true;
Expand All @@ -1513,7 +1540,7 @@ void HNSWIndex<DataType, DistType>::repairNodeConnections(idType node_id, size_t
this->distFunc(node_data, getDataByInternalId(node_level_data.links[j]), this->dim),
node_level_data.links[j]);
}
unlockNodeLinks(element);
readUnlockNodeLinks(element);

// If there are not deleted neighbors at that point the repair job has already been made by
// another parallel job, and there is no need to repair the node anymore.
Expand All @@ -1534,7 +1561,7 @@ void HNSWIndex<DataType, DistType>::repairNodeConnections(idType node_id, size_t
neighbors_to_remove.push_back(deleted_neighbor_id);

auto *neighbor = getGraphDataByInternalId(deleted_neighbor_id);
lockNodeLinks(neighbor);
readLockNodeLinks(neighbor);
LevelData &neighbor_level_data = getLevelData(neighbor, level);

for (size_t j = 0; j < neighbor_level_data.numLinks; j++) {
Expand All @@ -1551,7 +1578,7 @@ void HNSWIndex<DataType, DistType>::repairNodeConnections(idType node_id, size_t
this->dim),
neighbor_level_data.links[j]);
}
unlockNodeLinks(neighbor);
readUnlockNodeLinks(neighbor);
}

// Copy the original candidates, and run the heuristics. Afterwards, neighbors_candidates will
Expand Down
4 changes: 2 additions & 2 deletions src/VecSim/algorithms/hnsw/hnsw_batch_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ VecSimQueryResult_Code HNSW_BatchIterator<DataType, DistType>::scanGraphInternal
// Take the current node out of the candidates queue and go over his neighbours.
candidates.pop();
auto *node_graph_data = this->index->getGraphDataByInternalId(curr_node_id);
this->index->lockNodeLinks(node_graph_data);
this->index->readLockNodeLinks(node_graph_data);
LevelData &node_level_data = this->index->getLevelData(node_graph_data, 0);
if (node_level_data.numLinks > 0) {

Expand Down Expand Up @@ -159,7 +159,7 @@ VecSimQueryResult_Code HNSW_BatchIterator<DataType, DistType>::scanGraphInternal
candidates.emplace(candidate_dist, candidate_id);
}
}
this->index->unlockNodeLinks(curr_node_id);
this->index->readUnlockNodeLinks(curr_node_id);
}
return VecSim_QueryResult_OK;
}
Expand Down