Skip to content

Commit

Permalink
feat: adding window based clumping to StudyLocus (#779)
Browse files Browse the repository at this point in the history
* feat: adding window based clumping to locus

* fix: reverting some changes

* chore: pre-commit auto fixes [...]

* fix: fixing probem introduced by merge conflict

* fix: addressing review comment

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
DSuveges and pre-commit-ci[bot] committed Sep 24, 2024
1 parent dcacaf7 commit df45a6c
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 40 deletions.
42 changes: 27 additions & 15 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
order_array_of_structs_by_field,
)
from gentropy.common.utils import get_logsum
from gentropy.config import WindowBasedClumpingStepConfig
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.study_locus_overlap import StudyLocusOverlap
from gentropy.dataset.variant_index import VariantIndex
Expand Down Expand Up @@ -45,7 +46,8 @@ class StudyLocusQualityCheck(Enum):
PALINDROMIC_ALLELE_FLAG (str): Alleles are palindromic - cannot harmonize
AMBIGUOUS_STUDY (str): Association with ambiguous study
UNRESOLVED_LD (str): Variant not found in LD reference
LD_CLUMPED (str): Explained by a more significant variant in high LD (clumped)
LD_CLUMPED (str): Explained by a more significant variant in high LD
WINDOW_CLUMPED (str): Explained by a more significant variant in the same window
NO_POPULATION (str): Study does not have population annotation to resolve LD
NOT_QUALIFYING_LD_BLOCK (str): LD block does not contain variants at the required R^2 threshold
FAILED_STUDY (str): Flagging study loci if the study has failed QC
Expand All @@ -65,7 +67,8 @@ class StudyLocusQualityCheck(Enum):
PALINDROMIC_ALLELE_FLAG = "Palindrome alleles - cannot harmonize"
AMBIGUOUS_STUDY = "Association with ambiguous study"
UNRESOLVED_LD = "Variant not found in LD reference"
LD_CLUMPED = "Explained by a more significant variant in high LD (clumped)"
LD_CLUMPED = "Explained by a more significant variant in high LD"
WINDOW_CLUMPED = "Explained by a more significant variant in the same window"
NO_POPULATION = "Study does not have population annotation to resolve LD"
NOT_QUALIFYING_LD_BLOCK = (
"LD block does not contain variants at the required R^2 threshold"
Expand Down Expand Up @@ -168,9 +171,9 @@ def annotate_study_type(self: StudyLocus, study_index: StudyIndex) -> StudyLocus
"""
return StudyLocus(
_df=(
self.df
.drop("studyType")
.join(study_index.study_type_lut(), on="studyId", how="left")
self.df.drop("studyType").join(
study_index.study_type_lut(), on="studyId", how="left"
)
),
_schema=self.get_schema(),
)
Expand Down Expand Up @@ -524,9 +527,7 @@ def get_QC_mappings(cls: type[StudyLocus]) -> dict[str, str]:
"""
return {member.name: member.value for member in StudyLocusQualityCheck}

def filter_by_study_type(
self: StudyLocus, study_type: str
) -> StudyLocus:
def filter_by_study_type(self: StudyLocus, study_type: str) -> StudyLocus:
"""Creates a new StudyLocus dataset filtered by study type.
Args:
Expand All @@ -542,11 +543,7 @@ def filter_by_study_type(
raise ValueError(
f"Study type {study_type} not supported. Supported types are: gwas, eqtl, pqtl, sqtl."
)
new_df = (
self.df
.filter(f.col("studyType") == study_type)
.drop("studyType")
)
new_df = self.df.filter(f.col("studyType") == study_type).drop("studyType")
return StudyLocus(
_df=new_df,
_schema=self._schema,
Expand Down Expand Up @@ -609,8 +606,7 @@ def find_overlaps(
StudyLocusOverlap: Pairs of overlapping study-locus with aligned tags.
"""
loci_to_overlap = (
self.df
.filter(f.col("studyType").isNotNull())
self.df.filter(f.col("studyType").isNotNull())
.withColumn("locus", f.explode("locus"))
.select(
"studyLocusId",
Expand Down Expand Up @@ -1051,3 +1047,19 @@ def annotate_locus_statistics_boundaries(
)

return self

def window_based_clumping(
self: StudyLocus,
window_size: int = WindowBasedClumpingStepConfig().distance,
) -> StudyLocus:
"""Clump study locus by window size.
Args:
window_size (int): Window size for clumping.
Returns:
StudyLocus: Clumped study locus, where clumped associations are flagged.
"""
from gentropy.method.window_based_clumping import WindowBasedClumping

return WindowBasedClumping.clump(self, window_size)
7 changes: 4 additions & 3 deletions src/gentropy/dataset/summary_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ def window_based_clumping(
from gentropy.method.window_based_clumping import WindowBasedClumping

return WindowBasedClumping.clump(
self,
# Before clumping, we filter the summary statistics by p-value:
self.pvalue_filter(gwas_significance),
distance=distance,
gwas_significance=gwas_significance,
)
# After applying the clumping, we filter the clumped loci by the flag:
).valid_rows(["WINDOW_CLUMPED"])

def locus_breaker_clumping(
self: SummaryStatistics,
Expand Down
10 changes: 9 additions & 1 deletion src/gentropy/gwas_catalog_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from gentropy.common.session import Session
from gentropy.config import WindowBasedClumpingStepConfig
from gentropy.dataset.variant_index import VariantIndex
from gentropy.datasource.gwas_catalog.associations import (
GWASCatalogCuratedAssociationsParser,
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
gnomad_variant_path: str,
catalog_studies_out: str,
catalog_associations_out: str,
distance: int = WindowBasedClumpingStepConfig().distance,
gwas_catalog_study_curation_file: str | None = None,
inclusion_list_path: str | None = None,
) -> None:
Expand All @@ -44,6 +46,7 @@ def __init__(
gnomad_variant_path (str): Path to GnomAD variants.
catalog_studies_out (str): Output GWAS catalog studies path.
catalog_associations_out (str): Output GWAS catalog associations path.
distance (int): Distance, within which tagging variants are collected around the semi-index.
gwas_catalog_study_curation_file (str | None): file of the curation table. Optional.
inclusion_list_path (str | None): optional inclusion list (parquet)
"""
Expand Down Expand Up @@ -86,4 +89,9 @@ def __init__(

# Load
study_index.df.write.mode(session.write_mode).parquet(catalog_studies_out)
study_locus.df.write.mode(session.write_mode).parquet(catalog_associations_out)

(
study_locus.window_based_clumping(distance)
.df.write.mode(session.write_mode)
.parquet(catalog_associations_out)
)
51 changes: 30 additions & 21 deletions src/gentropy/method/window_based_clumping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyspark.sql.window import Window

from gentropy.config import WindowBasedClumpingStepConfig
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.study_locus import StudyLocus, StudyLocusQualityCheck

if TYPE_CHECKING:
from numpy.typing import NDArray
Expand Down Expand Up @@ -154,22 +154,38 @@ def _prune_peak(position: NDArray[np.float64], window_size: int) -> DenseVector:

@staticmethod
def clump(
summary_statistics: SummaryStatistics,
unclumped_associations: SummaryStatistics | StudyLocus,
distance: int = WindowBasedClumpingStepConfig().distance,
gwas_significance: float = WindowBasedClumpingStepConfig().gwas_significance,
) -> StudyLocus:
"""Clump significant signals from summary statistics based on window.
"""Clump single point associations from summary statistics or study locus dataset based on window.
Args:
summary_statistics (SummaryStatistics): Summary statistics to be used for clumping.
unclumped_associations (SummaryStatistics | StudyLocus): Input dataset to be used for clumping. Assumes that the input dataset is already filtered for significant variants.
distance (int): Distance in base pairs to be used for clumping. Defaults to 500_000.
gwas_significance (float): GWAS significance threshold. Defaults to 5e-8.
Returns:
StudyLocus: clumped summary statistics (without locus collection)
Check WindowBasedClumpingStepConfig object for default values
StudyLocus: clumped associations, where the clumped variants are flagged.
"""
# Quality check expression that flags variants that are not considered lead variant:
qc_check = f.col("semiIndices")[f.col("pvRank") - 1] <= 0

# The quality control expression will depend on the input dataset, as the column might be already present:
qc_expression = (
# When the column is already present and the condition is met, the value is appended to the array, otherwise keep as is:
f.when(
qc_check,
f.array_union(
f.col("qualityControls"),
f.array(f.lit(StudyLocusQualityCheck.WINDOW_CLUMPED.value)),
),
).otherwise(f.col("qualityControls"))
if "qualityControls" in unclumped_associations.df.columns
# If column is not there yet, initialize it with the flag value, or an empty array:
else f.when(
qc_check, f.array(f.lit(StudyLocusQualityCheck.WINDOW_CLUMPED.value))
).otherwise(f.array().cast(t.ArrayType(t.StringType())))
)

# Create window for locus clusters
# - variants where the distance between subsequent variants is below the defined threshold.
# - Variants are sorted by descending significance
Expand All @@ -179,11 +195,8 @@ def clump(

return StudyLocus(
_df=(
summary_statistics
# Dropping snps below significance - all subsequent steps are done on significant variants:
.pvalue_filter(gwas_significance)
.df
# Clustering summary variants for efficient windowing (complexity reduction):
unclumped_associations.df
# Clustering variants for efficient windowing (complexity reduction):
.withColumn(
"cluster_id",
WindowBasedClumping._cluster_peaks(
Expand All @@ -207,7 +220,7 @@ def clump(
),
).otherwise(f.array()),
)
# Get semi indices only ONCE per cluster:
# Collect top loci per cluster:
.withColumn(
"semiIndices",
f.when(
Expand All @@ -230,9 +243,6 @@ def clump(
),
).otherwise(f.col("semiIndices")),
)
# Keeping semi indices only:
.filter(f.col("semiIndices")[f.col("pvRank") - 1] > 0)
.drop("pvRank", "collectedPositions", "semiIndices", "cluster_id")
# Adding study-locus id:
.withColumn(
"studyLocusId",
Expand All @@ -241,9 +251,8 @@ def clump(
),
)
# Initialize QC column as array of strings:
.withColumn(
"qualityControls", f.array().cast(t.ArrayType(t.StringType()))
)
.withColumn("qualityControls", qc_expression)
.drop("pvRank", "collectedPositions", "semiIndices", "cluster_id")
),
_schema=StudyLocus.get_schema(),
)
68 changes: 68 additions & 0 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,74 @@ def test_study_validation_correctness(self: TestStudyLocusValidation) -> None:
) == 1


class TestStudyLocusWindowClumping:
"""Testing window-based clumping on study locus."""

TEST_DATASET = [
("s1", "c1", 1, -1),
("s1", "c1", 2, -2),
("s1", "c1", 3, -3),
("s2", "c2", 2, -2),
("s3", "c2", 2, -2),
]

TEST_SCHEMA = t.StructType(
[
t.StructField("studyId", t.StringType(), False),
t.StructField("chromosome", t.StringType(), False),
t.StructField("position", t.IntegerType(), False),
t.StructField("pValueExponent", t.IntegerType(), False),
]
)

@pytest.fixture(autouse=True)
def _setup(self: TestStudyLocusWindowClumping, spark: SparkSession) -> None:
"""Setup study locus for testing."""
self.study_locus = StudyLocus(
_df=(
spark.createDataFrame(
self.TEST_DATASET, schema=self.TEST_SCHEMA
).withColumns(
{
"studyLocusId": f.monotonically_increasing_id().cast(
t.LongType()
),
"pValueMantissa": f.lit(1).cast(t.FloatType()),
"variantId": f.concat(
f.lit("v"),
f.monotonically_increasing_id().cast(t.StringType()),
),
}
)
),
_schema=StudyLocus.get_schema(),
)

def test_clump_return_type(self: TestStudyLocusWindowClumping) -> None:
"""Testing if the clumping returns the right type."""
assert isinstance(self.study_locus.window_based_clumping(3), StudyLocus)

def test_clump_no_data_loss(self: TestStudyLocusWindowClumping) -> None:
"""Testing if the clumping returns same number of rows."""
assert (
self.study_locus.window_based_clumping(3).df.count()
== self.study_locus.df.count()
)

def test_correct_flag(self: TestStudyLocusWindowClumping) -> None:
"""Testing if the clumping flags are for variants."""
assert (
self.study_locus.window_based_clumping(3)
.df.filter(
f.array_contains(
f.col("qualityControls"),
StudyLocusQualityCheck.WINDOW_CLUMPED.value,
)
)
.count()
) == 2


def test_build_feature_matrix(
mock_study_locus: StudyLocus,
mock_colocalisation: Colocalisation,
Expand Down

0 comments on commit df45a6c

Please sign in to comment.