From 61bb000c90874f66ea6802c01d75f0f6e8ecfa90 Mon Sep 17 00:00:00 2001 From: James Willis Date: Mon, 28 Oct 2024 14:14:40 -0700 Subject: [PATCH] [SEDONA-667] Getis Ord (#1652) * Getis Ord * add matplotlib dependency * scipy/pysal compatability fix * rebase onto pre-commit changes * add missing paren --------- Co-authored-by: jameswillis --- docs/api/stats/sql.md | 67 +++++++ docs/tutorial/sql.md | 61 ++++++ python/Pipfile | 4 + .../stats/hotspot_detection/__init__.py | 18 ++ .../stats/hotspot_detection/getis_ord.py | 65 +++++++ python/sedona/stats/weighting.py | 110 +++++++++++ python/tests/stats/test_getis_ord.py | 157 +++++++++++++++ python/tests/stats/test_weighting.py | 49 +++++ python/tests/test_base.py | 28 ++- .../org/apache/sedona/stats/Weighting.scala | 180 ++++++++++++++++++ .../stats/hotspotDetection/GetisOrd.scala | 105 ++++++++++ .../org/apache/sedona/sql/TestBaseScala.scala | 8 +- .../apache/sedona/stats/WeightingTest.scala | 178 +++++++++++++++++ .../stats/hotspotDetection/GetisOrdTest.scala | 92 +++++++++ 14 files changed, 1120 insertions(+), 2 deletions(-) create mode 100644 python/sedona/stats/hotspot_detection/__init__.py create mode 100644 python/sedona/stats/hotspot_detection/getis_ord.py create mode 100644 python/sedona/stats/weighting.py create mode 100644 python/tests/stats/test_getis_ord.py create mode 100644 python/tests/stats/test_weighting.py create mode 100644 spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala create mode 100644 spark/common/src/main/scala/org/apache/sedona/stats/hotspotDetection/GetisOrd.scala create mode 100644 spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala create mode 100644 spark/common/src/test/scala/org/apache/sedona/stats/hotspotDetection/GetisOrdTest.scala diff --git a/docs/api/stats/sql.md b/docs/api/stats/sql.md index 2906917102..45c17dc4df 100644 --- a/docs/api/stats/sql.md +++ b/docs/api/stats/sql.md @@ -49,3 +49,70 @@ names in parentheses are python variable names - geometry - name of the geometry column - handleTies (handle_ties) - whether to handle ties in the k-distance calculation. Default is false - useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal distance calculation. Default is false + +The output is the input DataFrame with the lof added to each row. + +## Using Getis-Ord Gi(*) + +The G Local function is provided at `org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal` in scala/java and `sedona.stats.hotspot_detection.getis_ord.g_local` in python. + +Performs the Gi or Gi* statistic on the x column of the dataframe. + +Weights should be the neighbors of this row. The members of the weights should be comprised +of structs containing a value column and a neighbor column. The neighbor column should be the +contents of the neighbors with the same types as the parent row (minus neighbors). Reference the _Using the Distance +Weighting Function_ header for instructions on generating this column. To calculate the Gi* +statistic, ensure the focal observation is in the neighbors array (i.e. the row is in the +weights column) and `star=true`. Significance is calculated with a z score. + +### Parameters + +- dataframe - the dataframe to perform the G statistic on +- x - The column name we want to perform hotspot analysis on +- weights - The column name containing the neighbors array. The neighbor column should be the contents of the neighbors with the same types as the parent row (minus neighbors). You can use `Weighting` class functions to achieve this. +- star - Whether the focal observation is in the neighbors array. If true this calculates Gi*, otherwise Gi + +The output is the input DataFrame with the following columns added: G, E[G], V[G], Z, P. + +## Using the Distance Weighting Function + +The Weighting functions are provided at `org.apache.sedona.stats.Weighting` in scala/java and `sedona.stats.weighting` in python. + +The function generates a column containing an array of structs containing a value column and a neighbor column. + +The generic `addDistanceBandColumn` (`add_distance_band_column` in python) function annotates a dataframe with a weights column containing the other records within the threshold and their weight. + +The dataframe should contain at least one `GeometryType` column. Rows must be unique. If one +geometry column is present it will be used automatically. If two are present, the one named +'geometry' will be used. If more than one are present and neither is named 'geometry', the +column name must be provided. The new column will be named 'cluster'. + +### Parameters + +#### addDistanceBandColumn + +names in parentheses are python variable names + +- dataframe - DataFrame with geometry column +- threshold - Distance threshold for considering neighbors +- binary - whether to use binary weights or inverse distance weights for neighbors (dist^alpha) +- alpha - alpha to use for inverse distance weights ignored when binary is true +- includeZeroDistanceNeighbors (include_zero_distance_neighbors) - whether to include neighbors that are 0 distance. If 0 distance neighbors are included and binary is false, values are infinity as per the floating point spec (divide by 0) +- includeSelf (include_self) - whether to include self in the list of neighbors +- selfWeight (self_weight) - the value to use for the self weight +- geometry - name of the geometry column +- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal distance calculation. Default is false + +#### addBinaryDistanceBandColumn + +names in parentheses are python variable names + +- dataframe - DataFrame with geometry column +- threshold - Distance threshold for considering neighbors +- includeZeroDistanceNeighbors (include_zero_distance_neighbors) - whether to include neighbors that are 0 distance. If 0 distance neighbors are included and binary is false, values are infinity as per the floating point spec (divide by 0) +- includeSelf (include_self) - whether to include self in the list of neighbors +- selfWeight (self_weight) - the value to use for the self weight +- geometry - name of the geometry column +- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal distance calculation. Default is false + +In both cases the output is the input DataFrame with the weights column added to each row. diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md index b3e4ecf6ac..98e7006266 100644 --- a/docs/tutorial/sql.md +++ b/docs/tutorial/sql.md @@ -946,6 +946,67 @@ The output will look like this: +--------------------+------------------+ ``` +## Perform Getis-Ord Gi(*) Hot Spot Analysis + +Sedona provides an implementation of the [Gi and Gi*](https://en.wikipedia.org/wiki/Getis%E2%80%93Ord_statistics) algorithms to identify local hotspots in spatial data + +The algorithm is available as a Scala and Python function called on a spatial dataframe. The returned dataframe has additional columns added containing G statistic, E[G], V[G], the Z score, and the p-value. + +Using Gi involves first generating the neighbors list for each record, then calling the g_local function. +=== "Scala" + + ```scala + import org.apache.sedona.stats.Weighting.addBinaryDistanceBandColumn + import org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal + + val distanceRadius = 1.0 + val weightedDf = addBinaryDistanceBandColumn(df, distanceRadius) + gLocal(weightedDf, "val").show() + ``` + +=== "Java" + + ```java + import org.apache.sedona.stats.Weighting; + import org.apache.sedona.stats.hotspotDetection.GetisOrd; + import org.apache.spark.sql.DataFrame; + + double distanceRadius = 1.0; + DataFrame weightedDf = Weighting.addBinaryDistanceBandColumn(df, distanceRadius); + GetisOrd.gLocal(weightedDf, "val").show(); + ``` + +=== "Python" + + ```python + from sedona.stats.weighting import add_binary_distance_band_column + from sedona.stats.hotspot_detection.getis_ord import g_local + + distance_radius = 1.0 + weighted_df = addBinaryDistanceBandColumn(df, distance_radius) + g_local(weightedDf, "val").show() + ``` + +The output will look like this: + +``` ++-----------+---+--------------------+-------------------+-------------------+--------------------+--------------------+--------------------+ +| geometry|val| weights| G| EG| VG| Z| P| ++-----------+---+--------------------+-------------------+-------------------+--------------------+--------------------+--------------------+ +|POINT (2 2)|0.9|[{{POINT (2 3), 1...| 0.4488188976377953|0.45454545454545453| 0.00356321373799772|-0.09593402008347063| 0.4617864875295957| +|POINT (2 3)|1.2|[{{POINT (2 2), 0...|0.35433070866141736|0.36363636363636365|0.003325666155464539|-0.16136436037034918| 0.4359032175415549| +|POINT (3 3)|1.2|[{{POINT (2 3), 1...|0.28346456692913385| 0.2727272727272727|0.002850570990398176| 0.20110780337013057| 0.42030714022155924| +|POINT (3 2)|1.2|[{{POINT (2 2), 0...| 0.4488188976377953|0.45454545454545453| 0.00356321373799772|-0.09593402008347063| 0.4617864875295957| +|POINT (3 1)|1.2|[{{POINT (3 2), 3...| 0.3622047244094489| 0.2727272727272727|0.002850570990398176| 1.6758983614177538| 0.04687905137429871| +|POINT (2 1)|2.2|[{{POINT (2 2), 0...| 0.4330708661417323|0.36363636363636365|0.003325666155464539| 1.2040263812249166| 0.11428969105925013| +|POINT (1 1)|1.2|[{{POINT (2 1), 5...| 0.2834645669291339| 0.2727272727272727|0.002850570990398176| 0.2011078033701316| 0.4203071402215588| +|POINT (1 2)|0.2|[{{POINT (2 2), 0...|0.35433070866141736|0.45454545454545453| 0.00356321373799772| -1.67884535146075|0.046591093685710794| +|POINT (1 3)|1.2|[{{POINT (2 3), 1...| 0.2047244094488189| 0.2727272727272727|0.002850570990398176| -1.2736827546774914| 0.10138793530151635| +|POINT (0 2)|1.0|[{{POINT (1 2), 7...|0.09448818897637795|0.18181818181818182|0.002137928242798632| -1.8887168824332323|0.029464887612748458| +|POINT (4 2)|1.2|[{{POINT (3 2), 3...| 0.1889763779527559|0.18181818181818182|0.002137928242798632| 0.15481285921583854| 0.43848442662481324| ++-----------+---+--------------------+-------------------+-------------------+--------------------+--------------------+--------------------+ +``` + ## Run spatial queries After creating a Geometry type column, you are able to run spatial queries. diff --git a/python/Pipfile b/python/Pipfile index 6c7e142b38..8c899b263e 100644 --- a/python/Pipfile +++ b/python/Pipfile @@ -11,6 +11,10 @@ mkdocs="*" pytest-cov = "*" scikit-learn = "*" +esda = "*" +libpysal = "*" +matplotlib = "*" # implicit dependency of esda +scipy = "<=1.10.0" # prevent incompatibility with pysal 4.7.0, which is what is resolved to when shapely >2 is specified [packages] pandas="<=1.5.3" diff --git a/python/sedona/stats/hotspot_detection/__init__.py b/python/sedona/stats/hotspot_detection/__init__.py new file mode 100644 index 0000000000..ed2c4706d1 --- /dev/null +++ b/python/sedona/stats/hotspot_detection/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Detecting across a region where a variable's value is significantly different from other values nearby.""" diff --git a/python/sedona/stats/hotspot_detection/getis_ord.py b/python/sedona/stats/hotspot_detection/getis_ord.py new file mode 100644 index 0000000000..f8da89afaf --- /dev/null +++ b/python/sedona/stats/hotspot_detection/getis_ord.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Getis Ord functions. From the 1992 paper by Getis & Ord. + +Getis, A., & Ord, J. K. (1992). The analysis of spatial association by use of distance statistics. +Geographical Analysis, 24(3), 189-206. https://doi.org/10.1111/j.1538-4632.1992.tb00261.x +""" + +from pyspark.sql import Column, DataFrame, SparkSession + +# todo change weights and x type to string + + +def g_local( + dataframe: DataFrame, + x: str, + weights: str = "weights", + permutations: int = 0, + star: bool = False, + island_weight: float = 0.0, +) -> DataFrame: + """Performs the Gi or Gi* statistic on the x column of the dataframe. + + Weights should be the neighbors of this row. The members of the weights should be comprised of structs containing a + value column and a neighbor column. The neighbor column should be the contents of the neighbors with the same types + as the parent row (minus neighbors). You can use `wherobots.weighing.add_distance_band_column` to achieve this. To + calculate the Gi* statistic, ensure the focal observation is in the neighbors array (i.e. the row is in the weights + column) and `star=true`. Significance is calculated with a z score. Permutation tests are not yet implemented and + thus island weight does nothing. The following columns will be added: G, E[G], V[G], Z, P. + + Args: + dataframe: the dataframe to perform the G statistic on + x: The column name we want to perform hotspot analysis on + weights: The column name containing the neighbors array. The neighbor column should be the contents of + the neighbors with the same types as the parent row (minus neighbors). You can use + `wherobots.weighing.add_distance_band_column` to achieve this. + permutations: Not used. Permutation tests are not supported yet. The number of permutations to use for the + significance test. + star: Whether the focal observation is in the neighbors array. If true this calculates Gi*, otherwise Gi + island_weight: Not used. The weight for the simulated neighbor used for records without a neighbor in perm tests + Returns: + A dataframe with the original columns plus the columns G, E[G], V[G], Z, P. + """ + sedona = SparkSession.getActiveSession() + + result_df = sedona._jvm.org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal( + dataframe, x, weights, permutations, star, island_weight + ) + + return DataFrame(result_df, sedona) diff --git a/python/sedona/stats/weighting.py b/python/sedona/stats/weighting.py new file mode 100644 index 0000000000..8a5fc7e07a --- /dev/null +++ b/python/sedona/stats/weighting.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Weighting functions for spatial data.""" + +from typing import Optional + +from pyspark.sql import DataFrame, SparkSession + + +def add_distance_band_column( + dataframe: DataFrame, + threshold: float, + binary: bool = True, + alpha: float = -1.0, + include_zero_distance_neighbors: bool = False, + include_self: bool = False, + self_weight: float = 1.0, + geometry: Optional[str] = None, + use_spheroid: bool = False, +) -> DataFrame: + """Annotates a dataframe with a weights column containing the other records within the threshold and their weight. + + The dataframe should contain at least one GeometryType column. Rows must be unique. If one + geometry column is present it will be used automatically. If two are present, the one named + 'geometry' will be used. If more than one are present and neither is named 'geometry', the + column name must be provided. The new column will be named 'cluster'. + + Args: + dataframe: DataFrame with geometry column + threshold: Distance threshold for considering neighbors + binary: whether to use binary weights or inverse distance weights for neighbors (dist^alpha) + alpha: alpha to use for inverse distance weights ignored when binary is true + include_zero_distance_neighbors: whether to include neighbors that are 0 distance. If 0 distance neighbors are + included and binary is false, values are infinity as per the floating point spec (divide by 0) + include_self: whether to include self in the list of neighbors + self_weight: the value to use for the self weight + geometry: name of the geometry column + use_spheroid: whether to use a cartesian or spheroidal distance calculation. Default is false + + Returns: + The input DataFrame with a weight column added containing neighbors and their weights added to each row. + + """ + sedona = SparkSession.getActiveSession() + return sedona._jvm.org.apache.sedona.stats.Weighting.addDistanceBandColumn( + dataframe._jdf, + float(threshold), + binary, + float(alpha), + include_zero_distance_neighbors, + include_self, + float(self_weight), + geometry, + use_spheroid, + ) + + +def add_binary_distance_band_column( + dataframe: DataFrame, + threshold: float, + include_zero_distance_neighbors: bool = True, + include_self: bool = False, + geometry: Optional[str] = None, + use_spheroid: bool = False, +) -> DataFrame: + """Annotates a dataframe with a weights column containing the other records within the threshold and their weight. + + Weights will always be 1.0. The dataframe should contain at least one GeometryType column. Rows must be unique. If + one geometry column is present it will be used automatically. If two are present, the one named 'geometry' will be + used. If more than one are present and neither is named 'geometry', the column name must be provided. The new column + will be named 'cluster'. + + Args: + dataframe: DataFrame with geometry column + threshold: Distance threshold for considering neighbors + include_zero_distance_neighbors: whether to include neighbors that are 0 distance. If 0 distance neighbors are + included and binary is false, values are infinity as per the floating point spec (divide by 0) + include_self: whether to include self in the list of neighbors + geometry: name of the geometry column + use_spheroid: whether to use a cartesian or spheroidal distance calculation. Default is false + + Returns: + The input DataFrame with a weight column added containing neighbors and their weights (always 1) added to each + row. + + """ + sedona = SparkSession.getActiveSession() + return sedona._jvm.org.apache.sedona.stats.Weighting.addBinaryDistanceBandColumn( + dataframe._jdf, + float(threshold), + include_zero_distance_neighbors, + include_self, + geometry, + use_spheroid, + ) diff --git a/python/tests/stats/test_getis_ord.py b/python/tests/stats/test_getis_ord.py new file mode 100644 index 0000000000..d40ef4056a --- /dev/null +++ b/python/tests/stats/test_getis_ord.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from esda.getisord import G_Local +from libpysal.weights import DistanceBand +from pyspark.sql import functions as f +from tests.test_base import TestBase + +from sedona.sql.st_constructors import ST_MakePoint +from sedona.stats.hotspot_detection.getis_ord import g_local +from sedona.stats.weighting import add_distance_band_column + + +class TestGetisOrd(TestBase): + def get_data(self): + return [ + {"id": 0, "x": 2.0, "y": 2.0, "val": 0.9}, + {"id": 1, "x": 2.0, "y": 3.0, "val": 1.2}, + {"id": 2, "x": 3.0, "y": 3.0, "val": 1.2}, + {"id": 3, "x": 3.0, "y": 2.0, "val": 1.2}, + {"id": 4, "x": 3.0, "y": 1.0, "val": 1.2}, + {"id": 5, "x": 2.0, "y": 1.0, "val": 2.2}, + {"id": 6, "x": 1.0, "y": 1.0, "val": 1.2}, + {"id": 7, "x": 1.0, "y": 2.0, "val": 0.2}, + {"id": 8, "x": 1.0, "y": 3.0, "val": 1.2}, + {"id": 9, "x": 0.0, "y": 2.0, "val": 1.0}, + {"id": 10, "x": 4.0, "y": 2.0, "val": 1.2}, + ] + + def get_dataframe(self): + return self.spark.createDataFrame(self.get_data()).select( + ST_MakePoint("x", "y").alias("geometry"), "id", "val" + ) + + def test_gi_results_match_pysal(self): + # actual + input_dataframe = add_distance_band_column(self.get_dataframe(), 1.0) + actual_df = g_local(input_dataframe, "val", "weights") + + # expected_results + data = self.get_data() + points = [(datum["x"], datum["y"]) for datum in data] + w = DistanceBand(points, threshold=1.0) + y = [datum["val"] for datum in data] + expected_data = G_Local(y, w, transform="B") + + # assert + actuals = actual_df.orderBy(f.col("id").asc()).collect() + self.assert_almost_equal(expected_data.Gs.tolist(), [row.G for row in actuals]) + self.assert_almost_equal( + expected_data.EGs.tolist(), [row.EG for row in actuals] + ) + self.assert_almost_equal( + expected_data.VGs.tolist(), [row.VG for row in actuals] + ) + self.assert_almost_equal(expected_data.Zs.tolist(), [row.Z for row in actuals]) + self.assert_almost_equal( + expected_data.p_norm.tolist(), [row.P for row in actuals] + ) + + def test_gistar_results_match_pysal(self): + # actual + input_dataframe = add_distance_band_column( + self.get_dataframe(), 1.0, include_self=True + ) + actual_df = g_local(input_dataframe, "val", "weights", star=True) + + # expected_results + data = self.get_data() + points = [(datum["x"], datum["y"]) for datum in data] + w = DistanceBand(points, threshold=1.0) + y = [datum["val"] for datum in data] + expected_data = G_Local(y, w, transform="B", star=True) + + # assert + actuals = actual_df.orderBy(f.col("id").asc()).collect() + self.assert_almost_equal(expected_data.Gs.tolist(), [row.G for row in actuals]) + self.assert_almost_equal( + expected_data.EGs.tolist(), [row.EG for row in actuals] + ) + self.assert_almost_equal( + expected_data.VGs.tolist(), [row.VG for row in actuals] + ) + self.assert_almost_equal(expected_data.Zs.tolist(), [row.Z for row in actuals]) + self.assert_almost_equal( + expected_data.p_norm.tolist(), [row.P for row in actuals] + ) + + def test_gi_results_match_pysal_nb(self): + # actual + input_dataframe = add_distance_band_column( + self.get_dataframe(), 1.0, binary=False + ) + actual_df = g_local(input_dataframe, "val", "weights") + + # expected_results + data = self.get_data() + points = [(datum["x"], datum["y"]) for datum in data] + w = DistanceBand(points, threshold=1.0, binary=False) + y = [datum["val"] for datum in data] + expected_data = G_Local(y, w, transform="B") + + # assert + actuals = actual_df.orderBy(f.col("id").asc()).collect() + self.assert_almost_equal(expected_data.Gs.tolist(), [row.G for row in actuals]) + self.assert_almost_equal( + expected_data.EGs.tolist(), [row.EG for row in actuals] + ) + self.assert_almost_equal( + expected_data.VGs.tolist(), [row.VG for row in actuals] + ) + self.assert_almost_equal(expected_data.Zs.tolist(), [row.Z for row in actuals]) + self.assert_almost_equal( + expected_data.p_norm.tolist(), [row.P for row in actuals] + ) + + def test_gistar_results_match_pysal_nb(self): + # actual + input_dataframe = add_distance_band_column( + self.get_dataframe(), 1.0, include_self=True, binary=False + ) + actual_df = g_local(input_dataframe, "val", "weights", star=True) + + # expected_results + data = self.get_data() + points = [(datum["x"], datum["y"]) for datum in data] + w = DistanceBand(points, threshold=1.0, binary=False) + y = [datum["val"] for datum in data] + expected_data = G_Local(y, w, transform="B", star=True) + + # assert + actuals = actual_df.orderBy(f.col("id").asc()).collect() + self.assert_almost_equal(expected_data.Gs.tolist(), [row.G for row in actuals]) + self.assert_almost_equal( + expected_data.EGs.tolist(), [row.EG for row in actuals] + ) + self.assert_almost_equal( + expected_data.VGs.tolist(), [row.VG for row in actuals] + ) + self.assert_almost_equal(expected_data.Zs.tolist(), [row.Z for row in actuals]) + self.assert_almost_equal( + expected_data.p_norm.tolist(), [row.P for row in actuals] + ) diff --git a/python/tests/stats/test_weighting.py b/python/tests/stats/test_weighting.py new file mode 100644 index 0000000000..5dadcaef46 --- /dev/null +++ b/python/tests/stats/test_weighting.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyspark.sql.functions as f +from tests.test_base import TestBase + +from sedona.sql.st_constructors import ST_MakePoint +from sedona.stats.weighting import ( + add_binary_distance_band_column, + add_distance_band_column, +) + + +class TestWeighting(TestBase): + def get_dataframe(self): + data = [[0, 1, 1], [1, 1, 2]] + + return ( + self.spark.createDataFrame(data) + .select(ST_MakePoint("_1", "_2").alias("geometry")) + .withColumn("anotherColumn", f.rand()) + ) + + def test_calling_weighting_works(self): + df = self.get_dataframe() + add_distance_band_column(df, 1.0) + + def test_calling_binary_weighting_matches_expected(self): + df = self.get_dataframe() + self.assert_dataframes_equal( + add_distance_band_column( + df, 1.0, binary=True, include_zero_distance_neighbors=True + ), + add_binary_distance_band_column(df, 1.0), + ) diff --git a/python/tests/test_base.py b/python/tests/test_base.py index d600c45f5a..84f27356f6 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -16,9 +16,11 @@ # under the License. import os from tempfile import mkdtemp -from typing import Union +from typing import Iterable, Union + import pyspark +from pyspark.sql import DataFrame from sedona.spark import * from sedona.utils.decorators import classproperty @@ -65,6 +67,30 @@ def sc(self): setattr(self, "__sc", self.spark._sc) return getattr(self, "__sc") + @classmethod + def assert_almost_equal( + self, + a: Union[Iterable[float], float], + b: Union[Iterable[float], float], + tolerance: float = 0.00001, + ): + assert type(a) is type(b) + if isinstance(a, Iterable): + assert len(a) == len(b) + for i in range(len(a)): + self.assert_almost_equal(a[i], b[i], tolerance) + elif isinstance(b, float): + assert abs(a - b) < tolerance + else: + raise TypeError("this function is only for floats and iterables of floats") + + @classmethod + def assert_dataframes_equal(self, df1: DataFrame, df2: DataFrame): + df_diff1 = df1.exceptAll(df2) + df_diff2 = df2.exceptAll(df1) + + assert df_diff1.isEmpty and df_diff2.isEmpty + @classmethod def assert_geometry_almost_equal( cls, diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala new file mode 100644 index 0000000000..6d5a273854 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats + +import org.apache.sedona.stats.Util.getGeometryColumnName +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance, ST_DistanceSpheroid} +import org.apache.spark.sql.{Column, DataFrame} + +object Weighting { + + private val ID_COLUMN = "__id" + + /** + * Annotates a dataframe with a weights column for each data record containing the other members + * within the threshold and their weight. The dataframe should contain at least one GeometryType + * column. Rows must be unique. If one geometry column is present it will be used automatically. + * If two are present, the one named 'geometry' will be used. If more than one are present and + * neither is named 'geometry', the column name must be provided. The new column will be named + * 'cluster'. + * + * @param dataframe + * DataFrame with geometry column + * @param threshold + * Distance threshold for considering neighbors + * @param binary + * whether to use binary weights or inverse distance weights for neighbors (dist^alpha) + * @param alpha + * alpha to use for inverse distance weights ignored when binary is true + * @param includeZeroDistanceNeighbors + * whether to include neighbors that are 0 distance. If 0 distance neighbors are included and + * binary is false, values are infinity as per the floating point spec (divide by 0) + * @param includeSelf + * whether to include self in the list of neighbors + * @param selfWeight + * the value to use for the self weight + * @param geometry + * name of the geometry column + * @param useSpheroid + * whether to use a cartesian or spheroidal distance calculation. Default is false + * @return + * The input DataFrame with a weight column added containing neighbors and their weights added + * to each row. + */ + def addDistanceBandColumn( + dataframe: DataFrame, + threshold: Double, + binary: Boolean = true, + alpha: Double = -1.0, + includeZeroDistanceNeighbors: Boolean = false, + includeSelf: Boolean = false, + selfWeight: Double = 1.0, + geometry: String = null, + useSpheroid: Boolean = false): DataFrame = { + + require(threshold >= 0, "Threshold must be greater than or equal to 0") + require(alpha < 0, "Alpha must be less than 0") + + val geometryColumn = geometry match { + case null => getGeometryColumnName(dataframe) + case _ => + require( + dataframe.schema.fields.exists(_.name == geometry), + s"Geometry column $geometry not found in dataframe") + geometry + } + + val distanceFunction: (Column, Column) => Column = + if (useSpheroid) ST_DistanceSpheroid else ST_Distance + + val joinCondition = if (includeZeroDistanceNeighbors) { + distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")) <= threshold + } else { + distanceFunction( + col(s"l.$geometryColumn"), + col(s"r.$geometryColumn")) <= threshold && distanceFunction( + col(s"l.$geometryColumn"), + col(s"r.$geometryColumn")) > 0 + } + + val formattedDataFrame = dataframe.withColumn(ID_COLUMN, sha2(to_json(struct("*")), 256)) + + // Since spark 3.0 doesn't support dropFields, we need a work around + val withoutId = (prefix: String, colFunc: String => Column) => { + formattedDataFrame.schema.fields + .map(_.name) + .filter(name => name != ID_COLUMN) + .map(x => colFunc(prefix + "." + x).alias(x)) + } + + formattedDataFrame + .alias("l") + .join( + formattedDataFrame.alias("r"), + joinCondition && col(s"l.$ID_COLUMN") =!= col( + s"r.$ID_COLUMN" + ), // we will add self back later if self.includeSelf + "left") + .select( + col(s"l.$ID_COLUMN"), + struct("l.*").alias("left_contents"), + struct( + struct(withoutId("r", col): _*).alias("neighbor"), + if (!binary) + pow(distanceFunction(col(s"l.$geometryColumn"), col(s"r.$geometryColumn")), alpha) + .alias("value") + else lit(1.0).alias("value")).alias("weight")) + .groupBy(s"l.$ID_COLUMN") + .agg( + first("left_contents").alias("left_contents"), + concat( + collect_list(col("weight")), + if (includeSelf) + array( + struct( + struct(withoutId("left_contents", first): _*).alias("neighbor"), + lit(selfWeight).alias("value"))) + else array()).alias("weights")) + .select("left_contents.*", "weights") + .drop(ID_COLUMN) + .withColumn("weights", filter(col("weights"), _(f"neighbor")(geometryColumn).isNotNull)) + } + + /** + * Annotates a dataframe with a weights column for each data record containing the other members + * within the threshold and their weight. Weights will always be 1.0. The dataframe should + * contain at least one GeometryType column. Rows must be unique. If one geometry column is + * present it will be used automatically. If two are present, the one named 'geometry' will be + * used. If more than one are present and neither is named 'geometry', the column name must be + * provided. The new column will be named 'cluster'. + * + * @param dataframe + * DataFrame with geometry column + * @param threshold + * Distance threshold for considering neighbors + * @param includeZeroDistanceNeighbors + * whether to include neighbors that are 0 distance. If 0 distance neighbors are included and + * binary is false, values are infinity as per the floating point spec (divide by 0) + * @param includeSelf + * whether to include self in the list of neighbors + * @param geometry + * name of the geometry column + * @param useSpheroid + * whether to use a cartesian or spheroidal distance calculation. Default is false + * @return + * The input DataFrame with a weight column added containing neighbors and their weights + * (always 1) added to each row. + */ + def addBinaryDistanceBandColumn( + dataframe: DataFrame, + threshold: Double, + includeZeroDistanceNeighbors: Boolean = true, + includeSelf: Boolean = false, + geometry: String = null, + useSpheroid: Boolean = false): DataFrame = addDistanceBandColumn( + dataframe, + threshold, + binary = true, + includeZeroDistanceNeighbors = includeZeroDistanceNeighbors, + includeSelf = includeSelf, + geometry = geometry, + useSpheroid = useSpheroid) + +} diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/hotspotDetection/GetisOrd.scala b/spark/common/src/main/scala/org/apache/sedona/stats/hotspotDetection/GetisOrd.scala new file mode 100644 index 0000000000..f14edf463b --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/stats/hotspotDetection/GetisOrd.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats.hotspotDetection + +import org.apache.commons.math3.distribution.NormalDistribution +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.{Column, DataFrame, functions => f} + +object GetisOrd { + + def arraySum(arr: Column): Column = f.aggregate(arr, f.lit(0.0), (acc, x) => acc + x) + + private val cdfUDF = udf((z: Double) => { + new NormalDistribution().cumulativeProbability(z) + }) + + /** + * Performs the Gi or Gi* statistic on the x column of the dataframe. + * + * Weights should be the neighbors of this row. The members of the weights should be comprised + * of structs containing a value column and a neighbor column. The neighbor column should be the + * contents of the neighbors with the same types as the parent row (minus neighbors). You can + * use `wherobots.weighing.add_distance_band_column` to achieve this. To calculate the Gi* + * statistic, ensure the focal observation is in the neighbors array (i.e. the row is in the + * weights column) and `star=true`. Significance is calculated with a z score. Permutation tests + * are not yet implemented and thus island weight does nothing. The following columns will be + * added: G, E[G], V[G], Z, P. + * + * @param dataframe + * the dataframe to perform the G statistic on + * @param x + * The column name we want to perform hotspot analysis on + * @param weights + * The column name containing the neighbors array. The neighbor column should be the contents + * of the neighbors with the same types as the parent row (minus neighbors). You can use + * `wherobots.weighing.add_distance_band_column` to achieve this. + * @param permutations + * Not used. Permutation tests are not supported yet. The number of permutations to use for + * the significance test. + * @param star + * Whether the focal observation is in the neighbors array. If true this calculates Gi*, + * otherwise Gi + * @param islandWeight + * Not used. The weight for the simulated neighbor used for records without a neighbor in perm + * tests + * + * @return + * A dataframe with the original columns plus the columns G, E[G], V[G], Z, P. + */ + def gLocal( + dataframe: DataFrame, + x: String, + weights: String = "weights", + permutations: Int = 0, + star: Boolean = false, + islandWeight: Double = 0.0): DataFrame = { + + val removeSelf = f.lit(if (star) 0.0 else 1.0) + + val setStats = dataframe.agg(f.sum(f.col(x)), f.sum(f.pow(f.col(x), 2)), f.count("*")).first() + val sumOfAllY = f.lit(setStats.get(0)) + val sumOfSquaresofAllY = f.lit(setStats.get(1)) + val countOfAllY = f.lit(setStats.get(2)) + + dataframe + .withColumn( + "G", + arraySum( + f.transform( + f.col(weights), + weight => + weight("value") * weight("neighbor")(x))) / (sumOfAllY - removeSelf * f.col(x))) + .withColumn("W", arraySum(f.transform(f.col(weights), weight => weight.getField("value")))) + .withColumn("EG", f.col("W") / (countOfAllY - removeSelf)) + .withColumn("Y1", (sumOfAllY - removeSelf * f.col(x)) / (countOfAllY - removeSelf)) + .withColumn( + "Y2", + ((sumOfSquaresofAllY - removeSelf * f.pow(f.col(x), 2)) / (countOfAllY - removeSelf)) - f + .pow("Y1", 2.0)) + .withColumn( + "VG", + (f.col("W") * (countOfAllY - removeSelf - f.col("W")) * f.col("Y2")) / (f.pow( + countOfAllY - removeSelf, + 2.0) * (countOfAllY - 1 - removeSelf) * f.pow(f.col("Y1"), 2.0))) + .withColumn("Z", (f.col("G") - f.col("EG")) / f.sqrt(f.col("VG"))) + .withColumn("P", f.lit(1.0) - cdfUDF(f.abs(f.col("Z")))) + .drop("W", "Y1", "Y2") + } +} diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index a41cdab6ea..86fedeb6a2 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -28,7 +28,6 @@ import org.apache.sedona.common.sphere.{Haversine, Spheroid} import org.apache.sedona.spark.SedonaContext import org.apache.spark.SparkContext import org.apache.spark.sql.{DataFrame, SparkSession} -import org.junit.Assert.fail import org.locationtech.jts.geom._ import org.locationtech.jts.io.WKTReader import org.scalatest.{BeforeAndAfterAll, FunSpec} @@ -321,6 +320,13 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { (hdfsCluster, "hdfs://127.0.0.1:" + hdfsCluster.getNameNodePort + "/") } + protected def assertDataFramesEqual(df1: DataFrame, df2: DataFrame): Unit = { + val dfDiff1 = df1.except(df2) + val dfDiff2 = df2.except(df1) + + assert(dfDiff1.isEmpty && dfDiff2.isEmpty) + } + def assertGeometryEquals(expectedWkt: String, actualWkt: String, tolerance: Double): Unit = { val reader = new WKTReader val expectedGeom = reader.read(expectedWkt) diff --git a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala new file mode 100644 index 0000000000..a7a8865dda --- /dev/null +++ b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats + +import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint +import org.apache.spark.sql.{DataFrame, Row, functions => f} + +class WeightingTest extends TestBaseScala { + + case class Neighbors(id: Int, neighbor: Seq[Int]) + + case class Point(id: Int, x: Double, y: Double) + + var originalAQEValue: String = null + + override def beforeAll: Unit = { + super.beforeAll() + originalAQEValue = sparkSession.conf.get("spark.sql.adaptive.enabled") + sparkSession.conf.set("spark.sql.adaptive.enabled", "false") + } + + override def afterAll: Unit = { + super.beforeAll() + sparkSession.conf.set("spark.sql.adaptive.enabled", originalAQEValue) + } + + def getData(): DataFrame = { + sparkSession + .createDataFrame( + Seq( + Point(0, 2.0, 2.0), + Point(1, 2.0, 3.0), + Point(2, 3.0, 3.0), + Point(3, 3.0, 2.0), + Point(4, 3.0, 1.0), + Point(5, 2.0, 1.0), + Point(6, 1.0, 1.0), + Point(7, 1.0, 2.0), + Point(8, 1.0, 3.0), + Point(9, 0.0, 2.0), + Point(10, 4.0, 2.0))) + .withColumn("geometry", ST_MakePoint("x", "y")) + .drop("x", "y") + } + + def getDupedData(): DataFrame = { + sparkSession + .createDataFrame(Seq(Point(0, 1.0, 1.0), Point(1, 1.0, 1.0), Point(2, 2.0, 1.0))) + .withColumn("geometry", ST_MakePoint("x", "y")) + .drop("x", "y") + } + + describe("addDistanceBandColumn") { + + it("returns correct results") { + // Explode avoids the need to read a nested struct + // https://issues.apache.org/jira/browse/SPARK-48942 + val actualDf = Weighting + .addDistanceBandColumn(getData(), 1.0) + .select( + f.col("id"), + f.array_sort( + f.transform(f.col("weights"), w => w("neighbor")("id")).as("neighbor_ids"))) + val expectedDf = sparkSession.createDataFrame( + Seq( + Neighbors(0, Seq(1, 3, 5, 7)), + Neighbors(1, Seq(0, 2, 8)), + Neighbors(2, Seq(1, 3)), + Neighbors(3, Seq(0, 2, 4, 10)), + Neighbors(4, Seq(3, 5)), + Neighbors(5, Seq(0, 4, 6)), + Neighbors(6, Seq(5, 7)), + Neighbors(7, Seq(0, 6, 8, 9)), + Neighbors(8, Seq(1, 7)), + Neighbors(9, Seq(7)), + Neighbors(10, Seq(3)))) + + assertDataFramesEqual(actualDf, expectedDf) + } + + it("return empty weights array when no neighbors") { + val actualDf = Weighting.addDistanceBandColumn(getData(), .9) + + assert(actualDf.count() == 11) + assert(actualDf.filter(f.size(f.col("weights")) > 0).count() == 0) + } + + it("respect includeZeroDistanceNeighbors flag") { + val actualDfWithZeroDistanceNeighbors = Weighting + .addDistanceBandColumn( + getDupedData(), + 1.1, + includeZeroDistanceNeighbors = true, + binary = false) + .select( + f.col("id"), + f.transform(f.col("weights"), w => w("neighbor")("id")).as("neighbor_ids")) + + assertDataFramesEqual( + actualDfWithZeroDistanceNeighbors, + sparkSession.createDataFrame( + Seq(Neighbors(0, Seq(1, 2)), Neighbors(1, Seq(0, 2)), Neighbors(2, Seq(0, 1))))) + + val actualDfWithoutZeroDistanceNeighbors = Weighting + .addDistanceBandColumn(getDupedData(), 1.1) + .select( + f.col("id"), + f.transform(f.col("weights"), w => w("neighbor")("id")).as("neighbor_ids")) + + assertDataFramesEqual( + actualDfWithoutZeroDistanceNeighbors, + sparkSession.createDataFrame( + Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(0, 1))))) + + } + + it("adds binary weights") { + + val result = Weighting.addDistanceBandColumn(getData(), 2.0, geometry = "geometry") + val weights = result.select("weights").collect().map(_.getSeq[Row](0)) + assert(weights.forall(_.forall(_.getAs[Double]("value") == 1.0))) + } + + it("adds non-binary weights when binary is false") { + + val result = Weighting.addDistanceBandColumn( + getData(), + 2.0, + binary = false, + alpha = -.9, + geometry = "geometry") + val weights = result.select("weights").collect().map(_.getSeq[Row](0)) + assert(weights.exists(_.exists(_.getAs[Double]("value") != 1.0))) + } + + it("throws IllegalArgumentException when threshold is negative") { + + assertThrows[IllegalArgumentException] { + Weighting.addDistanceBandColumn(getData(), -1.0, geometry = "geometry") + } + } + + it("throws IllegalArgumentException when alpha is >=0") { + + assertThrows[IllegalArgumentException] { + Weighting.addDistanceBandColumn( + getData(), + 2.0, + binary = false, + alpha = 1.0, + geometry = "geometry") + } + } + + it("throw IllegalArgumentException with non-existent geometry column") { + assertThrows[IllegalArgumentException] { + Weighting.addDistanceBandColumn(getData(), 2.0, geometry = "non_existent") + } + } + } +} diff --git a/spark/common/src/test/scala/org/apache/sedona/stats/hotspotDetection/GetisOrdTest.scala b/spark/common/src/test/scala/org/apache/sedona/stats/hotspotDetection/GetisOrdTest.scala new file mode 100644 index 0000000000..0e93195b4a --- /dev/null +++ b/spark/common/src/test/scala/org/apache/sedona/stats/hotspotDetection/GetisOrdTest.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats.hotspotDetection + +import org.apache.sedona.sql.TestBaseScala +import org.apache.sedona.stats.Weighting +import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint +import org.apache.spark.sql.{DataFrame, functions => f} + +class GetisOrdTest extends TestBaseScala { + case class Point(id: Int, x: Double, y: Double, v: Double) + + def get_data(): DataFrame = { + sparkSession + .createDataFrame( + Seq( + Point(0, 2.0, 2.0, 0.9), + Point(1, 2.0, 3.0, 1.2), + Point(2, 3.0, 3.0, 1.2), + Point(3, 3.0, 2.0, 1.2), + Point(4, 3.0, 1.0, 1.2), + Point(5, 2.0, 1.0, 2.2), + Point(6, 1.0, 1.0, 1.2), + Point(7, 1.0, 2.0, 0.2), + Point(8, 1.0, 3.0, 1.2), + Point(9, 0.0, 2.0, 1.0), + Point(10, 4.0, 2.0, 1.2))) + .withColumn("geometry", ST_MakePoint("x", "y")) + .drop("x", "y") + } + + describe("glocal") { + it("returns one row per input row with expected columns, binary") { + val distanceBandedDf = Weighting.addDistanceBandColumn( + get_data(), + 1.0, + includeZeroDistanceNeighbors = true, + includeSelf = true, + geometry = "geometry") + + val actualResults = + GetisOrd.gLocal(distanceBandedDf, "v", "weights", star = true) + + val expectedColumnNames = + Array("id", "v", "geometry", "weights", "G", "EG", "VG", "Z", "P").sorted + + assert(actualResults.count() == 11) + assert(actualResults.columns.sorted === expectedColumnNames) + + for (columnName <- expectedColumnNames) { + assert(actualResults.filter(f.col(columnName).isNull).count() == 0) + } + } + + it("returns one row per input row with expected columns, idw") { + + val distanceBandedDf = + Weighting.addDistanceBandColumn(get_data(), 1.0, binary = false, alpha = -.5) + + val expectedResults = GetisOrd.gLocal(distanceBandedDf, "v", "weights") + + assert(expectedResults.count() == 11) + assert( + expectedResults.columns.sorted === Array( + "id", + "v", + "geometry", + "weights", + "G", + "EG", + "VG", + "Z", + "P").sorted) + } + } +}