Skip to content

Commit

Permalink
perf(clump): refactored window based clumping (#492)
Browse files Browse the repository at this point in the history
* perf: refactored window based clumping

* fix: gwas catalog clump step adjusted

* fix: broadcast logic

* revert: coalesce instruction

* fix: add alias

* docs: enhance docs

* test: accidentally removed test

* test: rescue missing test

* chore: remove unused step

* refactor: increased modularisation

* feat: restructure data model

* test: structural test for annotate function

* fix: prevent locus column could be duplicated

* chore: up-to-date with pre-commit
  • Loading branch information
d0choa authored Mar 20, 2024
1 parent 160051c commit ad50c15
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 214 deletions.
2 changes: 1 addition & 1 deletion docs/src_snippets/howto/python_api/c_applying_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def apply_class_method_clumping(summary_stats: SummaryStatistics) -> StudyLocus:
from gentropy.method.window_based_clumping import WindowBasedClumping

clumped_summary_statistics = WindowBasedClumping.clump(
summary_stats, window_length=500_000
summary_stats, distance=250_000
)
# --8<-- [end:apply_class_method_clumping]
return clumped_summary_statistics
Expand Down
70 changes: 0 additions & 70 deletions src/gentropy/clump.py

This file was deleted.

5 changes: 3 additions & 2 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,10 @@ class WindowBasedClumpingStep(StepConfig):

summary_statistics_input_path: str = MISSING
study_locus_output_path: str = MISSING
distance: int = 500_000
collect_locus: bool = False
collect_locus_distance: int = 500_000
inclusion_list_path: str | None = None
locus_collect_distance: str | None = None

_target_: str = "gentropy.window_based_clumping.WindowBasedClumpingStep"


Expand Down
70 changes: 70 additions & 0 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Study locus dataset."""

from __future__ import annotations

from dataclasses import dataclass
Expand All @@ -24,6 +25,7 @@

from gentropy.dataset.ld_index import LDIndex
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.summary_statistics import SummaryStatistics


class StudyLocusQualityCheck(Enum):
Expand Down Expand Up @@ -427,6 +429,74 @@ def annotate_credible_sets(self: StudyLocus) -> StudyLocus:
)
return self

def annotate_locus_statistics(
self: StudyLocus,
summary_statistics: SummaryStatistics,
collect_locus_distance: int,
) -> StudyLocus:
"""Annotates study locus with summary statistics in the specified distance around the position.
Args:
summary_statistics (SummaryStatistics): Summary statistics to be used for annotation.
collect_locus_distance (int): distance from variant defining window for inclusion of variants in locus.
Returns:
StudyLocus: Study locus annotated with summary statistics in `locus` column. If no statistics are found, the `locus` column will be empty.
"""
# The clumps will be used several times (persisting)
self.df.persist()
# Renaming columns:
sumstats_renamed = summary_statistics.df.selectExpr(
*[f"{col} as tag_{col}" for col in summary_statistics.df.columns]
).alias("sumstat")

locus_df = (
sumstats_renamed
# Joining the two datasets together:
.join(
f.broadcast(
self.df.alias("clumped").select(
"position", "chromosome", "studyId", "studyLocusId"
)
),
on=[
(f.col("sumstat.tag_studyId") == f.col("clumped.studyId"))
& (f.col("sumstat.tag_chromosome") == f.col("clumped.chromosome"))
& (
f.col("sumstat.tag_position")
>= (f.col("clumped.position") - collect_locus_distance)
)
& (
f.col("sumstat.tag_position")
<= (f.col("clumped.position") + collect_locus_distance)
)
],
how="inner",
)
.withColumn(
"locus",
f.struct(
f.col("tag_variantId").alias("variantId"),
f.col("tag_beta").alias("beta"),
f.col("tag_pValueMantissa").alias("pValueMantissa"),
f.col("tag_pValueExponent").alias("pValueExponent"),
f.col("tag_standardError").alias("standardError"),
),
)
.groupBy("studyLocusId")
.agg(
f.collect_list(f.col("locus")).alias("locus"),
)
)

self.df = self.df.drop("locus").join(
locus_df,
on="studyLocusId",
how="left",
)

return self

def annotate_ld(
self: StudyLocus, study_index: StudyIndex, ld_index: LDIndex
) -> StudyLocus:
Expand Down
31 changes: 10 additions & 21 deletions src/gentropy/dataset/summary_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.utils import parse_region, split_pvalue
from gentropy.dataset.dataset import Dataset
from gentropy.method.window_based_clumping import WindowBasedClumping

if TYPE_CHECKING:
from pyspark.sql.types import StructType
Expand Down Expand Up @@ -59,34 +58,24 @@ def window_based_clumping(
self: SummaryStatistics,
distance: int = 500_000,
gwas_significance: float = 5e-8,
baseline_significance: float = 0.05,
locus_collect_distance: int | None = None,
) -> StudyLocus:
"""Generate study-locus from summary statistics by distance based clumping + collect locus.
"""Generate study-locus from summary statistics using window-based clumping.
For more info, see [`WindowBasedClumping`][gentropy.method.window_based_clumping.WindowBasedClumping]
Args:
distance (int): Distance in base pairs to be used for clumping. Defaults to 500_000.
gwas_significance (float, optional): GWAS significance threshold. Defaults to 5e-8.
baseline_significance (float, optional): Baseline significance threshold for inclusion in the locus. Defaults to 0.05.
locus_collect_distance (int | None): The distance to collect locus around semi-indices. If not provided, locus is not collected.
Returns:
StudyLocus: Clumped study-locus containing variants based on window.
StudyLocus: Clumped study-locus optionally containing variants based on window.
"""
return (
WindowBasedClumping.clump_with_locus(
self,
window_length=distance,
p_value_significance=gwas_significance,
p_value_baseline=baseline_significance,
locus_window_length=locus_collect_distance,
)
if locus_collect_distance
else WindowBasedClumping.clump(
self,
window_length=distance,
p_value_significance=gwas_significance,
)
from gentropy.method.window_based_clumping import WindowBasedClumping

return WindowBasedClumping.clump(
self,
distance=distance,
gwas_significance=gwas_significance,
)

def exclude_region(self: SummaryStatistics, region: str) -> SummaryStatistics:
Expand Down
115 changes: 13 additions & 102 deletions src/gentropy/method/window_based_clumping.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,21 @@ def _prune_peak(position: NDArray[np.float64], window_size: int) -> DenseVector:

return DenseVector(is_lead)

@classmethod
@staticmethod
def clump(
cls: type[WindowBasedClumping],
summary_stats: SummaryStatistics,
window_length: int,
p_value_significance: float = 5e-8,
summary_statistics: SummaryStatistics,
distance: int = 500_000,
gwas_significance: float = 5e-8,
) -> StudyLocus:
"""Clump summary statistics by distance.
"""Clump significant signals from summary statistics based on window.
Args:
summary_stats (SummaryStatistics): summary statistics to clump
window_length (int): window length in basepair
p_value_significance (float): only more significant variants are considered
summary_statistics (SummaryStatistics): Summary statistics to be used for clumping.
distance (int): Distance in base pairs to be used for clumping. Defaults to 500_000.
gwas_significance (float): GWAS significance threshold. Defaults to 5e-8.
Returns:
StudyLocus: clumped summary statistics
StudyLocus: clumped summary statistics (without locus collection)
"""
# Create window for locus clusters
# - variants where the distance between subsequent variants is below the defined threshold.
Expand All @@ -177,9 +176,9 @@ def clump(

return StudyLocus(
_df=(
summary_stats
summary_statistics
# Dropping snps below significance - all subsequent steps are done on significant variants:
.pvalue_filter(p_value_significance)
.pvalue_filter(gwas_significance)
.df
# Clustering summary variants for efficient windowing (complexity reduction):
.withColumn(
Expand All @@ -188,7 +187,7 @@ def clump(
f.col("studyId"),
f.col("chromosome"),
f.col("position"),
window_length,
distance,
),
)
# Within each cluster variants are ranked by significance:
Expand All @@ -213,7 +212,7 @@ def clump(
fml.vector_to_array(
f.udf(WindowBasedClumping._prune_peak, VectorUDT())(
fml.array_to_vector(f.col("collectedPositions")),
f.lit(window_length),
f.lit(distance),
)
),
),
Expand Down Expand Up @@ -245,91 +244,3 @@ def clump(
),
_schema=StudyLocus.get_schema(),
)

@classmethod
def clump_with_locus(
cls: type[WindowBasedClumping],
summary_stats: SummaryStatistics,
window_length: int,
p_value_significance: float = 5e-8,
p_value_baseline: float = 0.05,
locus_window_length: int | None = None,
) -> StudyLocus:
"""Clump significant associations while collecting locus around them.
Args:
summary_stats (SummaryStatistics): Input summary statistics dataset
window_length (int): Window size in bp, used for distance based clumping.
p_value_significance (float): GWAS significance threshold used to filter peaks. Defaults to 5e-8.
p_value_baseline (float): Least significant threshold. Below this, all snps are dropped. Defaults to 0.05.
locus_window_length (int | None): The distance for collecting locus around the semi indices. Defaults to None.
Returns:
StudyLocus: StudyLocus after clumping with information about the `locus`
"""
# If no locus window provided, using the same value:
if locus_window_length is None:
locus_window_length = window_length

# Run distance based clumping on the summary stats:
clumped_dataframe = WindowBasedClumping.clump(
summary_stats,
window_length=window_length,
p_value_significance=p_value_significance,
).df.alias("clumped")

# Get list of columns from clumped dataset for further propagation:
clumped_columns = clumped_dataframe.columns

# Dropping variants not meeting the baseline criteria:
sumstats_baseline = summary_stats.pvalue_filter(p_value_baseline).df

# Renaming columns:
sumstats_baseline_renamed = sumstats_baseline.selectExpr(
*[f"{col} as tag_{col}" for col in sumstats_baseline.columns]
).alias("sumstat")

study_locus_df = (
sumstats_baseline_renamed
# Joining the two datasets together:
.join(
f.broadcast(clumped_dataframe),
on=[
(f.col("sumstat.tag_studyId") == f.col("clumped.studyId"))
& (f.col("sumstat.tag_chromosome") == f.col("clumped.chromosome"))
& (
f.col("sumstat.tag_position")
>= (f.col("clumped.position") - locus_window_length)
)
& (
f.col("sumstat.tag_position")
<= (f.col("clumped.position") + locus_window_length)
)
],
how="right",
)
.withColumn(
"locus",
f.struct(
f.col("tag_variantId").alias("variantId"),
f.col("tag_beta").alias("beta"),
f.col("tag_pValueMantissa").alias("pValueMantissa"),
f.col("tag_pValueExponent").alias("pValueExponent"),
f.col("tag_standardError").alias("standardError"),
),
)
.groupby("studyLocusId")
.agg(
*[
f.first(col).alias(col)
for col in clumped_columns
if col != "studyLocusId"
],
f.collect_list(f.col("locus")).alias("locus"),
)
)

return StudyLocus(
_df=study_locus_df,
_schema=StudyLocus.get_schema(),
)
Loading

0 comments on commit ad50c15

Please sign in to comment.