Skip to content

Commit

Permalink
Merge branch 'dev' into ds_pairwise_ld
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel-Considine committed Mar 28, 2024
2 parents b9b817d + 255c42d commit 170dd09
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 146 deletions.
15 changes: 2 additions & 13 deletions src/gentropy/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from functools import reduce
from typing import TYPE_CHECKING, Type

from pyspark.sql.functions import col

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import convert_from_long_to_wide
from gentropy.dataset.dataset import Dataset
Expand Down Expand Up @@ -62,22 +60,13 @@ def generate_features(
Raises:
ValueError: If the feature matrix is empty
"""
coloc_methods = (
colocalisation.df.select("colocalisationMethod")
.distinct()
.toPandas()["colocalisationMethod"]
.tolist()
)
if features_dfs := [
# Extract features
ColocalisationFactory._get_max_coloc_per_credible_set(
colocalisation,
credible_set,
study_index,
colocalisation.filter(col("colocalisationMethod") == method),
method,
).df
for method in coloc_methods
] + [
).df,
StudyLocusFactory._get_tss_distance_features(credible_set, variant_gene).df,
StudyLocusFactory._get_vep_features(credible_set, variant_gene).df,
]:
Expand Down
7 changes: 5 additions & 2 deletions src/gentropy/datasource/finngen/summary_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def from_source(
Returns:
SummaryStatistics: Processed summary statistics dataset
"""
study_id = raw_file.split("/")[-1].split(".")[0].upper()
processed_summary_stats_df = (
spark.read.schema(cls.raw_schema)
.option("delimiter", "\t")
Expand All @@ -59,7 +58,11 @@ def from_source(
.filter(f.col("pos").cast(t.IntegerType()).isNotNull())
.select(
# From the full path, extracts just the filename, and converts to upper case to get the study ID.
f.lit(study_id).alias("studyId"),
f.upper(
f.regexp_extract(
f.input_file_name(), r"([^/]+)(\.tsv\.gz|\.gz|\.tsv)", 1
)
).alias("studyId"),
# Add variant information.
f.concat_ws(
"_",
Expand Down
27 changes: 16 additions & 11 deletions src/gentropy/method/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class ECaviar:
It extends [CAVIAR](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5142122/#bib18) framework to explicitly estimate the posterior probability that the same variant is causal in 2 studies while accounting for the uncertainty of LD. eCAVIAR computes the colocalization posterior probability (**CLPP**) by utilizing the marginal posterior probabilities. This framework allows for **multiple variants to be causal** in a single locus.
"""

METHOD_NAME: str = "eCAVIAR"
METHOD_METRIC: str = "clpp"

@staticmethod
def _get_clpp(left_pp: Column, right_pp: Column) -> Column:
"""Calculate the colocalisation posterior probability (CLPP).
Expand Down Expand Up @@ -81,7 +84,7 @@ def colocalise(
f.count("*").alias("numberColocalisingVariants"),
f.sum(f.col("clpp")).alias("clpp"),
)
.withColumn("colocalisationMethod", f.lit("eCAVIAR"))
.withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME))
),
_schema=Colocalisation.get_schema(),
)
Expand All @@ -108,6 +111,8 @@ class Coloc:
PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10.
"""

METHOD_NAME: str = "COLOC"
METHOD_METRIC: str = "llr"
PSEUDOCOUNT: float = 1e-10

@staticmethod
Expand Down Expand Up @@ -154,24 +159,24 @@ def colocalise(
posteriors = f.udf(Coloc._get_posteriors, VectorUDT())
return Colocalisation(
_df=(
overlapping_signals.df
overlapping_signals.df.select("*", "statistics.*")
# Before summing log_BF columns nulls need to be filled with 0:
.fillna(0, subset=["statistics.left_logBF", "statistics.right_logBF"])
.fillna(0, subset=["left_logBF", "right_logBF"])
# Sum of log_BFs for each pair of signals
.withColumn(
"sum_log_bf",
f.col("statistics.left_logBF") + f.col("statistics.right_logBF"),
f.col("left_logBF") + f.col("right_logBF"),
)
# Group by overlapping peak and generating dense vectors of log_BF:
.groupBy("chromosome", "leftStudyLocusId", "rightStudyLocusId")
.agg(
f.count("*").alias("numberColocalisingVariants"),
fml.array_to_vector(
f.collect_list(f.col("statistics.left_logBF"))
).alias("left_logBF"),
fml.array_to_vector(
f.collect_list(f.col("statistics.right_logBF"))
).alias("right_logBF"),
fml.array_to_vector(f.collect_list(f.col("left_logBF"))).alias(
"left_logBF"
),
fml.array_to_vector(f.collect_list(f.col("right_logBF"))).alias(
"right_logBF"
),
fml.array_to_vector(f.collect_list(f.col("sum_log_bf"))).alias(
"sum_log_bf"
),
Expand Down Expand Up @@ -253,7 +258,7 @@ def colocalise(
"lH3bf",
"lH4bf",
)
.withColumn("colocalisationMethod", f.lit("COLOC"))
.withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME))
),
_schema=Colocalisation.get_schema(),
)
185 changes: 79 additions & 106 deletions src/gentropy/method/l2g/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
Expand All @@ -12,6 +13,7 @@
)
from gentropy.dataset.l2g_feature import L2GFeature
from gentropy.dataset.study_locus import CredibleInterval, StudyLocus
from gentropy.method.colocalisation import Coloc, ECaviar

if TYPE_CHECKING:
from pyspark.sql import Column, DataFrame
Expand All @@ -24,164 +26,135 @@
class ColocalisationFactory:
"""Feature extraction in colocalisation."""

@classmethod
def _add_colocalisation_metric(cls: type[ColocalisationFactory]) -> Column:
"""Expression that adds a `colocalisationMetric` column to the colocalisation dataframe in preparation for feature extraction.
Returns:
Column: The expression that adds a `colocalisationMetric` column with the derived metric
"""
method_metric_map = {
ECaviar.METHOD_NAME: ECaviar.METHOD_METRIC,
Coloc.METHOD_NAME: Coloc.METHOD_METRIC,
}
map_expr = f.create_map(*[f.lit(x) for x in chain(*method_metric_map.items())])
return map_expr[f.col("colocalisationMethod")].alias("colocalisationMetric")

@staticmethod
def _get_max_coloc_per_credible_set(
colocalisation: Colocalisation,
credible_set: StudyLocus,
studies: StudyIndex,
colocalisation: Colocalisation,
colocalisation_method: str,
) -> L2GFeature:
"""Get the maximum colocalisation posterior probability for each pair of overlapping study-locus per type of colocalisation method and QTL type.
Args:
colocalisation (Colocalisation): Colocalisation dataset
credible_set (StudyLocus): Study locus dataset
studies (StudyIndex): Study index dataset
colocalisation (Colocalisation): Colocalisation dataset
colocalisation_method (str): Colocalisation method to extract the max from
Returns:
L2GFeature: Stores the features with the max coloc probabilities for each pair of study-locus
Raises:
ValueError: If the colocalisation method is not supported
"""
if colocalisation_method not in ["COLOC", "eCAVIAR"]:
raise ValueError(
f"Colocalisation method {colocalisation_method} not supported"
)
if colocalisation_method == "COLOC":
coloc_score_col_name = "log2h4h3"
coloc_feature_col_template = "ColocLlrMaximum"

elif colocalisation_method == "eCAVIAR":
coloc_score_col_name = "clpp"
coloc_feature_col_template = "ColocClppMaximum"
colocalisation_df = colocalisation.df.select(
f.col("leftStudyLocusId").alias("studyLocusId"),
"rightStudyLocusId",
f.coalesce("log2h4h3", "clpp").alias("score"),
ColocalisationFactory._add_colocalisation_metric(),
)

colocalising_credible_sets = (
credible_set.df.select("studyLocusId", "studyId")
# annotate studyLoci with overlapping IDs on the left - to just keep GWAS associations
.join(
colocalisation.df.selectExpr(
"leftStudyLocusId as studyLocusId",
"rightStudyLocusId",
"colocalisationMethod",
f"{coloc_score_col_name} as coloc_score",
),
colocalisation_df,
on="studyLocusId",
how="inner",
)
# bring study metadata to just keep QTL studies on the right
.join(
credible_set.df.selectExpr(
"studyLocusId as rightStudyLocusId", "studyId as right_studyId"
credible_set.df.join(
studies.df.select("studyId", "studyType", "geneId"), "studyId"
).selectExpr(
"studyLocusId as rightStudyLocusId",
"studyType as right_studyType",
"geneId",
),
on="rightStudyLocusId",
how="inner",
)
.join(
f.broadcast(
studies.df.selectExpr(
"studyId as right_studyId",
"studyType as right_studyType",
"geneId",
)
),
on="right_studyId",
how="inner",
)
.filter(
(f.col("colocalisationMethod") == colocalisation_method)
& (f.col("right_studyType") != "gwas")
.filter(f.col("right_studyType") != "gwas")
.select(
"studyLocusId",
"right_studyType",
"geneId",
"score",
"colocalisationMetric",
)
.select("studyLocusId", "right_studyType", "geneId", "coloc_score")
)

# Max PP calculation per studyLocus AND type of QTL
local_max = get_record_with_maximum_value(
colocalising_credible_sets,
["studyLocusId", "right_studyType", "geneId"],
"coloc_score",
# Max PP calculation per credible set AND type of QTL AND colocalisation method
local_max = (
get_record_with_maximum_value(
colocalising_credible_sets,
["studyLocusId", "right_studyType", "geneId", "colocalisationMetric"],
"score",
)
.select(
"*",
f.col("score").alias("max_score"),
f.lit("Local").alias("score_type"),
)
.drop("score")
)

intercept = 0.0001
neighbourhood_max = (
local_max.selectExpr(
"studyLocusId", "coloc_score as coloc_local_max", "geneId"
"studyLocusId", "max_score as local_max_score", "geneId"
)
.join(
# Add maximum in the neighborhood
get_record_with_maximum_value(
colocalising_credible_sets.withColumnRenamed(
"coloc_score", "coloc_neighborhood_max"
"score", "tmp_nbh_max_score"
),
["studyLocusId", "right_studyType"],
"coloc_neighborhood_max",
["studyLocusId", "right_studyType", "colocalisationMetric"],
"tmp_nbh_max_score",
).drop("geneId"),
on="studyLocusId",
)
.withColumn("score_type", f.lit("Neighborhood"))
.withColumn(
f"{coloc_feature_col_template}Neighborhood",
"max_score",
f.log10(
f.abs(
f.col("coloc_local_max")
- f.col("coloc_neighborhood_max")
+ f.lit(intercept)
f.col("local_max_score")
- f.col("tmp_nbh_max_score")
+ f.lit(0.0001) # intercept
)
),
)
).drop("coloc_neighborhood_max", "coloc_local_max")

# Split feature per molQTL
local_dfs = []
nbh_dfs = []
qtl_types: list[str] = (
colocalising_credible_sets.select("right_studyType")
.distinct()
.toPandas()["right_studyType"]
.tolist()
)
for qtl_type in qtl_types:
filtered_local_max = (
local_max.filter(f.col("right_studyType") == qtl_type)
.withColumnRenamed(
"coloc_score",
f"{qtl_type}{coloc_feature_col_template}",
)
.drop("right_studyType")
)
local_dfs.append(filtered_local_max)

filtered_neighbourhood_max = (
neighbourhood_max.filter(f.col("right_studyType") == qtl_type)
.withColumnRenamed(
f"{coloc_feature_col_template}Neighborhood",
f"{qtl_type}{coloc_feature_col_template}Neighborhood",
)
.drop("right_studyType")
)
nbh_dfs.append(filtered_neighbourhood_max)

wide_dfs = reduce(
lambda x, y: x.unionByName(y, allowMissingColumns=True),
local_dfs + nbh_dfs,
)
).drop("tmp_nbh_max_score", "local_max_score")

return L2GFeature(
_df=convert_from_wide_to_long(
wide_dfs.groupBy("studyLocusId", "geneId").agg(
*(
f.first(f.col(c), ignorenulls=True).alias(c)
for c in wide_dfs.columns
if c
not in [
"studyLocusId",
"geneId",
]
)
),
id_vars=("studyLocusId", "geneId"),
var_name="featureName",
value_name="featureValue",
_df=(
# Combine local and neighborhood metrics
local_max.unionByName(
neighbourhood_max, allowMissingColumns=True
).select(
"studyLocusId",
"geneId",
# Feature name is a concatenation of the QTL type, colocalisation metric and if it's local or in the vicinity
f.concat_ws(
"",
f.col("right_studyType"),
f.lit("Coloc"),
f.initcap(f.col("colocalisationMetric")),
f.lit("Maximum"),
f.regexp_replace(f.col("score_type"), "Local", ""),
).alias("featureName"),
f.col("max_score").cast("float").alias("featureValue"),
)
),
_schema=L2GFeature.get_schema(),
)
Expand Down
Loading

0 comments on commit 170dd09

Please sign in to comment.