diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 948477600..e5be1a019 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -5,8 +5,6 @@ from functools import reduce from typing import TYPE_CHECKING, Type -from pyspark.sql.functions import col - from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import convert_from_long_to_wide from gentropy.dataset.dataset import Dataset @@ -62,22 +60,13 @@ def generate_features( Raises: ValueError: If the feature matrix is empty """ - coloc_methods = ( - colocalisation.df.select("colocalisationMethod") - .distinct() - .toPandas()["colocalisationMethod"] - .tolist() - ) if features_dfs := [ # Extract features ColocalisationFactory._get_max_coloc_per_credible_set( + colocalisation, credible_set, study_index, - colocalisation.filter(col("colocalisationMethod") == method), - method, - ).df - for method in coloc_methods - ] + [ + ).df, StudyLocusFactory._get_tss_distance_features(credible_set, variant_gene).df, StudyLocusFactory._get_vep_features(credible_set, variant_gene).df, ]: diff --git a/src/gentropy/datasource/finngen/summary_stats.py b/src/gentropy/datasource/finngen/summary_stats.py index 0d77f7d5c..08403bed5 100644 --- a/src/gentropy/datasource/finngen/summary_stats.py +++ b/src/gentropy/datasource/finngen/summary_stats.py @@ -50,7 +50,6 @@ def from_source( Returns: SummaryStatistics: Processed summary statistics dataset """ - study_id = raw_file.split("/")[-1].split(".")[0].upper() processed_summary_stats_df = ( spark.read.schema(cls.raw_schema) .option("delimiter", "\t") @@ -59,7 +58,11 @@ def from_source( .filter(f.col("pos").cast(t.IntegerType()).isNotNull()) .select( # From the full path, extracts just the filename, and converts to upper case to get the study ID. - f.lit(study_id).alias("studyId"), + f.upper( + f.regexp_extract( + f.input_file_name(), r"([^/]+)(\.tsv\.gz|\.gz|\.tsv)", 1 + ) + ).alias("studyId"), # Add variant information. f.concat_ws( "_", diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index 3e4e91c74..18d97fdf8 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -26,6 +26,9 @@ class ECaviar: It extends [CAVIAR](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5142122/#bib18) framework to explicitly estimate the posterior probability that the same variant is causal in 2 studies while accounting for the uncertainty of LD. eCAVIAR computes the colocalization posterior probability (**CLPP**) by utilizing the marginal posterior probabilities. This framework allows for **multiple variants to be causal** in a single locus. """ + METHOD_NAME: str = "eCAVIAR" + METHOD_METRIC: str = "clpp" + @staticmethod def _get_clpp(left_pp: Column, right_pp: Column) -> Column: """Calculate the colocalisation posterior probability (CLPP). @@ -81,7 +84,7 @@ def colocalise( f.count("*").alias("numberColocalisingVariants"), f.sum(f.col("clpp")).alias("clpp"), ) - .withColumn("colocalisationMethod", f.lit("eCAVIAR")) + .withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME)) ), _schema=Colocalisation.get_schema(), ) @@ -108,6 +111,8 @@ class Coloc: PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10. """ + METHOD_NAME: str = "COLOC" + METHOD_METRIC: str = "llr" PSEUDOCOUNT: float = 1e-10 @staticmethod @@ -154,24 +159,24 @@ def colocalise( posteriors = f.udf(Coloc._get_posteriors, VectorUDT()) return Colocalisation( _df=( - overlapping_signals.df + overlapping_signals.df.select("*", "statistics.*") # Before summing log_BF columns nulls need to be filled with 0: - .fillna(0, subset=["statistics.left_logBF", "statistics.right_logBF"]) + .fillna(0, subset=["left_logBF", "right_logBF"]) # Sum of log_BFs for each pair of signals .withColumn( "sum_log_bf", - f.col("statistics.left_logBF") + f.col("statistics.right_logBF"), + f.col("left_logBF") + f.col("right_logBF"), ) # Group by overlapping peak and generating dense vectors of log_BF: .groupBy("chromosome", "leftStudyLocusId", "rightStudyLocusId") .agg( f.count("*").alias("numberColocalisingVariants"), - fml.array_to_vector( - f.collect_list(f.col("statistics.left_logBF")) - ).alias("left_logBF"), - fml.array_to_vector( - f.collect_list(f.col("statistics.right_logBF")) - ).alias("right_logBF"), + fml.array_to_vector(f.collect_list(f.col("left_logBF"))).alias( + "left_logBF" + ), + fml.array_to_vector(f.collect_list(f.col("right_logBF"))).alias( + "right_logBF" + ), fml.array_to_vector(f.collect_list(f.col("sum_log_bf"))).alias( "sum_log_bf" ), @@ -253,7 +258,7 @@ def colocalise( "lH3bf", "lH4bf", ) - .withColumn("colocalisationMethod", f.lit("COLOC")) + .withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME)) ), _schema=Colocalisation.get_schema(), ) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 5373be108..f037b57b4 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -2,6 +2,7 @@ from __future__ import annotations from functools import reduce +from itertools import chain from typing import TYPE_CHECKING import pyspark.sql.functions as f @@ -12,6 +13,7 @@ ) from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.study_locus import CredibleInterval, StudyLocus +from gentropy.method.colocalisation import Coloc, ECaviar if TYPE_CHECKING: from pyspark.sql import Column, DataFrame @@ -24,164 +26,135 @@ class ColocalisationFactory: """Feature extraction in colocalisation.""" + @classmethod + def _add_colocalisation_metric(cls: type[ColocalisationFactory]) -> Column: + """Expression that adds a `colocalisationMetric` column to the colocalisation dataframe in preparation for feature extraction. + + Returns: + Column: The expression that adds a `colocalisationMetric` column with the derived metric + """ + method_metric_map = { + ECaviar.METHOD_NAME: ECaviar.METHOD_METRIC, + Coloc.METHOD_NAME: Coloc.METHOD_METRIC, + } + map_expr = f.create_map(*[f.lit(x) for x in chain(*method_metric_map.items())]) + return map_expr[f.col("colocalisationMethod")].alias("colocalisationMetric") + @staticmethod def _get_max_coloc_per_credible_set( + colocalisation: Colocalisation, credible_set: StudyLocus, studies: StudyIndex, - colocalisation: Colocalisation, - colocalisation_method: str, ) -> L2GFeature: """Get the maximum colocalisation posterior probability for each pair of overlapping study-locus per type of colocalisation method and QTL type. Args: + colocalisation (Colocalisation): Colocalisation dataset credible_set (StudyLocus): Study locus dataset studies (StudyIndex): Study index dataset - colocalisation (Colocalisation): Colocalisation dataset - colocalisation_method (str): Colocalisation method to extract the max from Returns: L2GFeature: Stores the features with the max coloc probabilities for each pair of study-locus - - Raises: - ValueError: If the colocalisation method is not supported """ - if colocalisation_method not in ["COLOC", "eCAVIAR"]: - raise ValueError( - f"Colocalisation method {colocalisation_method} not supported" - ) - if colocalisation_method == "COLOC": - coloc_score_col_name = "log2h4h3" - coloc_feature_col_template = "ColocLlrMaximum" - - elif colocalisation_method == "eCAVIAR": - coloc_score_col_name = "clpp" - coloc_feature_col_template = "ColocClppMaximum" + colocalisation_df = colocalisation.df.select( + f.col("leftStudyLocusId").alias("studyLocusId"), + "rightStudyLocusId", + f.coalesce("log2h4h3", "clpp").alias("score"), + ColocalisationFactory._add_colocalisation_metric(), + ) colocalising_credible_sets = ( credible_set.df.select("studyLocusId", "studyId") # annotate studyLoci with overlapping IDs on the left - to just keep GWAS associations .join( - colocalisation.df.selectExpr( - "leftStudyLocusId as studyLocusId", - "rightStudyLocusId", - "colocalisationMethod", - f"{coloc_score_col_name} as coloc_score", - ), + colocalisation_df, on="studyLocusId", how="inner", ) # bring study metadata to just keep QTL studies on the right .join( - credible_set.df.selectExpr( - "studyLocusId as rightStudyLocusId", "studyId as right_studyId" + credible_set.df.join( + studies.df.select("studyId", "studyType", "geneId"), "studyId" + ).selectExpr( + "studyLocusId as rightStudyLocusId", + "studyType as right_studyType", + "geneId", ), on="rightStudyLocusId", how="inner", ) - .join( - f.broadcast( - studies.df.selectExpr( - "studyId as right_studyId", - "studyType as right_studyType", - "geneId", - ) - ), - on="right_studyId", - how="inner", - ) - .filter( - (f.col("colocalisationMethod") == colocalisation_method) - & (f.col("right_studyType") != "gwas") + .filter(f.col("right_studyType") != "gwas") + .select( + "studyLocusId", + "right_studyType", + "geneId", + "score", + "colocalisationMetric", ) - .select("studyLocusId", "right_studyType", "geneId", "coloc_score") ) - # Max PP calculation per studyLocus AND type of QTL - local_max = get_record_with_maximum_value( - colocalising_credible_sets, - ["studyLocusId", "right_studyType", "geneId"], - "coloc_score", + # Max PP calculation per credible set AND type of QTL AND colocalisation method + local_max = ( + get_record_with_maximum_value( + colocalising_credible_sets, + ["studyLocusId", "right_studyType", "geneId", "colocalisationMetric"], + "score", + ) + .select( + "*", + f.col("score").alias("max_score"), + f.lit("Local").alias("score_type"), + ) + .drop("score") ) - intercept = 0.0001 neighbourhood_max = ( local_max.selectExpr( - "studyLocusId", "coloc_score as coloc_local_max", "geneId" + "studyLocusId", "max_score as local_max_score", "geneId" ) .join( # Add maximum in the neighborhood get_record_with_maximum_value( colocalising_credible_sets.withColumnRenamed( - "coloc_score", "coloc_neighborhood_max" + "score", "tmp_nbh_max_score" ), - ["studyLocusId", "right_studyType"], - "coloc_neighborhood_max", + ["studyLocusId", "right_studyType", "colocalisationMetric"], + "tmp_nbh_max_score", ).drop("geneId"), on="studyLocusId", ) + .withColumn("score_type", f.lit("Neighborhood")) .withColumn( - f"{coloc_feature_col_template}Neighborhood", + "max_score", f.log10( f.abs( - f.col("coloc_local_max") - - f.col("coloc_neighborhood_max") - + f.lit(intercept) + f.col("local_max_score") + - f.col("tmp_nbh_max_score") + + f.lit(0.0001) # intercept ) ), ) - ).drop("coloc_neighborhood_max", "coloc_local_max") - - # Split feature per molQTL - local_dfs = [] - nbh_dfs = [] - qtl_types: list[str] = ( - colocalising_credible_sets.select("right_studyType") - .distinct() - .toPandas()["right_studyType"] - .tolist() - ) - for qtl_type in qtl_types: - filtered_local_max = ( - local_max.filter(f.col("right_studyType") == qtl_type) - .withColumnRenamed( - "coloc_score", - f"{qtl_type}{coloc_feature_col_template}", - ) - .drop("right_studyType") - ) - local_dfs.append(filtered_local_max) - - filtered_neighbourhood_max = ( - neighbourhood_max.filter(f.col("right_studyType") == qtl_type) - .withColumnRenamed( - f"{coloc_feature_col_template}Neighborhood", - f"{qtl_type}{coloc_feature_col_template}Neighborhood", - ) - .drop("right_studyType") - ) - nbh_dfs.append(filtered_neighbourhood_max) - - wide_dfs = reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), - local_dfs + nbh_dfs, - ) + ).drop("tmp_nbh_max_score", "local_max_score") return L2GFeature( - _df=convert_from_wide_to_long( - wide_dfs.groupBy("studyLocusId", "geneId").agg( - *( - f.first(f.col(c), ignorenulls=True).alias(c) - for c in wide_dfs.columns - if c - not in [ - "studyLocusId", - "geneId", - ] - ) - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", + _df=( + # Combine local and neighborhood metrics + local_max.unionByName( + neighbourhood_max, allowMissingColumns=True + ).select( + "studyLocusId", + "geneId", + # Feature name is a concatenation of the QTL type, colocalisation metric and if it's local or in the vicinity + f.concat_ws( + "", + f.col("right_studyType"), + f.lit("Coloc"), + f.initcap(f.col("colocalisationMetric")), + f.lit("Maximum"), + f.regexp_replace(f.col("score_type"), "Local", ""), + ).alias("featureName"), + f.col("max_score").cast("float").alias("featureValue"), + ) ), _schema=L2GFeature.get_schema(), ) diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index 37613354c..e58b0e562 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -10,6 +10,7 @@ from gentropy.method.colocalisation import Coloc, ECaviar from pandas.testing import assert_frame_equal from pyspark.sql import SparkSession +from pyspark.sql.types import DoubleType, LongType, StringType, StructField, StructType def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: @@ -104,6 +105,57 @@ def test_coloc_semantic( ) +def test_coloc_no_logbf( + spark: SparkSession, + minimum_expected_h0: float = 0.99, + maximum_expected_h4: float = 1e-5, +) -> None: + """Test COLOC output when the input data has irrelevant logBF.""" + observed_overlap = StudyLocusOverlap( + ( + spark.createDataFrame( + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp", + "statistics": { + "left_logBF": None, + "right_logBF": None, + }, # irrelevant for COLOC + } + ], + schema=StructType( + [ + StructField("leftStudyLocusId", LongType(), False), + StructField("rightStudyLocusId", LongType(), False), + StructField("chromosome", StringType(), False), + StructField("tagVariantId", StringType(), False), + StructField( + "statistics", + StructType( + [ + StructField("left_logBF", DoubleType(), True), + StructField("right_logBF", DoubleType(), True), + ] + ), + ), + ] + ), + ) + ), + StudyLocusOverlap.get_schema(), + ) + observed_coloc_df = Coloc.colocalise(observed_overlap).df + assert ( + observed_coloc_df.select("h0").collect()[0]["h0"] > minimum_expected_h0 + ), "COLOC should return a high h0 (no association) when the input data has irrelevant logBF." + assert ( + observed_coloc_df.select("h4").collect()[0]["h4"] < maximum_expected_h4 + ), "COLOC should return a low h4 (traits are associated) when the input data has irrelevant logBF." + + def test_ecaviar(mock_study_locus_overlap: StudyLocusOverlap) -> None: """Test eCAVIAR.""" assert isinstance(ECaviar.colocalise(mock_study_locus_overlap), Colocalisation) diff --git a/tests/gentropy/method/test_locus_to_gene.py b/tests/gentropy/method/test_locus_to_gene.py index 3976406c9..898252f9f 100644 --- a/tests/gentropy/method/test_locus_to_gene.py +++ b/tests/gentropy/method/test_locus_to_gene.py @@ -80,30 +80,21 @@ def test_train( class TestColocalisationFactory: """Test the ColocalisationFactory methods.""" - @pytest.mark.parametrize( - "colocalisation_method", - [ - "COLOC", - "eCAVIAR", - ], - ) def test_get_max_coloc_per_credible_set( self: TestColocalisationFactory, mock_study_locus: StudyLocus, mock_study_index: StudyIndex, mock_colocalisation: Colocalisation, - colocalisation_method: str, ) -> None: """Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type.""" coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( + mock_colocalisation, mock_study_locus, mock_study_index, - mock_colocalisation, - colocalisation_method, ) assert isinstance( coloc_features, L2GFeature - ), "Unexpected model type returned from _get_max_coloc_per_credible_set" + ), "Unexpected type returned from _get_max_coloc_per_credible_set" def test_get_max_coloc_per_credible_set_semantic( self: TestColocalisationFactory, @@ -169,8 +160,10 @@ def test_get_max_coloc_per_credible_set_semantic( "colocalisationMethod": "eCAVIAR", "numberColocalisingVariants": 1, "clpp": 0.81, # 0.9*0.9 + "log2h4h3": None, } - ] + ], + schema=Colocalisation.get_schema(), ), _schema=Colocalisation.get_schema(), ) @@ -183,10 +176,9 @@ def test_get_max_coloc_per_credible_set_semantic( ) # Test coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( + coloc, credset, studies, - coloc, - "eCAVIAR", ) assert coloc_features.df.collect() == expected_coloc_features_df.collect()