From 1594f91d1dd1dbeb9db425018b142b421812ed5c Mon Sep 17 00:00:00 2001 From: "yusheng.ma" Date: Fri, 14 Jun 2024 14:38:14 +0800 Subject: [PATCH] fix hnsw cpu search ids type Signed-off-by: yusheng.ma --- cpp/include/raft/neighbors/detail/hnsw.hpp | 5 ++- cpp/include/raft/neighbors/hnsw.hpp | 2 +- cpp/include/raft_runtime/neighbors/hnsw.hpp | 40 ++++++++++----------- cpp/src/raft_runtime/neighbors/hnsw.cpp | 2 +- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/hnsw.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp index bd4e6608de..569170e4f7 100644 --- a/cpp/include/raft/neighbors/detail/hnsw.hpp +++ b/cpp/include/raft/neighbors/detail/hnsw.hpp @@ -32,7 +32,7 @@ template void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, const T* query, int k, - uint64_t* indices, + int64_t* indices, float* distances) { auto result = idx->searchKnn(query, k); @@ -50,14 +50,13 @@ void search(raft::resources const& res, const search_params& params, const index& idx, raft::host_matrix_view queries, - raft::host_matrix_view neighbors, + raft::host_matrix_view neighbors, raft::host_matrix_view distances) { idx.set_ef(params.ef); auto const* hnswlib_index = reinterpret_cast::type> const*>( idx.get_index()); - // when num_threads == 0, automatically maximize parallelism if (params.num_threads) { #pragma omp parallel for num_threads(params.num_threads) diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp index ee3f61e550..1214eee970 100644 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ b/cpp/include/raft/neighbors/hnsw.hpp @@ -122,7 +122,7 @@ void search(raft::resources const& res, const search_params& params, const index& idx, raft::host_matrix_view queries, - raft::host_matrix_view neighbors, + raft::host_matrix_view neighbors, raft::host_matrix_view distances) { RAFT_EXPECTS( diff --git a/cpp/include/raft_runtime/neighbors/hnsw.hpp b/cpp/include/raft_runtime/neighbors/hnsw.hpp index e8b932d490..a16bab785d 100644 --- a/cpp/include/raft_runtime/neighbors/hnsw.hpp +++ b/cpp/include/raft_runtime/neighbors/hnsw.hpp @@ -23,26 +23,26 @@ namespace raft::runtime::neighbors::hnsw { -#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ - std::unique_ptr> from_cagra( \ - raft::resources const& res, raft::neighbors::cagra::index); \ - void search(raft::resources const& handle, \ - raft::neighbors::hnsw::search_params const& params, \ - raft::neighbors::hnsw::index const& index, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ - raft::host_matrix_view distances); \ - template \ - std::unique_ptr> deserialize_file( \ - raft::resources const& handle, \ - const std::string& filename, \ - int dim, \ - raft::distance::DistanceType metric); \ - template <> \ - std::unique_ptr> deserialize_file( \ - raft::resources const& handle, \ - const std::string& filename, \ - int dim, \ +#define RAFT_INST_HNSW_FUNCS(T, IdxT) \ + std::unique_ptr> from_cagra( \ + raft::resources const& res, raft::neighbors::cagra::index); \ + void search(raft::resources const& handle, \ + raft::neighbors::hnsw::search_params const& params, \ + raft::neighbors::hnsw::index const& index, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbors, \ + raft::host_matrix_view distances); \ + template \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ + raft::distance::DistanceType metric); \ + template <> \ + std::unique_ptr> deserialize_file( \ + raft::resources const& handle, \ + const std::string& filename, \ + int dim, \ raft::distance::DistanceType metric); RAFT_INST_HNSW_FUNCS(float, uint32_t); diff --git a/cpp/src/raft_runtime/neighbors/hnsw.cpp b/cpp/src/raft_runtime/neighbors/hnsw.cpp index 5356e708d2..e6d21744bf 100644 --- a/cpp/src/raft_runtime/neighbors/hnsw.cpp +++ b/cpp/src/raft_runtime/neighbors/hnsw.cpp @@ -55,7 +55,7 @@ namespace raft::runtime::neighbors::hnsw { raft::neighbors::hnsw::search_params const& params, \ const raft::neighbors::hnsw::index& index, \ raft::host_matrix_view queries, \ - raft::host_matrix_view neighbors, \ + raft::host_matrix_view neighbors, \ raft::host_matrix_view distances) \ { \ raft::neighbors::hnsw::search(handle, params, index, queries, neighbors, distances); \