Skip to content

Commit

Permalink
fix hnsw cpu search ids type
Browse files Browse the repository at this point in the history
Signed-off-by: yusheng.ma <[email protected]>
  • Loading branch information
Presburger committed Oct 14, 2024
1 parent ec12aba commit 1594f91
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
5 changes: 2 additions & 3 deletions cpp/include/raft/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ template <typename T>
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
const T* query,
int k,
uint64_t* indices,
int64_t* indices,
float* distances)
{
auto result = idx->searchKnn(query, k);
Expand All @@ -50,14 +50,13 @@ void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const T, int64_t, row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
idx.set_ef(params.ef);
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());

// when num_threads == 0, automatically maximize parallelism
if (params.num_threads) {
#pragma omp parallel for num_threads(params.num_threads)
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const T, int64_t, row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::host_matrix_view<float, int64_t, row_major> distances)
{
RAFT_EXPECTS(
Expand Down
40 changes: 20 additions & 20 deletions cpp/include/raft_runtime/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@

namespace raft::runtime::neighbors::hnsw {

#define RAFT_INST_HNSW_FUNCS(T, IdxT) \
std::unique_ptr<raft::neighbors::hnsw::index<T>> from_cagra( \
raft::resources const& res, raft::neighbors::cagra::index<T, IdxT>); \
void search(raft::resources const& handle, \
raft::neighbors::hnsw::search_params const& params, \
raft::neighbors::hnsw::index<T> const& index, \
raft::host_matrix_view<const T, int64_t, row_major> queries, \
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors, \
raft::host_matrix_view<float, int64_t, row_major> distances); \
template <typename DType> \
std::unique_ptr<raft::neighbors::hnsw::index<DType>> deserialize_file( \
raft::resources const& handle, \
const std::string& filename, \
int dim, \
raft::distance::DistanceType metric); \
template <> \
std::unique_ptr<raft::neighbors::hnsw::index<T>> deserialize_file( \
raft::resources const& handle, \
const std::string& filename, \
int dim, \
#define RAFT_INST_HNSW_FUNCS(T, IdxT) \
std::unique_ptr<raft::neighbors::hnsw::index<T>> from_cagra( \
raft::resources const& res, raft::neighbors::cagra::index<T, IdxT>); \
void search(raft::resources const& handle, \
raft::neighbors::hnsw::search_params const& params, \
raft::neighbors::hnsw::index<T> const& index, \
raft::host_matrix_view<const T, int64_t, row_major> queries, \
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors, \
raft::host_matrix_view<float, int64_t, row_major> distances); \
template <typename DType> \
std::unique_ptr<raft::neighbors::hnsw::index<DType>> deserialize_file( \
raft::resources const& handle, \
const std::string& filename, \
int dim, \
raft::distance::DistanceType metric); \
template <> \
std::unique_ptr<raft::neighbors::hnsw::index<T>> deserialize_file( \
raft::resources const& handle, \
const std::string& filename, \
int dim, \
raft::distance::DistanceType metric);

RAFT_INST_HNSW_FUNCS(float, uint32_t);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/raft_runtime/neighbors/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace raft::runtime::neighbors::hnsw {
raft::neighbors::hnsw::search_params const& params, \
const raft::neighbors::hnsw::index<T>& index, \
raft::host_matrix_view<const T, int64_t, row_major> queries, \
raft::host_matrix_view<uint64_t, int64_t, row_major> neighbors, \
raft::host_matrix_view<int64_t, int64_t, row_major> neighbors, \
raft::host_matrix_view<float, int64_t, row_major> distances) \
{ \
raft::neighbors::hnsw::search<T>(handle, params, index, queries, neighbors, distances); \
Expand Down

0 comments on commit 1594f91

Please sign in to comment.