Skip to content

Commit

Permalink
allows that the hints procedure also populates several points into th…
Browse files Browse the repository at this point in the history
…e beam
  • Loading branch information
sadit committed Oct 17, 2024
1 parent da534ed commit b84b79e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/searchgraph/beamsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ Base.copy(bsearch::BeamSearch; bsize=bsearch.bsize, Δ=bsearch.Δ, maxvisits=bse

### local search algorithm

function beamsearch_init(bs::BeamSearch, index::SearchGraph, q, res::KnnResult, hints, vstate, bsize)
visited_ = approx_by_hints(index, q, hints, res, vstate)
function beamsearch_init(bs::BeamSearch, index::SearchGraph, q, res::KnnResult, hints, vstate, bsize, beam)
visited_ = approx_by_hints(index, q, hints, res, vstate, beam)

if length(res) == 0
_range = 1:length(index)
for _ in 1:bsize
objID = rand(_range)
visited_ += enqueue_item!(index, q, database(index, objID), res, objID, vstate)
visited_ += enqueue_item!(index, q, database(index, objID), res, objID, vstate, beam)
end
end

Expand Down Expand Up @@ -86,7 +86,7 @@ Optional arguments (defaults to values in `bs`)
function search(bs::BeamSearch, index::SearchGraph, context::SearchGraphContext, q, res, hints; bsize::Int32=bs.bsize, Δ::Float32=bs.Δ, maxvisits::Int=bs.maxvisits, vstate::Vector{UInt64}=getvstate(length(index), context))
# k is the number of neighbors in res
vstate = PtrArray(vstate)
visited_ = beamsearch_init(bs, index, q, res, hints, vstate, bsize)
beam = getbeam(bsize, context)
visited_ = beamsearch_init(bs, index, q, res, hints, vstate, bsize, beam)
beamsearch_inner(bs, index, q, res, vstate, beam, Δ, maxvisits, visited_)
end
8 changes: 5 additions & 3 deletions src/searchgraph/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ Base.copy(G::SearchGraph;
@inline Base.length(g::SearchGraph)::Int64 = g.len[]

"""
enqueue_item!(index::SearchGraph, q, obj, res::KnnResult, objID, vstate)
enqueue_item!(index::SearchGraph, q, obj, res::KnnResult, objID, vstate, beam)
Internal function that evaluates the distance between a database object `obj` with id `objID` and the query `q`.
It helps to evaluate, mark as visited, and enqueue in the result set.
"""
@inline function enqueue_item!(index::SearchGraph, q, obj, res::KnnResult, objID, vstate)::Int
@inline function enqueue_item!(index::SearchGraph, q, obj, res::KnnResult, objID, vstate, beam)::Int
check_visited_and_visit!(vstate, convert(UInt64, objID)) && return 0
push_item!(res, objID, evaluate(distance(index), q, database(index, objID)))
d = evaluate(distance(index), q, database(index, objID))
push_item!(res, objID, d)
beam !== nothing && push_item!(beam, objID, d)
1
end

Expand Down
10 changes: 5 additions & 5 deletions src/searchgraph/hints.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# This file is a part of SimilaritySearch.jl
#
"""
approx_by_hints(index::SearchGraph, q, hints, res::KnnResult, vstate)
approx_by_hints(index::SearchGraph, q, hints, res::KnnResult, vstate, beam)
Approximate the result using a set of hints (the set of identifiers (integers)) behints `hints`
"""
function approx_by_hints(index::SearchGraph, q, hints::T, res::KnnResult, vstate) where T<:Union{AbstractVector,Tuple,Integer,Set}
function approx_by_hints(index::SearchGraph, q, hints::T, res::KnnResult, vstate, beam) where T<:Union{AbstractVector,Tuple,Integer,Set}
visited = 0
for objID in hints
obj = database(index, objID)
visited += enqueue_item!(index, q, obj, res, objID, vstate)
visited += enqueue_item!(index, q, obj, res, objID, vstate, beam)
end

visited
Expand All @@ -30,11 +30,11 @@ end



function approx_by_hints(index::SearchGraph, q, h::AdjacentStoredHints, res::KnnResult, vstate)
function approx_by_hints(index::SearchGraph, q, h::AdjacentStoredHints, res::KnnResult, vstate, beam)
visited = 0
for (i, objID) in enumerate(h.map)
obj = h.hints[i]
visited += enqueue_item!(index, q, obj, res, objID, vstate)
visited += enqueue_item!(index, q, obj, res, objID, vstate, beam)
end

visited
Expand Down

0 comments on commit b84b79e

Please sign in to comment.