Skip to content

Commit

Permalink
feat: add FM step with carma and sumstat imputation (#568)
Browse files Browse the repository at this point in the history
* feat: add fm step with carma and sumstat imputation

* fix: adding log

* fix: fixing carma

* fix: resolving conflict

* fix: resolve conflicts with dev v2

* fix: siliencing FutureWarning in Carma
  • Loading branch information
addramir committed Apr 9, 2024
1 parent b76fd07 commit 86600b0
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 5 deletions.
19 changes: 16 additions & 3 deletions src/gentropy/method/carma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import concurrent.futures
import warnings
from itertools import combinations
from math import floor, lgamma
from typing import Any
Expand Down Expand Up @@ -32,6 +33,8 @@ def time_limited_CARMA_spike_slab_noEM(
- B_list: A dataframe containing the marginal likelihoods and the corresponding model space or None.
- Outliers: A list of outlier SNPs or None.
"""
# Ignore pandas future warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
try:
# Execute CARMA.CARMA_spike_slab_noEM with a timeout
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
Expand Down Expand Up @@ -854,9 +857,19 @@ def _MCS_modified( # noqa: C901
sec_sample = np.random.choice(
range(0, 3), 1, p=np.exp(aa) / np.sum(np.exp(aa))
)
S = set_gamma[sec_sample[0]][
int(set_star["gamma_set_index"][sec_sample[0]])
].tolist()
if set_gamma[sec_sample[0]] is not None:
S = set_gamma[sec_sample[0]][
int(set_star["gamma_set_index"][sec_sample[0]])
].tolist()
else:
sec_sample = np.random.choice(
range(1, 3),
1,
p=np.exp(aa)[[1, 2]] / np.sum(np.exp(aa)[[1, 2]]),
)
S = set_gamma[sec_sample[0]][
int(set_star["gamma_set_index"][sec_sample[0]])
].tolist()

for item in conditional_S:
if item not in S:
Expand Down
296 changes: 294 additions & 2 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import time
from typing import Any

import numpy as np
Expand All @@ -15,6 +16,8 @@
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.datasource.gnomad.ld import GnomADLDMatrix
from gentropy.method.carma import CARMA
from gentropy.method.sumstat_imputation import SummaryStatisticsImputation
from gentropy.method.susie_inf import SUSIE_inf


Expand Down Expand Up @@ -150,6 +153,7 @@ def susie_inf_to_studylocus(
region: str,
variant_index: DataFrame,
cs_lbf_thr: float = 2,
sum_pips: float = 0.99,
) -> StudyLocus:
"""Convert SuSiE-inf output to StudyLocus DataFrame.
Expand All @@ -160,6 +164,7 @@ def susie_inf_to_studylocus(
region (str): region
variant_index (DataFrame): DataFrame with variant information
cs_lbf_thr (float): credible set logBF threshold, default is 2
sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set)
Returns:
StudyLocus: StudyLocus object with fine-mapped credible sets
Expand Down Expand Up @@ -189,8 +194,8 @@ def susie_inf_to_studylocus(
susie_result[:, i + 1].astype(float).argsort()[::-1]
]
cumsum_arr = np.cumsum(sorted_arr[:, i + 1].astype(float))
filter_row = np.argmax(cumsum_arr >= 0.99)
if filter_row == 0 and cumsum_arr[0] < 0.99:
filter_row = np.argmax(cumsum_arr >= sum_pips)
if filter_row == 0 and cumsum_arr[0] < sum_pips:
filter_row = len(cumsum_arr)
filter_row += 1
filtered_arr = sorted_arr[:filter_row]
Expand Down Expand Up @@ -378,3 +383,290 @@ def susie_finemapper_ss_gathered(
region=region,
variant_index=variant_index,
)

@staticmethod
def susie_finemapper_from_prepared_dataframes(
GWAS_df: DataFrame,
ld_index: DataFrame,
gnomad_ld: np.ndarray,
L: int,
session: Session,
studyId: str,
region: str,
susie_est_tausq: bool = False,
run_carma: bool = False,
run_sumstat_imputation: bool = False,
carma_time_limit: int = 600,
imputed_r2_threshold: float = 0.8,
ld_score_threshold: float = 4,
sum_pips: float = 0.99,
) -> dict[str, Any]:
"""Susie fine-mapper function that uses LD, z-scores, variant info and other options for Fine-Mapping.
Args:
GWAS_df (DataFrame): GWAS DataFrame with mandotary columns: z, variantId
ld_index (DataFrame): LD index DataFrame
gnomad_ld (np.ndarray): GnomAD LD matrix
L (int): number of causal variants
session (Session): Spark session
studyId (str): study ID
region (str): region
susie_est_tausq (bool): estimate tau squared, default is False
run_carma (bool): run CARMA, default is False
run_sumstat_imputation (bool): run summary statistics imputation, default is False
carma_time_limit (int): CARMA time limit, default is 600 seconds
imputed_r2_threshold (float): imputed R2 threshold, default is 0.8
ld_score_threshold (float): LD score threshold ofr imputation, default is 4
sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set)
Returns:
dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map
"""
# PLEASE DO NOT REMOVE THIS LINE
pd.DataFrame.iteritems = pd.DataFrame.items

start_time = time.time()
GWAS_df = GWAS_df.toPandas()
ld_index = ld_index.toPandas()
ld_index = ld_index.reset_index()

N_gwas = len(GWAS_df)
N_ld = len(ld_index)

# Filtering out the variants that are not in the LD matrix, we don't need them
df_columns = ["variantId", "z"]
GWAS_df = GWAS_df.merge(ld_index, on="variantId", how="inner")
GWAS_df = GWAS_df[df_columns].reset_index()
N_after_merge = len(GWAS_df)

merged_df = GWAS_df.merge(
ld_index, left_on="variantId", right_on="variantId", how="inner"
)
indices = merged_df["index_y"].values

ld_to_fm = gnomad_ld[indices][:, indices]
z_to_fm = GWAS_df["z"].values

if run_carma:
carma_output = CARMA.time_limited_CARMA_spike_slab_noEM(
z=z_to_fm, ld=ld_to_fm, sec_threshold=carma_time_limit
)
if carma_output["Outliers"] != [] and carma_output["Outliers"] is not None:
GWAS_df.drop(carma_output["Outliers"], inplace=True)
GWAS_df = GWAS_df.reset_index()
ld_index = ld_index.reset_index()
merged_df = GWAS_df.merge(
ld_index, left_on="variantId", right_on="variantId", how="inner"
)
indices = merged_df["index_y"].values

ld_to_fm = gnomad_ld[indices][:, indices]
z_to_fm = GWAS_df["z"].values
N_outliers = len(carma_output["Outliers"])
else:
N_outliers = 0
else:
N_outliers = 0

if run_sumstat_imputation:
known = indices
unknown = [
index for index in list(range(len(gnomad_ld))) if index not in known
]
sig_t = gnomad_ld[known, :][:, known]
sig_i_t = gnomad_ld[unknown, :][:, known]
zt = z_to_fm

sumstat_imp_res = SummaryStatisticsImputation.raiss_model(
z_scores_known=zt,
ld_matrix_known=sig_t,
ld_matrix_known_missing=sig_i_t,
lamb=0.01,
rtol=0.01,
)

bool_index = (sumstat_imp_res["imputation_r2"] >= imputed_r2_threshold) * (
sumstat_imp_res["ld_score"] >= ld_score_threshold
)
if sum(bool_index) >= 1:
indices = np.where(bool_index)[0]
index_to_add = [unknown[i] for i in indices]
index_to_fm = np.concatenate((known, index_to_add))

ld_to_fm = gnomad_ld[index_to_fm][:, index_to_fm]

snp_info_to_add = pd.DataFrame(
{
"variantId": ld_index.iloc[index_to_add, :]["variantId"],
"z": sumstat_imp_res["mu"][indices],
}
)
GWAS_df = pd.concat([GWAS_df, snp_info_to_add], ignore_index=True)
z_to_fm = GWAS_df["z"].values

N_imputed = len(indices)
else:
N_imputed = 0
else:
N_imputed = 0

susie_output = SUSIE_inf.susie_inf(
z=z_to_fm, LD=ld_to_fm, L=L, est_tausq=susie_est_tausq
)

schema = StructType([StructField("variantId", StringType(), True)])
variant_index = (
session.spark.createDataFrame(
GWAS_df[["variantId"]],
schema=schema,
)
.withColumn(
"chromosome", f.split(f.col("variantId"), "_")[0].cast("string")
)
.withColumn("position", f.split(f.col("variantId"), "_")[1].cast("int"))
)

study_locus = SusieFineMapperStep.susie_inf_to_studylocus(
susie_output=susie_output,
session=session,
studyId=studyId,
region=region,
variant_index=variant_index,
sum_pips=sum_pips,
)

end_time = time.time()

log_df = pd.DataFrame(
{
"N_gwas": N_gwas,
"N_ld": N_ld,
"N_overlap": N_after_merge,
"N_outliers": N_outliers,
"N_imputed": N_imputed,
"N_final_to_fm": len(ld_to_fm),
"eleapsed_time": end_time - start_time,
},
index=[0],
)

return {
"study_locus": study_locus,
"log": log_df,
}

@staticmethod
def susie_finemapper_one_studylocus_row_v2_dev(
GWAS: SummaryStatistics,
session: Session,
study_locus_row: Row,
study_index: StudyIndex,
window: int = 1_000_000,
L: int = 10,
susie_est_tausq: bool = False,
run_carma: bool = False,
run_sumstat_imputation: bool = False,
carma_time_limit: int = 600,
imputed_r2_threshold: float = 0.8,
ld_score_threshold: float = 4,
sum_pips: float = 0.99,
) -> dict[str, Any]:
"""Susie fine-mapper function that uses Summary Statstics, chromosome and position as inputs.
Args:
GWAS (SummaryStatistics): GWAS summary statistics
session (Session): Spark session
study_locus_row (Row): StudyLocus row
study_index (StudyIndex): StudyIndex object
window (int): window size for fine-mapping
L (int): number of causal variants
susie_est_tausq (bool): estimate tau squared, default is False
run_carma (bool): run CARMA, default is False
run_sumstat_imputation (bool): run summary statistics imputation, default is False
carma_time_limit (int): CARMA time limit, default is 600 seconds
imputed_r2_threshold (float): imputed R2 threshold, default is 0.8
ld_score_threshold (float): LD score threshold ofr imputation, default is 4
sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set)
Returns:
dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map
"""
# PLEASE DO NOT REMOVE THIS LINE
pd.DataFrame.iteritems = pd.DataFrame.items

chromosome = study_locus_row["chromosome"]
position = study_locus_row["position"]
studyId = study_locus_row["studyId"]

study_index_df = study_index._df
study_index_df = study_index_df.filter(f.col("studyId") == studyId)
major_population = study_index_df.select(
"studyId",
f.array_max(f.col("ldPopulationStructure"))
.getItem("ldPopulation")
.alias("majorPopulation"),
).collect()[0]["majorPopulation"]

region = (
chromosome
+ ":"
+ str(int(position - window / 2))
+ "-"
+ str(int(position + window / 2))
)
gwas_df = (
GWAS.df.withColumn("z", f.col("beta") / f.col("standardError"))
.withColumn(
"chromosome", f.split(f.col("variantId"), "_")[0].cast("string")
)
.withColumn("position", f.split(f.col("variantId"), "_")[1].cast("int"))
.filter(f.col("studyId") == studyId)
.filter(f.col("z").isNotNull())
.filter(f.col("chromosome") == chromosome)
.filter(f.col("position") >= position - window / 2)
.filter(f.col("position") <= position + window / 2)
)

ld_index = (
GnomADLDMatrix()
.get_locus_index(
study_locus_row=study_locus_row,
window_size=window,
major_population=major_population,
)
.withColumn(
"variantId",
f.concat(
f.lit(chromosome),
f.lit("_"),
f.col("`locus.position`"),
f.lit("_"),
f.col("alleles").getItem(0),
f.lit("_"),
f.col("alleles").getItem(1),
).cast("string"),
)
)

gnomad_ld = GnomADLDMatrix.get_numpy_matrix(
ld_index, gnomad_ancestry=major_population
)

out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes(
GWAS_df=gwas_df,
ld_index=ld_index,
gnomad_ld=gnomad_ld,
L=L,
session=session,
studyId=studyId,
region=region,
susie_est_tausq=susie_est_tausq,
run_carma=run_carma,
run_sumstat_imputation=run_sumstat_imputation,
carma_time_limit=carma_time_limit,
imputed_r2_threshold=imputed_r2_threshold,
ld_score_threshold=ld_score_threshold,
sum_pips=sum_pips,
)

return out

0 comments on commit 86600b0

Please sign in to comment.