Skip to content

Commit

Permalink
[SEDONA-648] Implement Distributed K Nearest Neighbor Join (#1561)
Browse files Browse the repository at this point in the history
* [SEDONA-648] Implement Distributed K Nearest Neighbor Join

* add test data

* pass KnnJoinSuite

* add KnnJoinQueryTest

* add documentation

* fix test failures
  • Loading branch information
zhangfengcdt authored Aug 28, 2024
1 parent df229ce commit 0d300a2
Show file tree
Hide file tree
Showing 80 changed files with 404,309 additions and 22 deletions.
10 changes: 10 additions & 0 deletions common/src/main/java/org/apache/sedona/common/Predicates.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,14 @@ public static boolean relate(
public static boolean relateMatch(String matrix1, String matrix2) {
return IntersectionMatrix.matches(matrix1, matrix2);
}

public static boolean knn(Geometry leftGeometry, Geometry rightGeometry, int k) {
return knn(leftGeometry, rightGeometry, k, false);
}

public static boolean knn(
Geometry leftGeometry, Geometry rightGeometry, int k, boolean useSpheroid) {
// This should only be used as a test predicate used with extra join condition
return true;
}
}
88 changes: 88 additions & 0 deletions docs/api/sql/NearestNeighbourSearching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@

Sedona supports nearest-neighbour searching on geospatial data by providing a geospatial k-Nearest Neighbors (kNN) join method. This method involves identifying the k-nearest neighbors for a given spatial point or region based on geographic proximity, typically using spatial coordinates and a suitable distance metric like Euclidean or great-circle distance.

## ST_KNN

Introduction: join operation to find the k-nearest neighbors of a point or region in a spatial dataset.

Format: `ST_KNN(R: Table, S: Table, k: Integer, use_spheroid: Boolean)`

Where R is the queries side table and S is the object side table, K is the number of neighbors. use_spheroid is a boolean value that determines whether to use the spheroid distance or not.

Queries side table contains geometries that are used to find the k-nearest neighbors in the object side table.

When either queries or objects data contain non-point data (geometries), we take the centroid of each geometry.

In case there are ties in the distance, the result will include all the tied geometries only when the following sedona config is set to true:

```
spark.sedona.join.knn.includeTieBreakers=true
```

SQL Example

Suppose we have two tables `QUERIES` and `OBJECTS` with the following data:

QUERIES table:

```
ID GEOMETRY NAME
1 POINT(1 1) station1
2 POINT(10 10) station2
3 POINT(-0.5 -0.5) station3
```

OBJECTS table:

```
ID GEOMETRY NAME
1 POINT(11 5) bank1
2 POINT(12 1) bank2
3 POINT(-1 -1) bank3
4 POINT(-3 5) bank4
5 POINT(9 8) bank5
6 POINT(4 3) bank6
7 POINT(-4 -5) bank7
8 POINT(4 -2) bank8
9 POINT(-3 1) bank9
10 POINT(-7 3) bank10
11 POINT(11 5) bank11
12 POINT(12 1) bank12
13 POINT(-1 -1) bank13
14 POINT(-3 5) bank14
15 POINT(9 8) bank15
16 POINT(4 3) bank16
17 POINT(-4 -5) bank17
18 POINT(4 -2) bank18
19 POINT(-3 1) bank19
20 POINT(-7 3) bank20
```

```sql
SELECT
QUERIES.ID AS QUERY_ID,
QUERIES.GEOMETRY AS QUERIES_GEOM,
OBJECTS.GEOMETRY AS OBJECTS_GEOM
FROM QUERIES JOIN OBJECTS ON ST_KNN(QUERIES.GEOMETRY, OBJECTS.GEOMETRY, 4, FALSE)
```

Output:

```
+--------+-----------------+-------------+
|QUERY_ID|QUERIES_GEOM |OBJECTS_GEOM |
+--------+-----------------+-------------+
|3 |POINT (-0.5 -0.5)|POINT (-1 -1)|
|3 |POINT (-0.5 -0.5)|POINT (-1 -1)|
|3 |POINT (-0.5 -0.5)|POINT (-3 1) |
|3 |POINT (-0.5 -0.5)|POINT (-3 1) |
|1 |POINT (1 1) |POINT (-1 -1)|
|1 |POINT (1 1) |POINT (-1 -1)|
|1 |POINT (1 1) |POINT (4 3) |
|1 |POINT (1 1) |POINT (4 3) |
|2 |POINT (10 10) |POINT (9 8) |
|2 |POINT (10 10) |POINT (9 8) |
|2 |POINT (10 10) |POINT (11 5) |
|2 |POINT (10 10) |POINT (11 5) |
+--------+-----------------+-------------+
```
4 changes: 4 additions & 0 deletions docs/api/sql/Parameter.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ sparkSession.conf.set("sedona.global.index","false")
* Spatial partitioning grid type for join query
* Default: kdbtree
* Possible values: quadtree, kdbtree
* spark.sedona.join.knn.includeTieBreakers
* KNN join will include all ties in the result, possibly returning more than k results
* Default: false
* Possible values: true, false
* sedona.join.indexbuildside **(Advanced users only!)**
* The side which Sedona builds spatial indices on
* Default: left
Expand Down
4 changes: 4 additions & 0 deletions docs/tutorial/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,10 @@ LIMIT 5

The details of a join query is available here [Join query](../api/sql/Optimizer.md).

### KNN join query

The details of a KNN join query is available here [KNN join query](../api/sql/NearestNeighbourSearching.md).

### Other queries

There are lots of other functions can be combined with these queries. Please read [SedonaSQL functions](../api/sql/Function.md) and [SedonaSQL aggregate functions](../api/sql/AggregateFunction.md).
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ nav:
- Aggregate function: api/sql/AggregateFunction.md
- DataFrame Style functions: api/sql/DataFrameAPI.md
- Query optimization: api/sql/Optimizer.md
- Nearest-Neighbour searching: api/sql/NearestNeighbourSearching.md
- Reading Legacy Parquet Files: api/sql/Reading-legacy-parquet.md
- Visualization:
- SedonaPyDeck: api/sql/Visualization_SedonaPyDeck.md
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.core.enums;

/**
* The DistanceMetric enum represents the different distance metrics that can be used in the
* application.
*/
public enum DistanceMetric {
/** The Euclidean distance metric, also known as straight line distance. */
EUCLIDEAN,

/**
* The Haversine distance metric, which measures the shortest distance between two points on the
* surface of a sphere.
*/
HAVERSINE,

/**
* The Spheroid distance metric, which measures the shortest distance between two points on the
* surface of a spheroid.
*/
SPHEROID
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ public enum GridType implements Serializable {
QUADTREE,

/** K-D-B-tree partitioning (k-dimensional B-tree) */
KDBTREE;
KDBTREE,

/** Z-ORDER based partitioning (morton space-filling curve) for KNN joins */
ZORDER,

/** Modified Quad-tree partitioning for KNN joins */
QUADTREE_RTREE;

/**
* Gets the grid type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.core.joinJudgement;

import java.io.Serializable;
import java.util.*;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sedona.core.enums.DistanceMetric;
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
import org.apache.sedona.core.knnJudgement.SpheroidDistance;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.index.SpatialIndex;
import org.locationtech.jts.index.strtree.GeometryItemDistance;
import org.locationtech.jts.index.strtree.ItemDistance;
import org.locationtech.jts.index.strtree.STRtree;

/**
* This class is responsible for performing a K-nearest neighbors (KNN) join operation using a
* spatial index. It extends the JudgementBase class and implements the FlatMapFunction2 interface.
*
* @param <T> extends Geometry - the type of geometries in the left set
* @param <U> extends Geometry - the type of geometries in the right set
*/
public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry>
extends JudgementBase<T, U>
implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, T>>, Serializable {
private final int k;
private final DistanceMetric distanceMetric;
private final boolean includeTies;
private final Broadcast<STRtree> broadcastedTreeIndex;

/**
* Constructor for the KnnJoinIndexJudgement class.
*
* @param k the number of nearest neighbors to find
* @param distanceMetric the distance metric to use
* @param buildCount accumulator for the number of geometries processed from the build side
* @param streamCount accumulator for the number of geometries processed from the stream side
* @param resultCount accumulator for the number of join results
* @param candidateCount accumulator for the number of candidate matches
* @param broadcastedTreeIndex the broadcasted spatial index
*/
public KnnJoinIndexJudgement(
int k,
DistanceMetric distanceMetric,
boolean includeTies,
Broadcast<STRtree> broadcastedTreeIndex,
LongAccumulator buildCount,
LongAccumulator streamCount,
LongAccumulator resultCount,
LongAccumulator candidateCount) {
super(null, buildCount, streamCount, resultCount, candidateCount);
this.k = k;
this.distanceMetric = distanceMetric;
this.includeTies = includeTies;
this.broadcastedTreeIndex = broadcastedTreeIndex;
}

/**
* This method performs the KNN join operation. It iterates over the geometries in the stream side
* and uses the spatial index to find the k nearest neighbors for each geometry. The method
* returns an iterator over the join results.
*
* @param streamShapes iterator over the geometries in the stream side
* @param treeIndexes iterator over the spatial indexes
* @return an iterator over the join results
* @throws Exception if the spatial index is not of type STRtree
*/
@Override
public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex> treeIndexes)
throws Exception {
if (!treeIndexes.hasNext() || !streamShapes.hasNext()) {
buildCount.add(0);
streamCount.add(0);
resultCount.add(0);
candidateCount.add(0);
return Collections.emptyIterator();
}

STRtree strTree;
if (broadcastedTreeIndex != null) {
// get the broadcasted spatial index if available
// this is to support the broadcast join
strTree = broadcastedTreeIndex.getValue();
} else {
// get the spatial index from the iterator
SpatialIndex treeIndex = treeIndexes.next();
if (!(treeIndex instanceof STRtree)) {
throw new Exception(
"[KnnJoinIndexJudgement][Call] Only STRtree index supports KNN search.");
}
strTree = (STRtree) treeIndex;
}

List<Pair<U, T>> result = new ArrayList<>();
ItemDistance itemDistance;

while (streamShapes.hasNext()) {
T streamShape = streamShapes.next();
streamCount.add(1);

Object[] localK;
switch (distanceMetric) {
case EUCLIDEAN:
itemDistance = new EuclideanItemDistance();
break;
case HAVERSINE:
itemDistance = new HaversineItemDistance();
break;
case SPHEROID:
itemDistance = new SpheroidDistance();
break;
default:
itemDistance = new GeometryItemDistance();
break;
}

localK =
strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), streamShape, itemDistance, k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
}

for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) streamShape, candidate);
result.add(pair);
resultCount.add(1);
}
}

return result.iterator();
}

private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) {
Envelope searchEnvelope = streamShape.getEnvelopeInternal();
// get the maximum distance from the k nearest neighbors
double maxDistance = 0.0;
LinkedHashSet<T> uniqueCandidates = new LinkedHashSet<>();
for (Object obj : localK) {
T candidate = (T) obj;
uniqueCandidates.add(candidate);
double distance = streamShape.distance(candidate);
if (distance > maxDistance) {
maxDistance = distance;
}
}
searchEnvelope.expandBy(maxDistance);
List<T> candidates = strTree.query(searchEnvelope);
if (!candidates.isEmpty()) {
// update localK with all candidates that are within the maxDistance
List<Object> tiedResults = new ArrayList<>();
// add all localK
Collections.addAll(tiedResults, localK);

for (T candidate : candidates) {
double distance = streamShape.distance(candidate);
if (distance == maxDistance && !uniqueCandidates.contains(candidate)) {
tiedResults.add(candidate);
}
}
localK = tiedResults.toArray();
}
return localK;
}
}
Loading

0 comments on commit 0d300a2

Please sign in to comment.