From 37889e0d37323bd340bed5a9a15ef49c6dfd72e3 Mon Sep 17 00:00:00 2001 From: KristofferC Date: Thu, 20 Jun 2024 17:27:07 +0200 Subject: [PATCH] Implement a periodic tree that can do queries assuming the input data is periodic --- README.md | 11 +++-- src/NearestNeighbors.jl | 18 ++++--- src/ball_tree.jl | 34 ++++++------- src/brute_tree.jl | 21 ++++++-- src/inrange.jl | 32 ++++++------ src/kd_tree.jl | 26 +++++----- src/knn.jl | 34 ++++++------- src/periodic_tree.jl | 107 ++++++++++++++++++++++++++++++++++++++++ src/tree_ops.jl | 47 +++++++++++++----- src/utilities.jl | 10 ++-- test/runtests.jl | 1 + test/test_knn.jl | 2 +- test/test_monkey.jl | 3 +- test/test_periodic.jl | 53 ++++++++++++++++++++ 14 files changed, 302 insertions(+), 97 deletions(-) create mode 100644 src/periodic_tree.jl create mode 100644 test/test_periodic.jl diff --git a/README.md b/README.md index 12814af..9fb917f 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ There are currently three types of trees available: * `KDTree`: Recursively splits points into groups using hyper-planes. * `BallTree`: Recursively splits points into groups bounded by hyper-spheres. * `BruteTree`: Not actually a tree. It linearly searches all points in a brute force manner. +* `PeriodicTree`: Wraps one of the trees above and allows for queries assuming the points repeat periodically. These trees can be created using the following syntax: @@ -20,7 +21,7 @@ These trees can be created using the following syntax: KDTree(data, metric; leafsize, reorder) BallTree(data, metric; leafsize, reorder) BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for BruteTree - +PeriodicTree(tree, bounds_min, bounds_max) ``` * `data`: The points to build the tree from, either as @@ -29,6 +30,7 @@ BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for * `metric`: The `Metric` (from `Distances.jl`) to use, defaults to `Euclidean`. `KDTree` works with axis-aligned metrics: `Euclidean`, `Chebyshev`, `Minkowski`, and `Cityblock` while for `BallTree` and `BruteTree` other pre-defined `Metric`s can be used as well as custom metrics (that are subtypes of `Metric`). * `leafsize`: Determines the number of points (default 10) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points. * `reorder`: If `true` (default), during tree construction this rearranges points to improve cache locality during querying. This will create a copy of the original data. +* `bounds_min`, `bounds_max`: Coordinates for the two bounds for which the points are assumed to be periodic. All trees in `NearestNeighbors.jl` are static, meaning points cannot be added or removed after creation. Note that this package is not suitable for very high dimensional points due to high compilation time and inefficient queries on the trees. @@ -42,6 +44,7 @@ data = rand(3, 10^4) kdtree = KDTree(data; leafsize = 25) balltree = BallTree(data, Minkowski(3.5); reorder = false) brutetree = BruteTree(data) +PeriodicTree(kdtree, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]) ``` ## k-Nearest Neighbor (kNN) Searches @@ -49,8 +52,8 @@ brutetree = BruteTree(data) A kNN search finds the `k` nearest neighbors to a given point or points. This is done with the methods: ```julia -knn(tree, point[s], k, skip = always_false) -> idxs, dists -knn!(idxs, dists, tree, point, k, skip = always_false) +knn(tree, point[s], k, skip = Returns(false)) -> idxs, dists +knn!(idxs, dists, tree, point, k, skip = Returns(false)) ``` * `tree`: The tree instance. @@ -61,7 +64,7 @@ knn!(idxs, dists, tree, point, k, skip = always_false) For the single closest neighbor, you can use `nn`: ```julia -nn(tree, points, skip = always_false) -> idxs, dists +nn(tree, points, skip = Returns(false)) -> idxs, dists ``` Examples: diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 28d998d..efdafe2 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -4,9 +4,8 @@ using Distances import Distances: PreMetric, Metric, result_type, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters using StaticArrays -import Base.show -export NNTree, BruteTree, KDTree, BallTree, DataFreeTree +export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs export injectdata @@ -48,18 +47,21 @@ end get_T(::Type{T}) where {T <: AbstractFloat} = T get_T(::T) where {T} = Float64 -include("evaluation.jl") -include("tree_data.jl") -include("datafreetree.jl") -include("knn.jl") -include("inrange.jl") +get_tree(tree::NNTree) = tree + include("hyperspheres.jl") include("hyperrectangles.jl") +include("evaluation.jl") include("utilities.jl") +include("tree_data.jl") +include("tree_ops.jl") include("brute_tree.jl") include("kd_tree.jl") include("ball_tree.jl") -include("tree_ops.jl") +include("periodic_tree.jl") +include("datafreetree.jl") +include("knn.jl") +include("inrange.jl") for dim in (2, 3) tree = KDTree(rand(dim, 10)) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 1be8cc3..e8f60e1 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -136,7 +136,7 @@ function _knn(tree::BallTree, best_idxs::AbstractVector{<:Integer}, best_dists::AbstractVector, skip::F) where {F} - knn_kernel!(tree, 1, point, best_idxs, best_dists, skip) + knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, false) return end @@ -146,9 +146,9 @@ function knn_kernel!(tree::BallTree{V}, point::AbstractArray, best_idxs::AbstractVector{<:Integer}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F, unique::Bool) where {V, F} if isleaf(tree.tree_data.n_internal_nodes, index) - add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip) + add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, unique) return end @@ -160,14 +160,14 @@ function knn_kernel!(tree::BallTree{V}, if left_dist <= best_dists[1] || right_dist <= best_dists[1] if left_dist < right_dist - knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique) if right_dist <= best_dists[1] - knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique) end else - knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique) if left_dist <= best_dists[1] - knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique) end end end @@ -177,16 +177,19 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V} + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F) where {V, F} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder + return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, false) # Call the recursive range finder end function inrange_kernel!(tree::BallTree, index::Int, point::AbstractVector, query_ball::HyperSphere, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F, + unique::Bool) where {F} if index > length(tree.hyper_spheres) return 0 @@ -204,19 +207,16 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, unique) end - count = 0 - # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses_fast(dist, tree.metric, sphere, query_ball) - count += addall(tree, index, idx_in_ball) + return addall(tree, index, idx_in_ball, skip, unique) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball) + return inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip, unique) + + inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, unique) end - return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index ed5ce9a..cfc713a 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -35,7 +35,7 @@ function _knn(tree::BruteTree{V}, best_dists::AbstractVector, skip::F) where {V, F} - knn_kernel!(tree, point, best_idxs, best_dists, skip) + knn_kernel!(tree, point, best_idxs, best_dists, skip, false) return end @@ -43,12 +43,17 @@ function knn_kernel!(tree::BruteTree{V}, point::AbstractVector, best_idxs::AbstractVector{<:Integer}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F, + unique::Bool) where {V, F} for i in 1:length(tree.data) if skip(i) continue end + #if unique && i in best_idxs + # continue + #end + dist_d = evaluate(tree.metric, tree.data[i], point) if dist_d <= best_dists[1] best_dists[1] = dist_d @@ -61,17 +66,23 @@ end function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) - return inrange_kernel!(tree, point, radius, idx_in_ball) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F,) where {F} + return inrange_kernel!(tree, point, radius, idx_in_ball, skip, false) end function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::Function, + unique::Bool) count = 0 for i in 1:length(tree.data) + if skip(i) + continue + end d = evaluate(tree.metric, tree.data[i], point) if d <= r count += 1 diff --git a/src/inrange.jl b/src/inrange.jl index d271639..0926fbf 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -11,24 +11,28 @@ See also: `inrange!`, `inrangecount`. function inrange(tree::NNTree, points::AbstractVector{T}, radius::Number, - sortres=false) where {T <: AbstractVector} + sortres=false, + skip::F = Returns(false)) where {T <: AbstractVector, F} check_input(tree, points) check_radius(radius) idxs = [Vector{Int}() for _ in 1:length(points)] for i in 1:length(points) - inrange_point!(tree, points[i], radius, sortres, idxs[i]) + inrange_point!(tree, points[i], radius, sortres, idxs[i], skip) end return idxs end -function inrange_point!(tree, point, radius, sortres, idx) - count = _inrange(tree, point, radius, idx) +inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F} = _inrange_point!(tree, point, radius, sortres, idx, skip) + +function _inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F} + count = _inrange(tree, point, radius, idx, skip) if idx !== nothing - if tree.reordered + inner_tree = get_tree(tree) + if inner_tree.reordered @inbounds for j in 1:length(idx) - idx[j] = tree.indices[idx[j]] + idx[j] = inner_tree.indices[idx[j]] end end sortres && sort!(idx) @@ -44,11 +48,11 @@ Useful if one want to avoid allocations or specify the element type of the outpu See also: `inrange`, `inrangecount`. """ -function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} +function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip=Returns(false)) where {V, T <: Number} check_input(tree, point) check_radius(radius) length(idxs) == 0 || throw(ArgumentError("idxs must be empty")) - inrange_point!(tree, point, radius, sortres, idxs) + inrange_point!(tree, point, radius, sortres, idxs, skip) return idxs end @@ -61,7 +65,7 @@ function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sor inrange_matrix(tree, points, radius, Val(dim), sortres) end -function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim} +function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres, skip::F=Returns(false)) where {V, T <: Number, dim, F} # TODO: DRY with inrange for AbstractVector check_input(tree, points) check_radius(radius) @@ -70,7 +74,7 @@ function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Numb for i in 1:n_points point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - inrange_point!(tree, point, radius, sortres, idxs[i]) + inrange_point!(tree, point, radius, sortres, idxs[i], skip) end return idxs end @@ -80,18 +84,18 @@ end Count all the points in the tree which are closer than `radius` to `points`. """ -function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number) where {V, T <: Number} +function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F} check_input(tree, point) check_radius(radius) - return inrange_point!(tree, point, radius, false, nothing) + return inrange_point!(tree, point, radius, false, nothing, skip) end function inrangecount(tree::NNTree, points::AbstractVector{T}, - radius::Number) where {T <: AbstractVector} + radius::Number, skip::F=Returns(false)) where {T <: AbstractVector, F} check_input(tree, points) check_radius(radius) - return inrange_point!.(Ref(tree), points, radius, false, nothing) + return inrange_point!.(Ref(tree), points, radius, false, nothing, skip) end function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number} diff --git a/src/kd_tree.jl b/src/kd_tree.jl index 5518d7d..0b250d5 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -150,7 +150,7 @@ function _knn(tree::KDTree, best_dists::AbstractVector, skip::F) where {F} init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) - knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip) + knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip, false) @simd for i in eachindex(best_dists) @inbounds best_dists[i] = eval_end(tree.metric, best_dists[i]) end @@ -163,10 +163,11 @@ function knn_kernel!(tree::KDTree{V}, best_dists::AbstractVector, min_dist, hyper_rec::HyperRectangle, - skip::F) where {V, F} + skip::F, + unique::Bool) where {V, F} # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip) + add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip, unique) return end @@ -192,14 +193,14 @@ function knn_kernel!(tree::KDTree{V}, ddiff = max(zero(eltype(V)), lo - p_dim) end # Always call closer sub tree - knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip) + knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip, unique) split_diff_pow = eval_pow(M, split_diff) ddiff_pow = eval_pow(M, ddiff) diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) new_min = eval_reduce(M, min_dist, diff_tot) if new_min < best_dists[1] - knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip) + knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip, unique) end return end @@ -207,10 +208,11 @@ end function _inrange(tree::KDTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[]) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F) where {F} init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) return inrange_kernel!(tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, - tree.hyper_rec, init_min) + tree.hyper_rec, init_min, skip, false) end # Explicitly check the distance between leaf node and point while traversing @@ -220,7 +222,9 @@ function inrange_kernel!(tree::KDTree, r::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}}, hyper_rec::HyperRectangle, - min_dist) + min_dist, + skip::F, + unique::Bool) where {F} # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r return 0 @@ -228,7 +232,7 @@ function inrange_kernel!(tree::KDTree, # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, unique) end split_val = tree.split_vals[index] @@ -255,7 +259,7 @@ function inrange_kernel!(tree::KDTree, ddiff = max(zero(lo - p_dim), lo - p_dim) end # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist) + count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, skip, unique) # TODO: We could potentially also keep track of the max distance # between the point and the hyper rectangle and add the whole sub tree @@ -267,6 +271,6 @@ function inrange_kernel!(tree::KDTree, ddiff_pow = eval_pow(M, ddiff) diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) new_min = eval_reduce(M, min_dist, diff_tot) - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min) + count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, skip, unique) return count end diff --git a/src/knn.jl b/src/knn.jl index 775f7eb..eeb9ebf 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -1,5 +1,5 @@ function check_k(tree, k) - if k > length(tree.data) || k < 0 + if k > length(get_tree(tree).data) || k < 0 throw(ArgumentError("k > number of points in tree or < 0")) end end @@ -14,7 +14,7 @@ index. See also: `knn!`, `nn`. """ -function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function} +function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: AbstractVector, F<:Function} check_input(tree, points) check_k(tree, k) n_points = length(points) @@ -26,19 +26,23 @@ function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, return idxs, dists end -function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} +knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} = + _knn_point!(tree, point, sortres, dist, idx, skip) + +function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} fill!(idx, -1) fill!(dist, typemax(get_T(eltype(V)))) _knn(tree, point, idx, dist, skip) - if skip !== always_false + if skip !== Returns(false) skipped_idxs = findall(==(-1), idx) deleteat!(idx, skipped_idxs) deleteat!(dist, skipped_idxs) end sortres && heap_sort_inplace!(dist, idx) - if tree.reordered + inner_tree = get_tree(tree) + if inner_tree.reordered for j in eachindex(idx) - @inbounds idx[j] = tree.indices[idx[j]] + @inbounds idx[j] = inner_tree.indices[idx[j]] end end return @@ -52,7 +56,7 @@ Useful if one want to avoid allocations or specify the element type of the outpu See also: `knn`, `nn`. """ -function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function} +function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function} check_k(tree, k) length(idxs) == k || throw(ArgumentError("idxs must be of length k")) length(dists) == k || throw(ArgumentError("dists must be of length k")) @@ -60,19 +64,15 @@ function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTr return idxs, dists end -function knn(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function} +function knn(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function} idx = Vector{Int}(undef, k) dist = Vector{get_T(eltype(V))}(undef, k) return knn!(idx, dist, tree, point, k, sortres, skip) end -function knn(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function} - dim = size(points, 1) - knn_matrix(tree, points, k, Val(dim), sortres, skip) -end +function knn(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function} + dim = length(V) -# Function barrier -function knn_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, ::Val{dim}, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function, dim} # TODO: DRY with knn for AbstractVector check_input(tree, points) check_k(tree, k) @@ -95,9 +95,9 @@ Performs a lookup of the single nearest neigbours to the `points` from the data. See also: `knn`. """ -nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> only -nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _onlyeach -nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _onlyeach +nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> only +nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _onlyeach +nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _onlyeach _nn(tree, points, skip) = knn(tree, points, 1, false, skip) diff --git a/src/periodic_tree.jl b/src/periodic_tree.jl new file mode 100644 index 0000000..79db3df --- /dev/null +++ b/src/periodic_tree.jl @@ -0,0 +1,107 @@ +struct PeriodicTree{V<:AbstractVector, M, Tree <: NNTree{V, M}, D} <: NNTree{V,M} + tree::Tree + bbox::HyperRectangle{V} + combos::Vector{SVector{D, Int}} + + function PeriodicTree(tree::NNTree{V,M}, bounds_min, bounds_max) where {V,M} + dim = length(V) + if length(bounds_min) != dim || length(bounds_max) != dim + throw(ArgumentError("Bounding box dimensions do not match data dimensions")) + end + + combos = SVector(ntuple(i -> -1:1, Val(dim))) + box_widths = SVector(ntuple(i -> bounds_max[i] - bounds_min[i], Val(dim))) + + for i in 1:dim + if box_widths[i] <= 0 || isinf(box_widths[i]) + combos = setindex(combos, 0:0, i) + end + end + combos = SVector{dim, Int}.(collect(Iterators.product(combos...))) + + # Put the (0, 0, 0, ...) combo first in the list of combos + filtered_product = filter(x -> x != zero(SVector{dim, Int}), combos) + combos_reordered = pushfirst!(filtered_product, zero(SVector{dim, Int})) + return new{V, M, typeof(tree), dim}(tree, HyperRectangle(SVector{dim}(bounds_min), SVector{dim}(bounds_max)), combos_reordered) + end +end + +get_tree(tree::PeriodicTree) = tree.tree + +function Base.show(io::IO, tree::PeriodicTree{V}) where {V} + println(io, "Periodic Tree: $(typeof(tree.tree))") + println(io, " Bounding box: ", tree.bbox.mins, " ", tree.bbox.maxes) + println(io, " Number of points: ", length(tree.tree.data)) + println(io, " Dimensions: ", length(V)) + println(io, " Metric: ", tree.tree.metric) + print(io, " Reordered: ", tree.tree.reordered) +end + +function _knn(tree::PeriodicTree{V,M}, + point::AbstractVector, + best_idxs::AbstractVector{<:Integer}, + best_dists::AbstractVector, + skip::F) where {V, M, F} + dim = length(V) + box_widths = SVector(ntuple(i -> tree.bbox.maxes[i] - tree.bbox.mins[i], Val(dim))) + for combo in tree.combos + shift_vector = box_widths .* combo + point_shifted = point + shift_vector + + min_dist_to_canonical = get_min_distance_no_end(tree.tree.metric, tree.bbox, point_shifted) + + # TODO: Only search the mirror boxes that are relevant + + if tree.tree isa KDTree + knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, min_dist_to_canonical, tree.tree.hyper_rec, skip, true) + elseif tree.tree isa BallTree + knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, skip, true) + else + @assert tree.tree isa BruteTree + knn_kernel!(tree.tree, point_shifted, best_idxs, best_dists, skip, true) + end + end + + if tree.tree isa KDTree + @simd for i in eachindex(best_dists) + @inbounds best_dists[i] = eval_end(tree.tree.metric, best_dists[i]) + end + end + + + @assert allunique(best_idxs) + return +end + +function _inrange(tree::PeriodicTree{V}, + point::AbstractVector, + radius::Number, + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F) where {V, F} + + dim = length(V) + + box_widths = SVector(ntuple(i -> tree.bbox.maxes[i] - tree.bbox.mins[i], Val(dim))) + + for combo in tree.combos + shift_vector = box_widths .* combo + point_shifted = point + shift_vector + + # TODO: Only search the mirror boxes that are relevant + + if tree.tree isa KDTree + min_dist_to_canonical = get_min_distance_no_end(tree.tree.metric, tree.bbox, point_shifted) + inrange_kernel!(tree.tree, 1, point_shifted, eval_op(tree.tree.metric, radius, zero(min_dist_to_canonical)), idx_in_ball, + tree.tree.hyper_rec, min_dist_to_canonical, skip, true) + elseif tree.tree isa BallTree + ball = HyperSphere(convert(V, point_shifted), convert(eltype(V), radius)) + inrange_kernel!(tree.tree, 1, point_shifted, ball, idx_in_ball, skip, true) + else + @assert tree.tree isa BruteTree + inrange_kernel!(tree.tree, point, radius, idx_in_ball, skip, true) + end + end + + @assert allunique(idx_in_ball) + return +end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 39338cf..ddc796a 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -4,7 +4,7 @@ @inline getparent(i::Int) = div(i, 2) @inline isleaf(n_internal_nodes::Int, idx::Int) = idx > n_internal_nodes -function show(io::IO, tree::NNTree{V}) where {V} +function Base.show(io::IO, tree::NNTree{V}) where {V} println(io, typeof(tree)) println(io, " Number of points: ", length(tree.data)) println(io, " Dimensions: ", length(V)) @@ -92,15 +92,25 @@ end # Uses a heap for fast insertion. @inline function add_points_knn!(best_dists::AbstractVector, best_idxs::AbstractVector{<:Integer}, tree::NNTree, index::Int, point::AbstractVector, - do_end::Bool, skip::F) where {F} + do_end::Bool, skip::F, unique::Bool) where {F} for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end idx = tree.reordered ? z : tree.indices[z] dist_d = evaluate_maybe_end(tree.metric, tree.data[idx], point, do_end) - if dist_d <= best_dists[1] - if skip(tree.indices[z]) - continue + if dist_d < best_dists[1] + if unique + idx_existing = findfirst(==(idx), best_idxs) + if idx_existing !== nothing + dist = best_dists[idx_existing] + if dist_d < dist + best_dists[idx_existing] = dist_d + percolate_down!(best_dists, best_idxs, dist_d, idx, idx_existing) + end + continue + end end - best_dists[1] = dist_d best_idxs[1] = idx percolate_down!(best_dists, best_idxs, dist_d, idx) @@ -115,10 +125,17 @@ end # This will probably prevent SIMD and other optimizations so some care is needed # to evaluate if it is worth it. @inline function add_points_inrange!(idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, tree::NNTree, - index::Int, point::AbstractVector, r::Number) + index::Int, point::AbstractVector, r::Number, skip::Function, + unique::Bool) count = 0 for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end idx = tree.reordered ? z : tree.indices[z] + if unique && idx in idx_in_ball + continue + end if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) @@ -138,18 +155,24 @@ end # Add all points in this subtree since we have determined # they are all within the desired range -function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}) +function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}, skip::Function, unique::Bool) tree_data = tree.tree_data - count = 0 if isleaf(tree_data.n_internal_nodes, index) + count = 0 for z in get_leaf_range(tree_data, index) + if skip(tree.indices[z]) + continue + end idx = tree.reordered ? z : tree.indices[z] + if unique && idx in idx_in_ball + continue + end count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) end + return count else - count += addall(tree, getleft(index), idx_in_ball) - count += addall(tree, getright(index), idx_in_ball) + return addall(tree, getleft(index), idx_in_ball, skip, unique) + + addall(tree, getright(index), idx_in_ball, skip, unique) end - return count end diff --git a/src/utilities.jl b/src/utilities.jl index 7a7f30a..00cfa97 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -59,7 +59,7 @@ end @inbounds for i in length(xs):-1:2 xs[i], xs[1] = xs[1], xs[i] xis[i], xis[1] = xis[1], xis[i] - percolate_down!(xs, xis, xs[1], xis[1], i - 1) + percolate_down!(xs, xis, xs[1], xis[1], 1, i - 1) end return end @@ -69,8 +69,9 @@ end xis::AbstractArray, dist::Number, index::Integer, + offset::Integer=1, len::Integer=length(xs)) - i = 1 + i = offset @inbounds while (l = getleft(i)) <= len r = getright(i) j = ifelse(r > len || (xs[l] > xs[r]), l, r) @@ -87,11 +88,6 @@ end return end -# Default skip function, always false -@inline function always_false(::Int) - false -end - # Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} = [SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)] diff --git a/test/runtests.jl b/test/runtests.jl index 584a271..92a081d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ include("test_knn.jl") include("test_inrange.jl") include("test_monkey.jl") include("test_datafreetree.jl") +include("test_periodic.jl") @testset "views of SVector" begin x = [rand(SVector{3}) for i in 1:20] diff --git a/test/test_knn.jl b/test/test_knn.jl index 99d96a5..88cf484 100644 --- a/test/test_knn.jl +++ b/test/test_knn.jl @@ -1,6 +1,6 @@ # Does not test leafsize # Does not test different metrics -import Distances.evaluate +using Distances: evaluate @testset "knn" begin @testset "metric" for metric in [metrics; WeightedEuclidean(ones(2))] diff --git a/test/test_monkey.jl b/test/test_monkey.jl index 4d40570..f376233 100644 --- a/test/test_monkey.jl +++ b/test/test_monkey.jl @@ -32,7 +32,8 @@ import NearestNeighbors.MinkowskiMetric dim_data = rand(1:5) size_data = rand(100:151) data = rand(T, dim_data, size_data) - tree = TreeType(data, metric; leafsize = rand(1:15)) + leafsize = rand(1:15) + tree = TreeType(data, metric; leafsize) btree = BruteTree(data, metric) k = rand(1:12) p = rand(dim_data) diff --git a/test/test_periodic.jl b/test/test_periodic.jl new file mode 100644 index 0000000..f3a10d1 --- /dev/null +++ b/test/test_periodic.jl @@ -0,0 +1,53 @@ +using Test + +using NearestNeighbors, StaticArrays, Distances + +function create_trees(data, bounds_max, reorder) + kdtree = KDTree(data; leafsize=1, reorder) + balltree = BallTree(data; leafsize=1, reorder) + bounds_min = zeros(length(bounds_max)) + + pkdtree = PeriodicTree(kdtree, bounds_min, bounds_max) + pballtree = PeriodicTree(balltree, bounds_min, bounds_max) + btree = BruteTree(data, PeriodicEuclidean(bounds_max)) + return pkdtree, pballtree, btree +end + +function test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, point, r) + idx_btree = sort(inrange(btree, point, r)) + idx_pkdtree = sort(inrange(pkdtree, point, r)) + idx_pballtree = sort(inrange(pballtree, point, r)) + @test idx_btree == idx_pkdtree == idx_pballtree +end + +function test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, point, k) + idx_btree, dists_btree = knn(btree, point, k, true) + idx_pkdtree, dists_pkdtree = knn(pkdtree, point, k, true) + idx_pballtree, dists_pballtree = knn(pballtree, point, k, true) + + @test dists_btree ≈ dists_pkdtree ≈ dists_pballtree + @test idx_btree == idx_pkdtree == idx_pballtree + + return dists_pkdtree +end + +function test_data_bounds_point(data, bounds_max, point) + for reorder = (false, true) + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + for k in 1:length(data) + dists = test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, point, k) + r = maximum(dists) + 0.001 + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, point, r) + end + end +end + +data = SVector{2, Float64}.([(1, 2), (3, 4), (5, 6), (7, 8), (9, 10)]) +bounds_max = (10.0, 10.0) +point = [8.9, 1.9] +test_data_bounds_point(data, bounds_max, point) + +data = SVector{3, Float64}.([(1, 2, 3), (4, 5, 6), (7, 8, 9), (10, 11, 12), (13, 14, 15)]) +bounds_max = (20.0, 20.0, 20.0) +point = [18.0, 19.0, 0.0] +test_data_bounds_point(data, bounds_max, point)