diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index 3e4e91c74..eaf5b9de6 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -26,6 +26,9 @@ class ECaviar: It extends [CAVIAR](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5142122/#bib18) framework to explicitly estimate the posterior probability that the same variant is causal in 2 studies while accounting for the uncertainty of LD. eCAVIAR computes the colocalization posterior probability (**CLPP**) by utilizing the marginal posterior probabilities. This framework allows for **multiple variants to be causal** in a single locus. """ + METHOD_NAME: str = "eCAVIAR" + METHOD_METRIC: str = "clpp" + @staticmethod def _get_clpp(left_pp: Column, right_pp: Column) -> Column: """Calculate the colocalisation posterior probability (CLPP). @@ -81,7 +84,7 @@ def colocalise( f.count("*").alias("numberColocalisingVariants"), f.sum(f.col("clpp")).alias("clpp"), ) - .withColumn("colocalisationMethod", f.lit("eCAVIAR")) + .withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME)) ), _schema=Colocalisation.get_schema(), ) @@ -108,6 +111,8 @@ class Coloc: PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10. """ + METHOD_NAME: str = "COLOC" + METHOD_METRIC: str = "llr" PSEUDOCOUNT: float = 1e-10 @staticmethod @@ -253,7 +258,7 @@ def colocalise( "lH3bf", "lH4bf", ) - .withColumn("colocalisationMethod", f.lit("COLOC")) + .withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME)) ), _schema=Colocalisation.get_schema(), ) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 3b9b5b278..f037b57b4 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -2,6 +2,7 @@ from __future__ import annotations from functools import reduce +from itertools import chain from typing import TYPE_CHECKING import pyspark.sql.functions as f @@ -12,6 +13,7 @@ ) from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.study_locus import CredibleInterval, StudyLocus +from gentropy.method.colocalisation import Coloc, ECaviar if TYPE_CHECKING: from pyspark.sql import Column, DataFrame @@ -24,6 +26,20 @@ class ColocalisationFactory: """Feature extraction in colocalisation.""" + @classmethod + def _add_colocalisation_metric(cls: type[ColocalisationFactory]) -> Column: + """Expression that adds a `colocalisationMetric` column to the colocalisation dataframe in preparation for feature extraction. + + Returns: + Column: The expression that adds a `colocalisationMetric` column with the derived metric + """ + method_metric_map = { + ECaviar.METHOD_NAME: ECaviar.METHOD_METRIC, + Coloc.METHOD_NAME: Coloc.METHOD_METRIC, + } + map_expr = f.create_map(*[f.lit(x) for x in chain(*method_metric_map.items())]) + return map_expr[f.col("colocalisationMethod")].alias("colocalisationMetric") + @staticmethod def _get_max_coloc_per_credible_set( colocalisation: Colocalisation, @@ -44,9 +60,7 @@ def _get_max_coloc_per_credible_set( f.col("leftStudyLocusId").alias("studyLocusId"), "rightStudyLocusId", f.coalesce("log2h4h3", "clpp").alias("score"), - f.when(f.col("colocalisationMethod") == "COLOC", f.lit("Llr")) - .when(f.col("colocalisationMethod") == "eCAVIAR", f.lit("Clpp")) - .alias("colocalisationMetric"), + ColocalisationFactory._add_colocalisation_metric(), ) colocalising_credible_sets = ( @@ -135,7 +149,7 @@ def _get_max_coloc_per_credible_set( "", f.col("right_studyType"), f.lit("Coloc"), - f.col("colocalisationMetric"), + f.initcap(f.col("colocalisationMetric")), f.lit("Maximum"), f.regexp_replace(f.col("score_type"), "Local", ""), ).alias("featureName"),