Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(method): improved performance in coloc tests #536

Merged
merged 9 commits into from
Mar 21, 2024
44 changes: 44 additions & 0 deletions tests/gentropy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

import dbldatagen as dg
import hail as hl
Expand Down Expand Up @@ -647,3 +648,46 @@ def sample_data_for_susie_inf() -> list[np.ndarray]:
lbf_moments = np.loadtxt("tests/gentropy/data_samples/01_test_lbf_moments.csv")
lbf_mle = np.loadtxt("tests/gentropy/data_samples/01_test_lbf_mle.csv")
return [ld, z, lbf_moments, lbf_mle]


@pytest.fixture()
def sample_data_for_coloc(spark: SparkSession) -> list[Any]:
"""Sample data for Coloc tests."""
overlap_df = spark.read.parquet(
"tests/gentropy/data_samples/coloc_test_data.snappy.parquet"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this file generated? For semantic tests, it's easier to understand if you create a data subset in the testing module directly.
Instead of reading a file of 500 rows, create a dataframe with 2 overlapping variants, for example.

The same testing function can be parametrised for both scenarios: associations that overlap on multiple SNPs, and on a single SNP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was directly extracted from the test dataset from the R package

)
expected_df = spark.createDataFrame(
[
{
"h0": 1.3769995397857477e-18,
"h1": 2.937336451601565e-10,
"h2": 8.593226431647826e-12,
"h3": 8.338916748775843e-4,
"h4": 0.9991661080227981,
}
]
)
single_snp_coloc = spark.createDataFrame(
[
{
"leftStudyLocusId": 1,
"rightStudyLocusId": 2,
"chromosome": "1",
"tagVariantId": "snp",
"left_logBF": 10.3,
"right_logBF": 10.5,
}
]
)
expected_single_snp_coloc = spark.createDataFrame(
[
{
"h0": 9.254841951638903e-5,
"h1": 2.7517068829182966e-4,
"h2": 3.3609423764447284e-4,
"h3": 9.254841952564387e-13,
"h4": 0.9992961866536217,
}
]
)
return [overlap_df, expected_df, single_snp_coloc, expected_single_snp_coloc]
147 changes: 13 additions & 134 deletions tests/gentropy/method/test_colocalisation_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from __future__ import annotations

from typing import Any

from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.study_locus_overlap import StudyLocusOverlap
from gentropy.method.colocalisation import Coloc, ECaviar
from pyspark.sql import SparkSession
from pyspark.sql import functions as f


Expand All @@ -15,161 +16,39 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None:


def test_coloc_colocalise(
spark: SparkSession,
threshold: float = 1e-5,
sample_data_for_coloc: list[Any],
threshold: float = 1e-4,
) -> None:
"""Compare COLOC results with R implementation, using provided sample dataset from R package (StudyLocusOverlap)."""
test_overlap_df = spark.read.parquet(
"tests/gentropy/data_samples/coloc_test_data.snappy.parquet", header=True
test_overlap_df = sample_data_for_coloc[0]
test_overlap = StudyLocusOverlap(
_df=test_overlap_df, _schema=StudyLocusOverlap.get_schema()
)
test_overlap = StudyLocusOverlap(test_overlap_df, StudyLocusOverlap.get_schema())
test_result = Coloc.colocalise(test_overlap)

expected = spark.createDataFrame(
[
{
"h0": 1.3769995397857477e-18,
"h1": 2.937336451601565e-10,
"h2": 8.593226431647826e-12,
"h3": 8.338916748775843e-4,
"h4": 0.9991661080227981,
}
]
)
expected = sample_data_for_coloc[1]
difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected)
for col in difference.columns:
assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0


def test_single_snp_coloc(
spark: SparkSession,
sample_data_for_coloc: list[Any],
threshold: float = 1e-5,
) -> None:
"""Test edge case of coloc where only one causal SNP is present in the StudyLocusOverlap."""
test_overlap_df = spark.createDataFrame(
[
{
"leftStudyLocusId": 1,
"rightStudyLocusId": 2,
"chromosome": "1",
"tagVariantId": "snp",
"left_logBF": 10.3,
"right_logBF": 10.5,
}
]
)
test_overlap = StudyLocusOverlap(
test_overlap_df.select(
"leftStudyLocusId",
"rightStudyLocusId",
"chromosome",
"tagVariantId",
f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"),
),
StudyLocusOverlap.get_schema(),
)
test_result = Coloc.colocalise(test_overlap)

expected = spark.createDataFrame(
[
{
"h0": 9.254841951638903e-5,
"h1": 2.7517068829182966e-4,
"h2": 3.3609423764447284e-4,
"h3": 9.254841952564387e-13,
"h4": 0.9992961866536217,
}
]
)
difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected)
for col in difference.columns:
assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0


def test_single_snp_coloc_one_negative(
spark: SparkSession,
threshold: float = 1e-5,
) -> None:
"""Test edge case of coloc where only one causal SNP is present (On one side!) in the StudyLocusOverlap."""
test_overlap_df = spark.createDataFrame(
[
{
"leftStudyLocusId": 1,
"rightStudyLocusId": 2,
"chromosome": "1",
"tagVariantId": "snp",
"left_logBF": 18.3,
"right_logBF": 0.01,
}
]
)
test_overlap = StudyLocusOverlap(
test_overlap_df.select(
"leftStudyLocusId",
"rightStudyLocusId",
"chromosome",
"tagVariantId",
f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"),
),
StudyLocusOverlap.get_schema(),
)
test_result = Coloc.colocalise(test_overlap)
test_result.df.show(1, False)
expected = spark.createDataFrame(
[
{
"h0": 1.0246538505087709e-4,
"h1": 0.9081680002273896,
"h2": 1.0349517929098209e-8,
"h3": 1.0246538506112363e-12,
"h4": 0.09172952403701702,
}
]
)
difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected)
for col in difference.columns:
assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0


def test_single_snp_coloc_both_negative(
spark: SparkSession,
threshold: float = 1e-5,
) -> None:
"""Test edge case of coloc where only one non-causal SNP overlaps in the StudyLocusOverlap."""
test_overlap_df = spark.createDataFrame(
[
{
"leftStudyLocusId": 1,
"rightStudyLocusId": 2,
"chromosome": "1",
"tagVariantId": "snp",
"left_logBF": 0.03,
"right_logBF": 0.01,
}
]
)
test_overlap_df = sample_data_for_coloc[2]
test_overlap = StudyLocusOverlap(
test_overlap_df.select(
_df=test_overlap_df.select(
"leftStudyLocusId",
"rightStudyLocusId",
"chromosome",
"tagVariantId",
f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"),
),
StudyLocusOverlap.get_schema(),
_schema=StudyLocusOverlap.get_schema(),
)
test_result = Coloc.colocalise(test_overlap)
expected = spark.createDataFrame(
[
{
"h0": 0.9997855774090624,
"h1": 1.0302335812225042e-4,
"h2": 1.0098335895103664e-4,
"h3": 9.9978557750904e-9,
"h4": 1.0405876008495098e-5,
}
]
)
expected = sample_data_for_coloc[3]
difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected)
for col in difference.columns:
assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0
Expand Down