-
Notifications
You must be signed in to change notification settings - Fork 695
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SEDONA-648] Implement Distributed K Nearest Neighbor Join (#1561)
* [SEDONA-648] Implement Distributed K Nearest Neighbor Join * add test data * pass KnnJoinSuite * add KnnJoinQueryTest * add documentation * fix test failures
- Loading branch information
1 parent
df229ce
commit 0d300a2
Showing
80 changed files
with
404,309 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
|
||
Sedona supports nearest-neighbour searching on geospatial data by providing a geospatial k-Nearest Neighbors (kNN) join method. This method involves identifying the k-nearest neighbors for a given spatial point or region based on geographic proximity, typically using spatial coordinates and a suitable distance metric like Euclidean or great-circle distance. | ||
|
||
## ST_KNN | ||
|
||
Introduction: join operation to find the k-nearest neighbors of a point or region in a spatial dataset. | ||
|
||
Format: `ST_KNN(R: Table, S: Table, k: Integer, use_spheroid: Boolean)` | ||
|
||
Where R is the queries side table and S is the object side table, K is the number of neighbors. use_spheroid is a boolean value that determines whether to use the spheroid distance or not. | ||
|
||
Queries side table contains geometries that are used to find the k-nearest neighbors in the object side table. | ||
|
||
When either queries or objects data contain non-point data (geometries), we take the centroid of each geometry. | ||
|
||
In case there are ties in the distance, the result will include all the tied geometries only when the following sedona config is set to true: | ||
|
||
``` | ||
spark.sedona.join.knn.includeTieBreakers=true | ||
``` | ||
|
||
SQL Example | ||
|
||
Suppose we have two tables `QUERIES` and `OBJECTS` with the following data: | ||
|
||
QUERIES table: | ||
|
||
``` | ||
ID GEOMETRY NAME | ||
1 POINT(1 1) station1 | ||
2 POINT(10 10) station2 | ||
3 POINT(-0.5 -0.5) station3 | ||
``` | ||
|
||
OBJECTS table: | ||
|
||
``` | ||
ID GEOMETRY NAME | ||
1 POINT(11 5) bank1 | ||
2 POINT(12 1) bank2 | ||
3 POINT(-1 -1) bank3 | ||
4 POINT(-3 5) bank4 | ||
5 POINT(9 8) bank5 | ||
6 POINT(4 3) bank6 | ||
7 POINT(-4 -5) bank7 | ||
8 POINT(4 -2) bank8 | ||
9 POINT(-3 1) bank9 | ||
10 POINT(-7 3) bank10 | ||
11 POINT(11 5) bank11 | ||
12 POINT(12 1) bank12 | ||
13 POINT(-1 -1) bank13 | ||
14 POINT(-3 5) bank14 | ||
15 POINT(9 8) bank15 | ||
16 POINT(4 3) bank16 | ||
17 POINT(-4 -5) bank17 | ||
18 POINT(4 -2) bank18 | ||
19 POINT(-3 1) bank19 | ||
20 POINT(-7 3) bank20 | ||
``` | ||
|
||
```sql | ||
SELECT | ||
QUERIES.ID AS QUERY_ID, | ||
QUERIES.GEOMETRY AS QUERIES_GEOM, | ||
OBJECTS.GEOMETRY AS OBJECTS_GEOM | ||
FROM QUERIES JOIN OBJECTS ON ST_KNN(QUERIES.GEOMETRY, OBJECTS.GEOMETRY, 4, FALSE) | ||
``` | ||
|
||
Output: | ||
|
||
``` | ||
+--------+-----------------+-------------+ | ||
|QUERY_ID|QUERIES_GEOM |OBJECTS_GEOM | | ||
+--------+-----------------+-------------+ | ||
|3 |POINT (-0.5 -0.5)|POINT (-1 -1)| | ||
|3 |POINT (-0.5 -0.5)|POINT (-1 -1)| | ||
|3 |POINT (-0.5 -0.5)|POINT (-3 1) | | ||
|3 |POINT (-0.5 -0.5)|POINT (-3 1) | | ||
|1 |POINT (1 1) |POINT (-1 -1)| | ||
|1 |POINT (1 1) |POINT (-1 -1)| | ||
|1 |POINT (1 1) |POINT (4 3) | | ||
|1 |POINT (1 1) |POINT (4 3) | | ||
|2 |POINT (10 10) |POINT (9 8) | | ||
|2 |POINT (10 10) |POINT (9 8) | | ||
|2 |POINT (10 10) |POINT (11 5) | | ||
|2 |POINT (10 10) |POINT (11 5) | | ||
+--------+-----------------+-------------+ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
spark/common/src/main/java/org/apache/sedona/core/enums/DistanceMetric.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.apache.sedona.core.enums; | ||
|
||
/** | ||
* The DistanceMetric enum represents the different distance metrics that can be used in the | ||
* application. | ||
*/ | ||
public enum DistanceMetric { | ||
/** The Euclidean distance metric, also known as straight line distance. */ | ||
EUCLIDEAN, | ||
|
||
/** | ||
* The Haversine distance metric, which measures the shortest distance between two points on the | ||
* surface of a sphere. | ||
*/ | ||
HAVERSINE, | ||
|
||
/** | ||
* The Spheroid distance metric, which measures the shortest distance between two points on the | ||
* surface of a spheroid. | ||
*/ | ||
SPHEROID | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
187 changes: 187 additions & 0 deletions
187
spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.apache.sedona.core.joinJudgement; | ||
|
||
import java.io.Serializable; | ||
import java.util.*; | ||
import org.apache.commons.lang3.tuple.Pair; | ||
import org.apache.sedona.core.enums.DistanceMetric; | ||
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance; | ||
import org.apache.sedona.core.knnJudgement.HaversineItemDistance; | ||
import org.apache.sedona.core.knnJudgement.SpheroidDistance; | ||
import org.apache.spark.api.java.function.FlatMapFunction2; | ||
import org.apache.spark.broadcast.Broadcast; | ||
import org.apache.spark.util.LongAccumulator; | ||
import org.locationtech.jts.geom.Envelope; | ||
import org.locationtech.jts.geom.Geometry; | ||
import org.locationtech.jts.index.SpatialIndex; | ||
import org.locationtech.jts.index.strtree.GeometryItemDistance; | ||
import org.locationtech.jts.index.strtree.ItemDistance; | ||
import org.locationtech.jts.index.strtree.STRtree; | ||
|
||
/** | ||
* This class is responsible for performing a K-nearest neighbors (KNN) join operation using a | ||
* spatial index. It extends the JudgementBase class and implements the FlatMapFunction2 interface. | ||
* | ||
* @param <T> extends Geometry - the type of geometries in the left set | ||
* @param <U> extends Geometry - the type of geometries in the right set | ||
*/ | ||
public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> | ||
extends JudgementBase<T, U> | ||
implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, T>>, Serializable { | ||
private final int k; | ||
private final DistanceMetric distanceMetric; | ||
private final boolean includeTies; | ||
private final Broadcast<STRtree> broadcastedTreeIndex; | ||
|
||
/** | ||
* Constructor for the KnnJoinIndexJudgement class. | ||
* | ||
* @param k the number of nearest neighbors to find | ||
* @param distanceMetric the distance metric to use | ||
* @param buildCount accumulator for the number of geometries processed from the build side | ||
* @param streamCount accumulator for the number of geometries processed from the stream side | ||
* @param resultCount accumulator for the number of join results | ||
* @param candidateCount accumulator for the number of candidate matches | ||
* @param broadcastedTreeIndex the broadcasted spatial index | ||
*/ | ||
public KnnJoinIndexJudgement( | ||
int k, | ||
DistanceMetric distanceMetric, | ||
boolean includeTies, | ||
Broadcast<STRtree> broadcastedTreeIndex, | ||
LongAccumulator buildCount, | ||
LongAccumulator streamCount, | ||
LongAccumulator resultCount, | ||
LongAccumulator candidateCount) { | ||
super(null, buildCount, streamCount, resultCount, candidateCount); | ||
this.k = k; | ||
this.distanceMetric = distanceMetric; | ||
this.includeTies = includeTies; | ||
this.broadcastedTreeIndex = broadcastedTreeIndex; | ||
} | ||
|
||
/** | ||
* This method performs the KNN join operation. It iterates over the geometries in the stream side | ||
* and uses the spatial index to find the k nearest neighbors for each geometry. The method | ||
* returns an iterator over the join results. | ||
* | ||
* @param streamShapes iterator over the geometries in the stream side | ||
* @param treeIndexes iterator over the spatial indexes | ||
* @return an iterator over the join results | ||
* @throws Exception if the spatial index is not of type STRtree | ||
*/ | ||
@Override | ||
public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex> treeIndexes) | ||
throws Exception { | ||
if (!treeIndexes.hasNext() || !streamShapes.hasNext()) { | ||
buildCount.add(0); | ||
streamCount.add(0); | ||
resultCount.add(0); | ||
candidateCount.add(0); | ||
return Collections.emptyIterator(); | ||
} | ||
|
||
STRtree strTree; | ||
if (broadcastedTreeIndex != null) { | ||
// get the broadcasted spatial index if available | ||
// this is to support the broadcast join | ||
strTree = broadcastedTreeIndex.getValue(); | ||
} else { | ||
// get the spatial index from the iterator | ||
SpatialIndex treeIndex = treeIndexes.next(); | ||
if (!(treeIndex instanceof STRtree)) { | ||
throw new Exception( | ||
"[KnnJoinIndexJudgement][Call] Only STRtree index supports KNN search."); | ||
} | ||
strTree = (STRtree) treeIndex; | ||
} | ||
|
||
List<Pair<U, T>> result = new ArrayList<>(); | ||
ItemDistance itemDistance; | ||
|
||
while (streamShapes.hasNext()) { | ||
T streamShape = streamShapes.next(); | ||
streamCount.add(1); | ||
|
||
Object[] localK; | ||
switch (distanceMetric) { | ||
case EUCLIDEAN: | ||
itemDistance = new EuclideanItemDistance(); | ||
break; | ||
case HAVERSINE: | ||
itemDistance = new HaversineItemDistance(); | ||
break; | ||
case SPHEROID: | ||
itemDistance = new SpheroidDistance(); | ||
break; | ||
default: | ||
itemDistance = new GeometryItemDistance(); | ||
break; | ||
} | ||
|
||
localK = | ||
strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), streamShape, itemDistance, k); | ||
if (includeTies) { | ||
localK = getUpdatedLocalKWithTies(streamShape, localK, strTree); | ||
} | ||
|
||
for (Object obj : localK) { | ||
T candidate = (T) obj; | ||
Pair<U, T> pair = Pair.of((U) streamShape, candidate); | ||
result.add(pair); | ||
resultCount.add(1); | ||
} | ||
} | ||
|
||
return result.iterator(); | ||
} | ||
|
||
private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) { | ||
Envelope searchEnvelope = streamShape.getEnvelopeInternal(); | ||
// get the maximum distance from the k nearest neighbors | ||
double maxDistance = 0.0; | ||
LinkedHashSet<T> uniqueCandidates = new LinkedHashSet<>(); | ||
for (Object obj : localK) { | ||
T candidate = (T) obj; | ||
uniqueCandidates.add(candidate); | ||
double distance = streamShape.distance(candidate); | ||
if (distance > maxDistance) { | ||
maxDistance = distance; | ||
} | ||
} | ||
searchEnvelope.expandBy(maxDistance); | ||
List<T> candidates = strTree.query(searchEnvelope); | ||
if (!candidates.isEmpty()) { | ||
// update localK with all candidates that are within the maxDistance | ||
List<Object> tiedResults = new ArrayList<>(); | ||
// add all localK | ||
Collections.addAll(tiedResults, localK); | ||
|
||
for (T candidate : candidates) { | ||
double distance = streamShape.distance(candidate); | ||
if (distance == maxDistance && !uniqueCandidates.contains(candidate)) { | ||
tiedResults.add(candidate); | ||
} | ||
} | ||
localK = tiedResults.toArray(); | ||
} | ||
return localK; | ||
} | ||
} |
Oops, something went wrong.