Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: implement a periodic tree that maps points to "mirrors" #193

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ There are currently three types of trees available:
* `KDTree`: Recursively splits points into groups using hyper-planes.
* `BallTree`: Recursively splits points into groups bounded by hyper-spheres.
* `BruteTree`: Not actually a tree. It linearly searches all points in a brute force manner.
* `PeriodicTree`: Wraps one of the trees above and allows for queries assuming the points repeat periodically.

These trees can be created using the following syntax:

```julia
KDTree(data, metric; leafsize, reorder)
BallTree(data, metric; leafsize, reorder)
BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for BruteTree

PeriodicTree(tree, bounds_min, bounds_max)
```

* `data`: The points to build the tree from, either as
Expand All @@ -29,6 +30,7 @@ BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for
* `metric`: The `Metric` (from `Distances.jl`) to use, defaults to `Euclidean`. `KDTree` works with axis-aligned metrics: `Euclidean`, `Chebyshev`, `Minkowski`, and `Cityblock` while for `BallTree` and `BruteTree` other pre-defined `Metric`s can be used as well as custom metrics (that are subtypes of `Metric`).
* `leafsize`: Determines the number of points (default 10) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points.
* `reorder`: If `true` (default), during tree construction this rearranges points to improve cache locality during querying. This will create a copy of the original data.
* `bounds_min`, `bounds_max`: Coordinates for the two bounds for which the points are assumed to be periodic.

