diff --git a/tests/gentropy/data_samples/coloc_test_data.snappy.parquet b/tests/gentropy/data_samples/coloc_test_data.snappy.parquet deleted file mode 100644 index 71b3913eb..000000000 Binary files a/tests/gentropy/data_samples/coloc_test_data.snappy.parquet and /dev/null differ diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index f90d54b3f..37613354c 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -2,11 +2,14 @@ from __future__ import annotations +from typing import Any + +import pytest from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.method.colocalisation import Coloc, ECaviar +from pandas.testing import assert_frame_equal from pyspark.sql import SparkSession -from pyspark.sql import functions as f def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: @@ -14,165 +17,91 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: assert isinstance(Coloc.colocalise(mock_study_locus_overlap), Colocalisation) -def test_coloc_colocalise( - spark: SparkSession, - threshold: float = 1e-5, -) -> None: - """Compare COLOC results with R implementation, using provided sample dataset from R package (StudyLocusOverlap).""" - test_overlap_df = spark.read.parquet( - "tests/gentropy/data_samples/coloc_test_data.snappy.parquet", header=True - ) - test_overlap = StudyLocusOverlap(test_overlap_df, StudyLocusOverlap.get_schema()) - test_result = Coloc.colocalise(test_overlap) - - expected = spark.createDataFrame( - [ - { - "h0": 1.3769995397857477e-18, - "h1": 2.937336451601565e-10, - "h2": 8.593226431647826e-12, - "h3": 8.338916748775843e-4, - "h4": 0.9991661080227981, - } - ] - ) - difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) - for col in difference.columns: - assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 - - -def test_single_snp_coloc( - spark: SparkSession, - threshold: float = 1e-5, -) -> None: - """Test edge case of coloc where only one causal SNP is present in the StudyLocusOverlap.""" - test_overlap_df = spark.createDataFrame( - [ - { - "leftStudyLocusId": 1, - "rightStudyLocusId": 2, - "chromosome": "1", - "tagVariantId": "snp", - "left_logBF": 10.3, - "right_logBF": 10.5, - } - ] - ) - test_overlap = StudyLocusOverlap( - test_overlap_df.select( - "leftStudyLocusId", - "rightStudyLocusId", - "chromosome", - "tagVariantId", - f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), +@pytest.mark.parametrize( + "observed_data, expected_data", + [ + # associations with a single overlapping SNP + ( + # observed overlap + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp", + "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, + }, + ], + # expected coloc + [ + { + "h0": 9.254841951638903e-5, + "h1": 2.7517068829182966e-4, + "h2": 3.3609423764447284e-4, + "h3": 9.254841952564387e-13, + "h4": 0.9992961866536217, + }, + ], ), - StudyLocusOverlap.get_schema(), - ) - test_result = Coloc.colocalise(test_overlap) - - expected = spark.createDataFrame( - [ - { - "h0": 9.254841951638903e-5, - "h1": 2.7517068829182966e-4, - "h2": 3.3609423764447284e-4, - "h3": 9.254841952564387e-13, - "h4": 0.9992961866536217, - } - ] - ) - difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) - for col in difference.columns: - assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 - - -def test_single_snp_coloc_one_negative( + # associations with multiple overlapping SNPs + ( + # observed overlap + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, + }, + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, + }, + ], + # expected coloc + [ + { + "h0": 4.6230151407950416e-5, + "h1": 2.749086942648107e-4, + "h2": 3.357742374172504e-4, + "h3": 9.983447421747411e-4, + "h4": 0.9983447421747356, + }, + ], + ), + ], +) +def test_coloc_semantic( spark: SparkSession, - threshold: float = 1e-5, + observed_data: list[Any], + expected_data: list[Any], ) -> None: - """Test edge case of coloc where only one causal SNP is present (On one side!) in the StudyLocusOverlap.""" - test_overlap_df = spark.createDataFrame( - [ - { - "leftStudyLocusId": 1, - "rightStudyLocusId": 2, - "chromosome": "1", - "tagVariantId": "snp", - "left_logBF": 18.3, - "right_logBF": 0.01, - } - ] + """Test our COLOC with the implementation in R.""" + observed_overlap = StudyLocusOverlap( + _df=spark.createDataFrame(observed_data, schema=StudyLocusOverlap.get_schema()), + _schema=StudyLocusOverlap.get_schema(), ) - test_overlap = StudyLocusOverlap( - test_overlap_df.select( - "leftStudyLocusId", - "rightStudyLocusId", - "chromosome", - "tagVariantId", - f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), - ), - StudyLocusOverlap.get_schema(), + observed_coloc_pdf = ( + Coloc.colocalise(observed_overlap) + .df.select("h0", "h1", "h2", "h3", "h4") + .toPandas() ) - test_result = Coloc.colocalise(test_overlap) - test_result.df.show(1, False) - expected = spark.createDataFrame( - [ - { - "h0": 1.0246538505087709e-4, - "h1": 0.9081680002273896, - "h2": 1.0349517929098209e-8, - "h3": 1.0246538506112363e-12, - "h4": 0.09172952403701702, - } - ] + expected_coloc_pdf = ( + spark.createDataFrame(expected_data) + .select("h0", "h1", "h2", "h3", "h4") + .toPandas() ) - difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) - for col in difference.columns: - assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 - -def test_single_snp_coloc_both_negative( - spark: SparkSession, - threshold: float = 1e-5, -) -> None: - """Test edge case of coloc where only one non-causal SNP overlaps in the StudyLocusOverlap.""" - test_overlap_df = spark.createDataFrame( - [ - { - "leftStudyLocusId": 1, - "rightStudyLocusId": 2, - "chromosome": "1", - "tagVariantId": "snp", - "left_logBF": 0.03, - "right_logBF": 0.01, - } - ] - ) - test_overlap = StudyLocusOverlap( - test_overlap_df.select( - "leftStudyLocusId", - "rightStudyLocusId", - "chromosome", - "tagVariantId", - f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), - ), - StudyLocusOverlap.get_schema(), - ) - test_result = Coloc.colocalise(test_overlap) - expected = spark.createDataFrame( - [ - { - "h0": 0.9997855774090624, - "h1": 1.0302335812225042e-4, - "h2": 1.0098335895103664e-4, - "h3": 9.9978557750904e-9, - "h4": 1.0405876008495098e-5, - } - ] + assert_frame_equal( + observed_coloc_pdf, + expected_coloc_pdf, + check_exact=False, + check_dtype=True, ) - difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) - for col in difference.columns: - assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 def test_ecaviar(mock_study_locus_overlap: StudyLocusOverlap) -> None: