From ad50c158d45004c36dbeee0860c9a96ea052f2d7 Mon Sep 17 00:00:00 2001 From: David Ochoa Date: Wed, 20 Mar 2024 10:34:56 +0000 Subject: [PATCH] perf(clump): refactored window based clumping (#492) * perf: refactored window based clumping * fix: gwas catalog clump step adjusted * fix: broadcast logic * revert: coalesce instruction * fix: add alias * docs: enhance docs * test: accidentally removed test * test: rescue missing test * chore: remove unused step * refactor: increased modularisation * feat: restructure data model * test: structural test for annotate function * fix: prevent locus column could be duplicated * chore: up-to-date with pre-commit --- .../howto/python_api/c_applying_methods.py | 2 +- src/gentropy/clump.py | 70 ----------- src/gentropy/config.py | 5 +- src/gentropy/dataset/study_locus.py | 70 +++++++++++ src/gentropy/dataset/summary_statistics.py | 31 ++--- src/gentropy/method/window_based_clumping.py | 115 ++---------------- src/gentropy/window_based_clumping.py | 39 +++--- tests/gentropy/dataset/test_study_locus.py | 11 ++ .../method/test_window_based_clumping.py | 18 ++- 9 files changed, 147 insertions(+), 214 deletions(-) delete mode 100644 src/gentropy/clump.py diff --git a/docs/src_snippets/howto/python_api/c_applying_methods.py b/docs/src_snippets/howto/python_api/c_applying_methods.py index 12eaf61ac..d0bec9edc 100644 --- a/docs/src_snippets/howto/python_api/c_applying_methods.py +++ b/docs/src_snippets/howto/python_api/c_applying_methods.py @@ -23,7 +23,7 @@ def apply_class_method_clumping(summary_stats: SummaryStatistics) -> StudyLocus: from gentropy.method.window_based_clumping import WindowBasedClumping clumped_summary_statistics = WindowBasedClumping.clump( - summary_stats, window_length=500_000 + summary_stats, distance=250_000 ) # --8<-- [end:apply_class_method_clumping] return clumped_summary_statistics diff --git a/src/gentropy/clump.py b/src/gentropy/clump.py deleted file mode 100644 index 9ea3306ac..000000000 --- a/src/gentropy/clump.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Step to run clump associations from summary statistics or study locus.""" -from __future__ import annotations - -from typing import Optional - -from gentropy.common.session import Session -from gentropy.dataset.ld_index import LDIndex -from gentropy.dataset.study_index import StudyIndex -from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.summary_statistics import SummaryStatistics - - -class ClumpStep: - """Perform clumping of an association dataset to identify independent signals. - - Two types of clumping are supported and are applied based on the input dataset: - - Clumping of summary statistics based on a window-based approach. - - Clumping of study locus based on LD. - - Both approaches yield a StudyLocus dataset. - """ - - def __init__( - self, - session: Session, - input_path: str, - clumped_study_locus_path: str, - study_index_path: Optional[str] = None, - ld_index_path: Optional[str] = None, - locus_collect_distance: Optional[int] = None, - ) -> None: - """Run the clumping step. - - Args: - session (Session): Session object. - input_path (str): Input path for the study locus or summary statistics files. - clumped_study_locus_path (str): Output path for the clumped study locus dataset. - study_index_path (Optional[str]): Input path for the study index dataset. - ld_index_path (Optional[str]): Input path for the LD index dataset. - locus_collect_distance (Optional[int]): Distance in base pairs to collect variants around the study locus. - - Raises: - ValueError: If study index and LD index paths are not provided for study locus. - """ - input_cols = session.spark.read.parquet( - input_path, recursiveFileLookup=True - ).columns - if "studyLocusId" in input_cols: - if study_index_path is None or ld_index_path is None: - raise ValueError( - "Study index and LD index paths are required for clumping study locus." - ) - study_locus = StudyLocus.from_parquet(session, input_path) - ld_index = LDIndex.from_parquet(session, ld_index_path) - study_index = StudyIndex.from_parquet(session, study_index_path) - - clumped_study_locus = study_locus.annotate_ld( - study_index=study_index, ld_index=ld_index - ).clump() - else: - sumstats = SummaryStatistics.from_parquet( - session, input_path, recursiveFileLookup=True - ).coalesce(4000) - clumped_study_locus = sumstats.window_based_clumping( - locus_collect_distance=locus_collect_distance - ) - - clumped_study_locus.df.write.mode(session.write_mode).parquet( - clumped_study_locus_path - ) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 6a2469c0c..127d90844 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -321,9 +321,10 @@ class WindowBasedClumpingStep(StepConfig): summary_statistics_input_path: str = MISSING study_locus_output_path: str = MISSING + distance: int = 500_000 + collect_locus: bool = False + collect_locus_distance: int = 500_000 inclusion_list_path: str | None = None - locus_collect_distance: str | None = None - _target_: str = "gentropy.window_based_clumping.WindowBasedClumpingStep" diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index 4ff9f0172..185aeb569 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -1,4 +1,5 @@ """Study locus dataset.""" + from __future__ import annotations from dataclasses import dataclass @@ -24,6 +25,7 @@ from gentropy.dataset.ld_index import LDIndex from gentropy.dataset.study_index import StudyIndex + from gentropy.dataset.summary_statistics import SummaryStatistics class StudyLocusQualityCheck(Enum): @@ -427,6 +429,74 @@ def annotate_credible_sets(self: StudyLocus) -> StudyLocus: ) return self + def annotate_locus_statistics( + self: StudyLocus, + summary_statistics: SummaryStatistics, + collect_locus_distance: int, + ) -> StudyLocus: + """Annotates study locus with summary statistics in the specified distance around the position. + + Args: + summary_statistics (SummaryStatistics): Summary statistics to be used for annotation. + collect_locus_distance (int): distance from variant defining window for inclusion of variants in locus. + + Returns: + StudyLocus: Study locus annotated with summary statistics in `locus` column. If no statistics are found, the `locus` column will be empty. + """ + # The clumps will be used several times (persisting) + self.df.persist() + # Renaming columns: + sumstats_renamed = summary_statistics.df.selectExpr( + *[f"{col} as tag_{col}" for col in summary_statistics.df.columns] + ).alias("sumstat") + + locus_df = ( + sumstats_renamed + # Joining the two datasets together: + .join( + f.broadcast( + self.df.alias("clumped").select( + "position", "chromosome", "studyId", "studyLocusId" + ) + ), + on=[ + (f.col("sumstat.tag_studyId") == f.col("clumped.studyId")) + & (f.col("sumstat.tag_chromosome") == f.col("clumped.chromosome")) + & ( + f.col("sumstat.tag_position") + >= (f.col("clumped.position") - collect_locus_distance) + ) + & ( + f.col("sumstat.tag_position") + <= (f.col("clumped.position") + collect_locus_distance) + ) + ], + how="inner", + ) + .withColumn( + "locus", + f.struct( + f.col("tag_variantId").alias("variantId"), + f.col("tag_beta").alias("beta"), + f.col("tag_pValueMantissa").alias("pValueMantissa"), + f.col("tag_pValueExponent").alias("pValueExponent"), + f.col("tag_standardError").alias("standardError"), + ), + ) + .groupBy("studyLocusId") + .agg( + f.collect_list(f.col("locus")).alias("locus"), + ) + ) + + self.df = self.df.drop("locus").join( + locus_df, + on="studyLocusId", + how="left", + ) + + return self + def annotate_ld( self: StudyLocus, study_index: StudyIndex, ld_index: LDIndex ) -> StudyLocus: diff --git a/src/gentropy/dataset/summary_statistics.py b/src/gentropy/dataset/summary_statistics.py index 442672c58..6cde03988 100644 --- a/src/gentropy/dataset/summary_statistics.py +++ b/src/gentropy/dataset/summary_statistics.py @@ -9,7 +9,6 @@ from gentropy.common.schemas import parse_spark_schema from gentropy.common.utils import parse_region, split_pvalue from gentropy.dataset.dataset import Dataset -from gentropy.method.window_based_clumping import WindowBasedClumping if TYPE_CHECKING: from pyspark.sql.types import StructType @@ -59,34 +58,24 @@ def window_based_clumping( self: SummaryStatistics, distance: int = 500_000, gwas_significance: float = 5e-8, - baseline_significance: float = 0.05, - locus_collect_distance: int | None = None, ) -> StudyLocus: - """Generate study-locus from summary statistics by distance based clumping + collect locus. + """Generate study-locus from summary statistics using window-based clumping. + + For more info, see [`WindowBasedClumping`][gentropy.method.window_based_clumping.WindowBasedClumping] Args: distance (int): Distance in base pairs to be used for clumping. Defaults to 500_000. gwas_significance (float, optional): GWAS significance threshold. Defaults to 5e-8. - baseline_significance (float, optional): Baseline significance threshold for inclusion in the locus. Defaults to 0.05. - locus_collect_distance (int | None): The distance to collect locus around semi-indices. If not provided, locus is not collected. Returns: - StudyLocus: Clumped study-locus containing variants based on window. + StudyLocus: Clumped study-locus optionally containing variants based on window. """ - return ( - WindowBasedClumping.clump_with_locus( - self, - window_length=distance, - p_value_significance=gwas_significance, - p_value_baseline=baseline_significance, - locus_window_length=locus_collect_distance, - ) - if locus_collect_distance - else WindowBasedClumping.clump( - self, - window_length=distance, - p_value_significance=gwas_significance, - ) + from gentropy.method.window_based_clumping import WindowBasedClumping + + return WindowBasedClumping.clump( + self, + distance=distance, + gwas_significance=gwas_significance, ) def exclude_region(self: SummaryStatistics, region: str) -> SummaryStatistics: diff --git a/src/gentropy/method/window_based_clumping.py b/src/gentropy/method/window_based_clumping.py index a2ae12419..57a24c559 100644 --- a/src/gentropy/method/window_based_clumping.py +++ b/src/gentropy/method/window_based_clumping.py @@ -151,22 +151,21 @@ def _prune_peak(position: NDArray[np.float64], window_size: int) -> DenseVector: return DenseVector(is_lead) - @classmethod + @staticmethod def clump( - cls: type[WindowBasedClumping], - summary_stats: SummaryStatistics, - window_length: int, - p_value_significance: float = 5e-8, + summary_statistics: SummaryStatistics, + distance: int = 500_000, + gwas_significance: float = 5e-8, ) -> StudyLocus: - """Clump summary statistics by distance. + """Clump significant signals from summary statistics based on window. Args: - summary_stats (SummaryStatistics): summary statistics to clump - window_length (int): window length in basepair - p_value_significance (float): only more significant variants are considered + summary_statistics (SummaryStatistics): Summary statistics to be used for clumping. + distance (int): Distance in base pairs to be used for clumping. Defaults to 500_000. + gwas_significance (float): GWAS significance threshold. Defaults to 5e-8. Returns: - StudyLocus: clumped summary statistics + StudyLocus: clumped summary statistics (without locus collection) """ # Create window for locus clusters # - variants where the distance between subsequent variants is below the defined threshold. @@ -177,9 +176,9 @@ def clump( return StudyLocus( _df=( - summary_stats + summary_statistics # Dropping snps below significance - all subsequent steps are done on significant variants: - .pvalue_filter(p_value_significance) + .pvalue_filter(gwas_significance) .df # Clustering summary variants for efficient windowing (complexity reduction): .withColumn( @@ -188,7 +187,7 @@ def clump( f.col("studyId"), f.col("chromosome"), f.col("position"), - window_length, + distance, ), ) # Within each cluster variants are ranked by significance: @@ -213,7 +212,7 @@ def clump( fml.vector_to_array( f.udf(WindowBasedClumping._prune_peak, VectorUDT())( fml.array_to_vector(f.col("collectedPositions")), - f.lit(window_length), + f.lit(distance), ) ), ), @@ -245,91 +244,3 @@ def clump( ), _schema=StudyLocus.get_schema(), ) - - @classmethod - def clump_with_locus( - cls: type[WindowBasedClumping], - summary_stats: SummaryStatistics, - window_length: int, - p_value_significance: float = 5e-8, - p_value_baseline: float = 0.05, - locus_window_length: int | None = None, - ) -> StudyLocus: - """Clump significant associations while collecting locus around them. - - Args: - summary_stats (SummaryStatistics): Input summary statistics dataset - window_length (int): Window size in bp, used for distance based clumping. - p_value_significance (float): GWAS significance threshold used to filter peaks. Defaults to 5e-8. - p_value_baseline (float): Least significant threshold. Below this, all snps are dropped. Defaults to 0.05. - locus_window_length (int | None): The distance for collecting locus around the semi indices. Defaults to None. - - Returns: - StudyLocus: StudyLocus after clumping with information about the `locus` - """ - # If no locus window provided, using the same value: - if locus_window_length is None: - locus_window_length = window_length - - # Run distance based clumping on the summary stats: - clumped_dataframe = WindowBasedClumping.clump( - summary_stats, - window_length=window_length, - p_value_significance=p_value_significance, - ).df.alias("clumped") - - # Get list of columns from clumped dataset for further propagation: - clumped_columns = clumped_dataframe.columns - - # Dropping variants not meeting the baseline criteria: - sumstats_baseline = summary_stats.pvalue_filter(p_value_baseline).df - - # Renaming columns: - sumstats_baseline_renamed = sumstats_baseline.selectExpr( - *[f"{col} as tag_{col}" for col in sumstats_baseline.columns] - ).alias("sumstat") - - study_locus_df = ( - sumstats_baseline_renamed - # Joining the two datasets together: - .join( - f.broadcast(clumped_dataframe), - on=[ - (f.col("sumstat.tag_studyId") == f.col("clumped.studyId")) - & (f.col("sumstat.tag_chromosome") == f.col("clumped.chromosome")) - & ( - f.col("sumstat.tag_position") - >= (f.col("clumped.position") - locus_window_length) - ) - & ( - f.col("sumstat.tag_position") - <= (f.col("clumped.position") + locus_window_length) - ) - ], - how="right", - ) - .withColumn( - "locus", - f.struct( - f.col("tag_variantId").alias("variantId"), - f.col("tag_beta").alias("beta"), - f.col("tag_pValueMantissa").alias("pValueMantissa"), - f.col("tag_pValueExponent").alias("pValueExponent"), - f.col("tag_standardError").alias("standardError"), - ), - ) - .groupby("studyLocusId") - .agg( - *[ - f.first(col).alias(col) - for col in clumped_columns - if col != "studyLocusId" - ], - f.collect_list(f.col("locus")).alias("locus"), - ) - ) - - return StudyLocus( - _df=study_locus_df, - _schema=StudyLocus.get_schema(), - ) diff --git a/src/gentropy/window_based_clumping.py b/src/gentropy/window_based_clumping.py index fcc680ef7..bce9edd37 100644 --- a/src/gentropy/window_based_clumping.py +++ b/src/gentropy/window_based_clumping.py @@ -1,4 +1,5 @@ """Step to run window based clumping on summary statistics datasts.""" + from __future__ import annotations from gentropy.common.session import Session @@ -13,8 +14,10 @@ def __init__( session: Session, summary_statistics_input_path: str, study_locus_output_path: str, + distance: int = 500_000, + collect_locus: bool = False, + collect_locus_distance: int = 500_000, inclusion_list_path: str | None = None, - locus_collect_distance: int | None = None, ) -> None: """Run window-based clumping step. @@ -22,8 +25,10 @@ def __init__( session (Session): Session object. summary_statistics_input_path (str): Path to the harmonized summary statistics dataset. study_locus_output_path (str): Output path for the resulting study locus dataset. + distance (int): Distance, within which tagging variants are collected around the semi-index. Optional. + collect_locus (bool): Whether to collect locus around semi-indices. Optional. + collect_locus_distance (int): Distance, within which tagging variants are collected around the semi-index. Optional. inclusion_list_path (str | None): Path to the inclusion list (list of white-listed study identifier). Optional. - locus_collect_distance (int | None): Distance, within which tagging variants are collected around the semi-index. Optional. """ # If inclusion list path is provided, only these studies will be read: if inclusion_list_path: @@ -35,16 +40,22 @@ def __init__( # If no inclusion list is provided, read all summary stats in folder: study_ids_to_ingest = [summary_statistics_input_path] - ( - SummaryStatistics.from_parquet( - session, - study_ids_to_ingest, - recursiveFileLookup=True, - ) - .coalesce(4000) - # Applying window based clumping: - .window_based_clumping(locus_collect_distance=locus_collect_distance) - # Save resulting study locus dataset: - .df.write.mode(session.write_mode) - .parquet(study_locus_output_path) + ss = SummaryStatistics.from_parquet( + session, + study_ids_to_ingest, + recursiveFileLookup=True, + ) + + # Clumping: + study_locus = ss.window_based_clumping( + distance=distance, ) + + # Optional locus collection: + if collect_locus: + # Collecting locus around semi-indices: + study_locus = study_locus.annotate_locus_statistics( + ss, collect_locus_distance=collect_locus_distance + ) + + study_locus.df.write.mode(session.write_mode).parquet(study_locus_output_path) diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index c12597d54..1401b9dd3 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -10,6 +10,7 @@ from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import CredibleInterval, StudyLocus from gentropy.dataset.study_locus_overlap import StudyLocusOverlap +from gentropy.dataset.summary_statistics import SummaryStatistics from pyspark.sql import Column, SparkSession from pyspark.sql.types import ( ArrayType, @@ -214,6 +215,16 @@ def test_filter_by_study_type( assert observed.df.count() == expected_sl_count +def test_annotate_locus_statistics( + mock_study_locus: StudyLocus, mock_summary_statistics: SummaryStatistics +) -> None: + """Test annotate locus statistics returns a StudyLocus.""" + assert isinstance( + mock_study_locus.annotate_locus_statistics(mock_summary_statistics, 100), + StudyLocus, + ) + + def test_filter_credible_set(mock_study_locus: StudyLocus) -> None: """Test credible interval filter.""" assert isinstance( diff --git a/tests/gentropy/method/test_window_based_clumping.py b/tests/gentropy/method/test_window_based_clumping.py index 03546df9d..cd583bac2 100644 --- a/tests/gentropy/method/test_window_based_clumping.py +++ b/tests/gentropy/method/test_window_based_clumping.py @@ -8,12 +8,12 @@ from gentropy.method.window_based_clumping import WindowBasedClumping from pyspark.ml import functions as fml from pyspark.ml.linalg import VectorUDT +from pyspark.sql import SparkSession from pyspark.sql import functions as f from pyspark.sql.window import Window if TYPE_CHECKING: from gentropy.dataset.summary_statistics import SummaryStatistics - from pyspark.sql import SparkSession def test_window_based_clump__return_type( @@ -21,7 +21,14 @@ def test_window_based_clump__return_type( ) -> None: """Test window-based clumping.""" assert isinstance( - WindowBasedClumping.clump_with_locus(mock_summary_statistics, 250_000), + WindowBasedClumping.clump(mock_summary_statistics, distance=250_000), + StudyLocus, + ) + assert isinstance( + WindowBasedClumping.clump( + mock_summary_statistics, + distance=250_000, + ), StudyLocus, ) @@ -44,7 +51,10 @@ def test_window_based_clump_with_locus__correctness( ) -> None: """Test window-based clumping.""" clumped = sample_summary_statistics.window_based_clumping( - distance=250_000, locus_collect_distance=250_000 + distance=250_000, + ) + clumped = clumped.annotate_locus_statistics( + sample_summary_statistics, collect_locus_distance=250_000 ) # Asserting the presence of locus key: @@ -57,7 +67,7 @@ def test_window_based_clump_with_locus__correctness( assert (clumped.df.filter(f.col("variantId") == "18_12843138_T_C").count()) == 1 # Assert the number of variants in the locus: - assert (clumped.df.select(f.explode_outer("locus").alias("loci")).count()) == 132 + assert (clumped.df.select(f.explode_outer("locus").alias("loci")).count()) == 218 def test_prune_peak(spark: SparkSession) -> None: