Skip to content

Commit

Permalink
perf(l2g): rewrite ColocalisationFactory._get_max_coloc_per_credible_set
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Mar 20, 2024
1 parent 160051c commit f788e73
Showing 1 changed file with 66 additions and 90 deletions.
156 changes: 66 additions & 90 deletions src/gentropy/method/l2g/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyspark.sql.functions as f

from gentropy.common.spark_helpers import (
convert_from_long_to_wide,
convert_from_wide_to_long,
get_record_with_maximum_value,
)
Expand All @@ -29,149 +30,124 @@ def _get_max_coloc_per_credible_set(
credible_set: StudyLocus,
studies: StudyIndex,
colocalisation: Colocalisation,
colocalisation_method: str,
) -> L2GFeature:
"""Get the maximum colocalisation posterior probability for each pair of overlapping study-locus per type of colocalisation method and QTL type.
Args:
credible_set (StudyLocus): Study locus dataset
studies (StudyIndex): Study index dataset
colocalisation (Colocalisation): Colocalisation dataset
colocalisation_method (str): Colocalisation method to extract the max from
Returns:
L2GFeature: Stores the features with the max coloc probabilities for each pair of study-locus
Raises:
ValueError: If the colocalisation method is not supported
"""
if colocalisation_method not in ["COLOC", "eCAVIAR"]:
raise ValueError(
f"Colocalisation method {colocalisation_method} not supported"
)
if colocalisation_method == "COLOC":
coloc_score_col_name = "log2h4h3"
coloc_feature_col_template = "ColocLlrMaximum"

elif colocalisation_method == "eCAVIAR":
coloc_score_col_name = "clpp"
coloc_feature_col_template = "ColocClppMaximum"
colocalisation_df = colocalisation.df.select(
f.col("leftStudyLocusId").alias("studyLocusId"),
"rightStudyLocusId",
f.coalesce("log2h4h3", "clpp").alias("score"),
f.when(f.col("colocalisationMethod") == "COLOC", f.lit("Llr"))
.when(f.col("colocalisationMethod") == "eCAVIAR", f.lit("Clpp"))
.alias("colocalisationMetric"),
)

colocalising_credible_sets = (
credible_set.df.select("studyLocusId", "studyId")
# annotate studyLoci with overlapping IDs on the left - to just keep GWAS associations
.join(
colocalisation.df.selectExpr(
"leftStudyLocusId as studyLocusId",
"rightStudyLocusId",
"colocalisationMethod",
f"{coloc_score_col_name} as coloc_score",
),
colocalisation_df,
on="studyLocusId",
how="inner",
)
# bring study metadata to just keep QTL studies on the right
.join(
credible_set.df.selectExpr(
"studyLocusId as rightStudyLocusId", "studyId as right_studyId"
credible_set.df.join(
studies.df.select("studyId", "studyType", "geneId"), "studyId"
).selectExpr(
"studyLocusId as rightStudyLocusId", "studyType as right_studyType"
),
on="rightStudyLocusId",
how="inner",
)
.join(
f.broadcast(
studies.df.selectExpr(
"studyId as right_studyId",
"studyType as right_studyType",
"geneId",
)
),
on="right_studyId",
how="inner",
)
.filter(
(f.col("colocalisationMethod") == colocalisation_method)
& (f.col("right_studyType") != "gwas")
.filter(f.col("right_studyType") != "gwas")
.select(
"studyLocusId",
"right_studyType",
"geneId",
"score",
"colocalisationMetric",
)
.select("studyLocusId", "right_studyType", "geneId", "coloc_score")
)

# Max PP calculation per studyLocus AND type of QTL
local_max = get_record_with_maximum_value(
colocalising_credible_sets,
["studyLocusId", "right_studyType", "geneId"],
"coloc_score",
# Max PP calculation per credible set AND type of QTL AND colocalisation method
local_max = (
get_record_with_maximum_value(
colocalising_credible_sets,
["studyLocusId", "right_studyType", "geneId", "colocalisationMetric"],
"score",
)
.select(
"*",
f.col("score").alias("max_score"),
f.lit("Local").alias("score_type"),
)
.drop("score")
)

intercept = 0.0001
neighbourhood_max = (
local_max.selectExpr(
"studyLocusId", "coloc_score as coloc_local_max", "geneId"
"studyLocusId", "max_score as local_max_score", "geneId"
)
.join(
# Add maximum in the neighborhood
get_record_with_maximum_value(
colocalising_credible_sets.withColumnRenamed(
"coloc_score", "coloc_neighborhood_max"
"score", "tmp_nbh_max_score"
),
["studyLocusId", "right_studyType"],
"coloc_neighborhood_max",
["studyLocusId", "right_studyType", "colocalisationMetric"],
"tmp_nbh_max_score",
).drop("geneId"),
on="studyLocusId",
)
.withColumn("score_type", f.lit("Neighborhood"))
.withColumn(
f"{coloc_feature_col_template}Neighborhood",
"max_score",
f.log10(
f.abs(
f.col("coloc_local_max")
- f.col("coloc_neighborhood_max")
+ f.lit(intercept)
f.col("local_max_score")
- f.col("tmp_nbh_max_score")
+ f.lit(0.0001) # intercept
)
),
)
).drop("coloc_neighborhood_max", "coloc_local_max")

# Split feature per molQTL
local_dfs = []
nbh_dfs = []
qtl_types: list[str] = (
colocalising_credible_sets.select("right_studyType")
.distinct()
.toPandas()["right_studyType"]
.tolist()
)
for qtl_type in qtl_types:
filtered_local_max = (
local_max.filter(f.col("right_studyType") == qtl_type)
.withColumnRenamed(
"coloc_score",
f"{qtl_type}{coloc_feature_col_template}",
)
.drop("right_studyType")
)
local_dfs.append(filtered_local_max)

filtered_neighbourhood_max = (
neighbourhood_max.filter(f.col("right_studyType") == qtl_type)
.withColumnRenamed(
f"{coloc_feature_col_template}Neighborhood",
f"{qtl_type}{coloc_feature_col_template}Neighborhood",
).drop("tmp_nbh_max_score", "local_max_score")

wide_df = convert_from_long_to_wide(
df=(
local_max.unionByName(
neighbourhood_max, allowMissingColumns=True
).withColumn(
"featureName",
f.concat_ws(
"",
f.col("right_studyType"),
f.lit("Coloc"),
f.col("colocalisationMetric"),
f.lit("Maximum"),
f.col("score_type"),
),
)
.drop("right_studyType")
)
nbh_dfs.append(filtered_neighbourhood_max)

wide_dfs = reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True),
local_dfs + nbh_dfs,
),
id_vars=["studyLocusId", "geneId"],
var_name="featureName",
value_name="max_score",
)

return L2GFeature(
_df=convert_from_wide_to_long(
wide_dfs.groupBy("studyLocusId", "geneId").agg(
wide_df.groupBy("studyLocusId", "geneId").agg(
*(
f.first(f.col(c), ignorenulls=True).alias(c)
for c in wide_dfs.columns
for c in wide_df.columns
if c
not in [
"studyLocusId",
Expand All @@ -182,7 +158,7 @@ def _get_max_coloc_per_credible_set(
id_vars=("studyLocusId", "geneId"),
var_name="featureName",
value_name="featureValue",
),
).filter(f.col("featureValue").isNotNull()),
_schema=L2GFeature.get_schema(),
)

Expand Down

0 comments on commit f788e73

Please sign in to comment.