All trees in `NearestNeighbors.jl` are static, meaning points cannot be added or removed after creation.
Note that this package is not suitable for very high dimensional points due to high compilation time and inefficient queries on the trees.
Expand All @@ -42,15 +44,16 @@ data = rand(3, 10^4)
kdtree = KDTree(data; leafsize = 25)
balltree = BallTree(data, Minkowski(3.5); reorder = false)
brutetree = BruteTree(data)
PeriodicTree(kdtree, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
```

## k-Nearest Neighbor (kNN) Searches

A kNN search finds the `k` nearest neighbors to a given point or points. This is done with the methods:

```julia
knn(tree, point[s], k, skip = always_false) -> idxs, dists
knn!(idxs, dists, tree, point, k, skip = always_false)
knn(tree, point[s], k, skip = Returns(false)) -> idxs, dists
knn!(idxs, dists, tree, point, k, skip = Returns(false))
```

* `tree`: The tree instance.
Expand All @@ -61,7 +64,7 @@ knn!(idxs, dists, tree, point, k, skip = always_false)
For the single closest neighbor, you can use `nn`:

```julia
nn(tree, points, skip = always_false) -> idxs, dists
nn(tree, points, skip = Returns(false)) -> idxs, dists
```

Examples:
Expand Down
18 changes: 10 additions & 8 deletions src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ using Distances
import Distances: PreMetric, Metric, result_type, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters

using StaticArrays
import Base.show

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree
export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs
export injectdata

Expand Down Expand Up @@ -48,18 +47,21 @@ end
get_T(::Type{T}) where {T <: AbstractFloat} = T
get_T(::T) where {T} = Float64

include("evaluation.jl")
include("tree_data.jl")
include("datafreetree.jl")
include("knn.jl")
include("inrange.jl")
get_tree(tree::NNTree) = tree

include("hyperspheres.jl")
include("hyperrectangles.jl")
include("evaluation.jl")
include("utilities.jl")
include("tree_data.jl")
include("tree_ops.jl")
include("brute_tree.jl")
include("kd_tree.jl")
include("ball_tree.jl")
include("tree_ops.jl")
include("periodic_tree.jl")
include("datafreetree.jl")
include("knn.jl")
include("inrange.jl")

for dim in (2, 3)
tree = KDTree(rand(dim, 10))
Expand Down
34 changes: 17 additions & 17 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function _knn(tree::BallTree,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {F}
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip)
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, false)
return
end

Expand All @@ -146,9 +146,9 @@ function knn_kernel!(tree::BallTree{V},
point::AbstractArray,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F, unique::Bool) where {V, F}
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip)
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, unique)
return
end

Expand All @@ -160,14 +160,14 @@ function knn_kernel!(tree::BallTree{V},

if left_dist <= best_dists[1] || right_dist <= best_dists[1]
if left_dist < right_dist
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique)
if right_dist <= best_dists[1]
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique)
end
else
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique)
if left_dist <= best_dists[1]
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique)
end
end
end
Expand All @@ -177,16 +177,19 @@ end
function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V}
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::F) where {V, F}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, false) # Call the recursive range finder
end

function inrange_kernel!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::F,
unique::Bool) where {F}

if index > length(tree.hyper_spheres)
return 0
Expand All @@ -204,19 +207,16 @@ function inrange_kernel!(tree::BallTree,
# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r
return add_points_inrange!(idx_in_ball, tree, index, point, r)
return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, unique)
end

count = 0

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses_fast(dist, tree.metric, sphere, query_ball)
count += addall(tree, index, idx_in_ball)
return addall(tree, index, idx_in_ball, skip, unique)
else
# Recursively call the left and right sub tree.
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
return inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip, unique) +
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, unique)
end
return count
end
21 changes: 16 additions & 5 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,25 @@ function _knn(tree::BruteTree{V},
best_dists::AbstractVector,
skip::F) where {V, F}

knn_kernel!(tree, point, best_idxs, best_dists, skip)
knn_kernel!(tree, point, best_idxs, best_dists, skip, false)
return
end

function knn_kernel!(tree::BruteTree{V},
point::AbstractVector,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F,
unique::Bool) where {V, F}
for i in 1:length(tree.data)
if skip(i)
continue
end

#if unique && i in best_idxs
# continue
#end

dist_d = evaluate(tree.metric, tree.data[i], point)
if dist_d <= best_dists[1]
best_dists[1] = dist_d
Expand All @@ -61,17 +66,23 @@ end
function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
return inrange_kernel!(tree, point, radius, idx_in_ball)
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::F,) where {F}
return inrange_kernel!(tree, point, radius, idx_in_ball, skip, false)
end


function inrange_kernel!(tree::BruteTree,
point::AbstractVector,
r::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::Function,
unique::Bool)
count = 0
for i in 1:length(tree.data)
if skip(i)
continue
end
d = evaluate(tree.metric, tree.data[i], point)
if d <= r
count += 1
Expand Down
32 changes: 18 additions & 14 deletions src/inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@ See also: `inrange!`, `inrangecount`.
function inrange(tree::NNTree,
points::AbstractVector{T},
radius::Number,
sortres=false) where {T <: AbstractVector}
sortres=false,
skip::F = Returns(false)) where {T <: AbstractVector, F}
check_input(tree, points)
check_radius(radius)

idxs = [Vector{Int}() for _ in 1:length(points)]

for i in 1:length(points)
inrange_point!(tree, points[i], radius, sortres, idxs[i])
inrange_point!(tree, points[i], radius, sortres, idxs[i], skip)
end
return idxs
end

function inrange_point!(tree, point, radius, sortres, idx)
count = _inrange(tree, point, radius, idx)
inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F} = _inrange_point!(tree, point, radius, sortres, idx, skip)

function _inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F}
count = _inrange(tree, point, radius, idx, skip)
if idx !== nothing
if tree.reordered
inner_tree = get_tree(tree)
if inner_tree.reordered
@inbounds for j in 1:length(idx)
idx[j] = tree.indices[idx[j]]
idx[j] = inner_tree.indices[idx[j]]
end
end
sortres && sort!(idx)
Expand All @@ -44,11 +48,11 @@ Useful if one want to avoid allocations or specify the element type of the outpu

See also: `inrange`, `inrangecount`.
"""
function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number}
function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip=Returns(false)) where {V, T <: Number}
check_input(tree, point)
check_radius(radius)
length(idxs) == 0 || throw(ArgumentError("idxs must be empty"))
inrange_point!(tree, point, radius, sortres, idxs)
inrange_point!(tree, point, radius, sortres, idxs, skip)
return idxs
end

Expand All @@ -61,7 +65,7 @@ function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sor
inrange_matrix(tree, points, radius, Val(dim), sortres)
end

function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim}
function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres, skip::F=Returns(false)) where {V, T <: Number, dim, F}
# TODO: DRY with inrange for AbstractVector
check_input(tree, points)
check_radius(radius)
Expand All @@ -70,7 +74,7 @@ function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Numb

for i in 1:n_points
point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim)))
inrange_point!(tree, point, radius, sortres, idxs[i])
inrange_point!(tree, point, radius, sortres, idxs[i], skip)
end
return idxs
end
Expand All @@ -80,18 +84,18 @@ end

Count all the points in the tree which are closer than `radius` to `points`.
"""
function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number) where {V, T <: Number}
function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F}
check_input(tree, point)
check_radius(radius)
return inrange_point!(tree, point, radius, false, nothing)
return inrange_point!(tree, point, radius, false, nothing, skip)
end

function inrangecount(tree::NNTree,
points::AbstractVector{T},
radius::Number) where {T <: AbstractVector}
radius::Number, skip::F=Returns(false)) where {T <: AbstractVector, F}
check_input(tree, points)
check_radius(radius)
return inrange_point!.(Ref(tree), points, radius, false, nothing)
return inrange_point!.(Ref(tree), points, radius, false, nothing, skip)
end

function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number}
Expand Down
Loading
Loading