From 86600b0df666f4ed2874f446c3831fe2baa8ff13 Mon Sep 17 00:00:00 2001 From: Yakov Date: Tue, 9 Apr 2024 12:24:09 +0100 Subject: [PATCH] feat: add FM step with carma and sumstat imputation (#568) * feat: add fm step with carma and sumstat imputation * fix: adding log * fix: fixing carma * fix: resolving conflict * fix: resolve conflicts with dev v2 * fix: siliencing FutureWarning in Carma --- src/gentropy/method/carma.py | 19 +- src/gentropy/susie_finemapper.py | 296 ++++++++++++++++++++++++++++++- 2 files changed, 310 insertions(+), 5 deletions(-) diff --git a/src/gentropy/method/carma.py b/src/gentropy/method/carma.py index 75cb32c79..af8816706 100644 --- a/src/gentropy/method/carma.py +++ b/src/gentropy/method/carma.py @@ -2,6 +2,7 @@ from __future__ import annotations import concurrent.futures +import warnings from itertools import combinations from math import floor, lgamma from typing import Any @@ -32,6 +33,8 @@ def time_limited_CARMA_spike_slab_noEM( - B_list: A dataframe containing the marginal likelihoods and the corresponding model space or None. - Outliers: A list of outlier SNPs or None. """ + # Ignore pandas future warnings + warnings.simplefilter(action="ignore", category=FutureWarning) try: # Execute CARMA.CARMA_spike_slab_noEM with a timeout with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: @@ -854,9 +857,19 @@ def _MCS_modified( # noqa: C901 sec_sample = np.random.choice( range(0, 3), 1, p=np.exp(aa) / np.sum(np.exp(aa)) ) - S = set_gamma[sec_sample[0]][ - int(set_star["gamma_set_index"][sec_sample[0]]) - ].tolist() + if set_gamma[sec_sample[0]] is not None: + S = set_gamma[sec_sample[0]][ + int(set_star["gamma_set_index"][sec_sample[0]]) + ].tolist() + else: + sec_sample = np.random.choice( + range(1, 3), + 1, + p=np.exp(aa)[[1, 2]] / np.sum(np.exp(aa)[[1, 2]]), + ) + S = set_gamma[sec_sample[0]][ + int(set_star["gamma_set_index"][sec_sample[0]]) + ].tolist() for item in conditional_S: if item not in S: diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index c5d86b1a8..aba5b4c40 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from typing import Any import numpy as np @@ -15,6 +16,8 @@ from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.summary_statistics import SummaryStatistics from gentropy.datasource.gnomad.ld import GnomADLDMatrix +from gentropy.method.carma import CARMA +from gentropy.method.sumstat_imputation import SummaryStatisticsImputation from gentropy.method.susie_inf import SUSIE_inf @@ -150,6 +153,7 @@ def susie_inf_to_studylocus( region: str, variant_index: DataFrame, cs_lbf_thr: float = 2, + sum_pips: float = 0.99, ) -> StudyLocus: """Convert SuSiE-inf output to StudyLocus DataFrame. @@ -160,6 +164,7 @@ def susie_inf_to_studylocus( region (str): region variant_index (DataFrame): DataFrame with variant information cs_lbf_thr (float): credible set logBF threshold, default is 2 + sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set) Returns: StudyLocus: StudyLocus object with fine-mapped credible sets @@ -189,8 +194,8 @@ def susie_inf_to_studylocus( susie_result[:, i + 1].astype(float).argsort()[::-1] ] cumsum_arr = np.cumsum(sorted_arr[:, i + 1].astype(float)) - filter_row = np.argmax(cumsum_arr >= 0.99) - if filter_row == 0 and cumsum_arr[0] < 0.99: + filter_row = np.argmax(cumsum_arr >= sum_pips) + if filter_row == 0 and cumsum_arr[0] < sum_pips: filter_row = len(cumsum_arr) filter_row += 1 filtered_arr = sorted_arr[:filter_row] @@ -378,3 +383,290 @@ def susie_finemapper_ss_gathered( region=region, variant_index=variant_index, ) + + @staticmethod + def susie_finemapper_from_prepared_dataframes( + GWAS_df: DataFrame, + ld_index: DataFrame, + gnomad_ld: np.ndarray, + L: int, + session: Session, + studyId: str, + region: str, + susie_est_tausq: bool = False, + run_carma: bool = False, + run_sumstat_imputation: bool = False, + carma_time_limit: int = 600, + imputed_r2_threshold: float = 0.8, + ld_score_threshold: float = 4, + sum_pips: float = 0.99, + ) -> dict[str, Any]: + """Susie fine-mapper function that uses LD, z-scores, variant info and other options for Fine-Mapping. + + Args: + GWAS_df (DataFrame): GWAS DataFrame with mandotary columns: z, variantId + ld_index (DataFrame): LD index DataFrame + gnomad_ld (np.ndarray): GnomAD LD matrix + L (int): number of causal variants + session (Session): Spark session + studyId (str): study ID + region (str): region + susie_est_tausq (bool): estimate tau squared, default is False + run_carma (bool): run CARMA, default is False + run_sumstat_imputation (bool): run summary statistics imputation, default is False + carma_time_limit (int): CARMA time limit, default is 600 seconds + imputed_r2_threshold (float): imputed R2 threshold, default is 0.8 + ld_score_threshold (float): LD score threshold ofr imputation, default is 4 + sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set) + + Returns: + dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map + """ + # PLEASE DO NOT REMOVE THIS LINE + pd.DataFrame.iteritems = pd.DataFrame.items + + start_time = time.time() + GWAS_df = GWAS_df.toPandas() + ld_index = ld_index.toPandas() + ld_index = ld_index.reset_index() + + N_gwas = len(GWAS_df) + N_ld = len(ld_index) + + # Filtering out the variants that are not in the LD matrix, we don't need them + df_columns = ["variantId", "z"] + GWAS_df = GWAS_df.merge(ld_index, on="variantId", how="inner") + GWAS_df = GWAS_df[df_columns].reset_index() + N_after_merge = len(GWAS_df) + + merged_df = GWAS_df.merge( + ld_index, left_on="variantId", right_on="variantId", how="inner" + ) + indices = merged_df["index_y"].values + + ld_to_fm = gnomad_ld[indices][:, indices] + z_to_fm = GWAS_df["z"].values + + if run_carma: + carma_output = CARMA.time_limited_CARMA_spike_slab_noEM( + z=z_to_fm, ld=ld_to_fm, sec_threshold=carma_time_limit + ) + if carma_output["Outliers"] != [] and carma_output["Outliers"] is not None: + GWAS_df.drop(carma_output["Outliers"], inplace=True) + GWAS_df = GWAS_df.reset_index() + ld_index = ld_index.reset_index() + merged_df = GWAS_df.merge( + ld_index, left_on="variantId", right_on="variantId", how="inner" + ) + indices = merged_df["index_y"].values + + ld_to_fm = gnomad_ld[indices][:, indices] + z_to_fm = GWAS_df["z"].values + N_outliers = len(carma_output["Outliers"]) + else: + N_outliers = 0 + else: + N_outliers = 0 + + if run_sumstat_imputation: + known = indices + unknown = [ + index for index in list(range(len(gnomad_ld))) if index not in known + ] + sig_t = gnomad_ld[known, :][:, known] + sig_i_t = gnomad_ld[unknown, :][:, known] + zt = z_to_fm + + sumstat_imp_res = SummaryStatisticsImputation.raiss_model( + z_scores_known=zt, + ld_matrix_known=sig_t, + ld_matrix_known_missing=sig_i_t, + lamb=0.01, + rtol=0.01, + ) + + bool_index = (sumstat_imp_res["imputation_r2"] >= imputed_r2_threshold) * ( + sumstat_imp_res["ld_score"] >= ld_score_threshold + ) + if sum(bool_index) >= 1: + indices = np.where(bool_index)[0] + index_to_add = [unknown[i] for i in indices] + index_to_fm = np.concatenate((known, index_to_add)) + + ld_to_fm = gnomad_ld[index_to_fm][:, index_to_fm] + + snp_info_to_add = pd.DataFrame( + { + "variantId": ld_index.iloc[index_to_add, :]["variantId"], + "z": sumstat_imp_res["mu"][indices], + } + ) + GWAS_df = pd.concat([GWAS_df, snp_info_to_add], ignore_index=True) + z_to_fm = GWAS_df["z"].values + + N_imputed = len(indices) + else: + N_imputed = 0 + else: + N_imputed = 0 + + susie_output = SUSIE_inf.susie_inf( + z=z_to_fm, LD=ld_to_fm, L=L, est_tausq=susie_est_tausq + ) + + schema = StructType([StructField("variantId", StringType(), True)]) + variant_index = ( + session.spark.createDataFrame( + GWAS_df[["variantId"]], + schema=schema, + ) + .withColumn( + "chromosome", f.split(f.col("variantId"), "_")[0].cast("string") + ) + .withColumn("position", f.split(f.col("variantId"), "_")[1].cast("int")) + ) + + study_locus = SusieFineMapperStep.susie_inf_to_studylocus( + susie_output=susie_output, + session=session, + studyId=studyId, + region=region, + variant_index=variant_index, + sum_pips=sum_pips, + ) + + end_time = time.time() + + log_df = pd.DataFrame( + { + "N_gwas": N_gwas, + "N_ld": N_ld, + "N_overlap": N_after_merge, + "N_outliers": N_outliers, + "N_imputed": N_imputed, + "N_final_to_fm": len(ld_to_fm), + "eleapsed_time": end_time - start_time, + }, + index=[0], + ) + + return { + "study_locus": study_locus, + "log": log_df, + } + + @staticmethod + def susie_finemapper_one_studylocus_row_v2_dev( + GWAS: SummaryStatistics, + session: Session, + study_locus_row: Row, + study_index: StudyIndex, + window: int = 1_000_000, + L: int = 10, + susie_est_tausq: bool = False, + run_carma: bool = False, + run_sumstat_imputation: bool = False, + carma_time_limit: int = 600, + imputed_r2_threshold: float = 0.8, + ld_score_threshold: float = 4, + sum_pips: float = 0.99, + ) -> dict[str, Any]: + """Susie fine-mapper function that uses Summary Statstics, chromosome and position as inputs. + + Args: + GWAS (SummaryStatistics): GWAS summary statistics + session (Session): Spark session + study_locus_row (Row): StudyLocus row + study_index (StudyIndex): StudyIndex object + window (int): window size for fine-mapping + L (int): number of causal variants + susie_est_tausq (bool): estimate tau squared, default is False + run_carma (bool): run CARMA, default is False + run_sumstat_imputation (bool): run summary statistics imputation, default is False + carma_time_limit (int): CARMA time limit, default is 600 seconds + imputed_r2_threshold (float): imputed R2 threshold, default is 0.8 + ld_score_threshold (float): LD score threshold ofr imputation, default is 4 + sum_pips (float): the expected sum of posterior probabilities in the locus, default is 0.99 (99% credible set) + + Returns: + dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map + """ + # PLEASE DO NOT REMOVE THIS LINE + pd.DataFrame.iteritems = pd.DataFrame.items + + chromosome = study_locus_row["chromosome"] + position = study_locus_row["position"] + studyId = study_locus_row["studyId"] + + study_index_df = study_index._df + study_index_df = study_index_df.filter(f.col("studyId") == studyId) + major_population = study_index_df.select( + "studyId", + f.array_max(f.col("ldPopulationStructure")) + .getItem("ldPopulation") + .alias("majorPopulation"), + ).collect()[0]["majorPopulation"] + + region = ( + chromosome + + ":" + + str(int(position - window / 2)) + + "-" + + str(int(position + window / 2)) + ) + gwas_df = ( + GWAS.df.withColumn("z", f.col("beta") / f.col("standardError")) + .withColumn( + "chromosome", f.split(f.col("variantId"), "_")[0].cast("string") + ) + .withColumn("position", f.split(f.col("variantId"), "_")[1].cast("int")) + .filter(f.col("studyId") == studyId) + .filter(f.col("z").isNotNull()) + .filter(f.col("chromosome") == chromosome) + .filter(f.col("position") >= position - window / 2) + .filter(f.col("position") <= position + window / 2) + ) + + ld_index = ( + GnomADLDMatrix() + .get_locus_index( + study_locus_row=study_locus_row, + window_size=window, + major_population=major_population, + ) + .withColumn( + "variantId", + f.concat( + f.lit(chromosome), + f.lit("_"), + f.col("`locus.position`"), + f.lit("_"), + f.col("alleles").getItem(0), + f.lit("_"), + f.col("alleles").getItem(1), + ).cast("string"), + ) + ) + + gnomad_ld = GnomADLDMatrix.get_numpy_matrix( + ld_index, gnomad_ancestry=major_population + ) + + out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes( + GWAS_df=gwas_df, + ld_index=ld_index, + gnomad_ld=gnomad_ld, + L=L, + session=session, + studyId=studyId, + region=region, + susie_est_tausq=susie_est_tausq, + run_carma=run_carma, + run_sumstat_imputation=run_sumstat_imputation, + carma_time_limit=carma_time_limit, + imputed_r2_threshold=imputed_r2_threshold, + ld_score_threshold=ld_score_threshold, + sum_pips=sum_pips, + ) + + return out