From 645dd802eb1fca9e742b32629cd70e4d119800fd Mon Sep 17 00:00:00 2001 From: Yakov Tsepilov Date: Fri, 5 Apr 2024 16:16:32 +0100 Subject: [PATCH] feat: add fm step with carma and sumstat imputation --- src/gentropy/susie_finemapper.py | 159 +++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index 03e1c2a0b..24cebb490 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -15,6 +15,8 @@ from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.summary_statistics import SummaryStatistics from gentropy.datasource.gnomad.ld import GnomADLDMatrix +from gentropy.method.carma import CARMA +from gentropy.method.sumstat_imputation import SummaryStatisticsImputation from gentropy.method.susie_inf import SUSIE_inf @@ -263,3 +265,160 @@ def susie_inf_to_studylocus( _df=cred_sets, _schema=StudyLocus.get_schema(), ) + + @staticmethod + def susie_finemapper_from_prepared_dataframes( + GWAS_df: DataFrame, + ld_index: DataFrame, + gnomad_ld: np.ndarray, + L: int, + session: Session, + studyId: str, + region: str, + susie_est_tausq: bool = False, + run_carma: bool = False, + run_sumstat_imputation: bool = False, + carma_time_limit: int = 600, + imputed_r2_threshold: float = 0.8, + ld_score_threshold: float = 4, + ) -> dict[str, Any]: + """Susie fine-mapper function that uses LD, z-scores, variant info and other options for Fine-Mapping. + + Args: + GWAS_df (DataFrame): GWAS DataFrame with mandotary columns: z, variantId, chromosome, position + ld_index (DataFrame): LD index DataFrame + gnomad_ld (np.ndarray): GnomAD LD matrix + L (int): number of causal variants + session (Session): Spark session + studyId (str): study ID + region (str): region + susie_est_tausq (bool): estimate tau squared, default is False + run_carma (bool): run CARMA, default is False + run_sumstat_imputation (bool): run summary statistics imputation, default is False + carma_time_limit (int): CARMA time limit, default is 600 seconds + imputed_r2_threshold (float): imputed R2 threshold, default is 0.8 + ld_score_threshold (float): LD score threshold ofr imputation, default is 4 + + Returns: + dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map + """ + GWAS_df = GWAS_df.toPandas() + ld_index = ld_index.toPandas() + ld_index = ld_index.reset_index() + + N_gwas = len(GWAS_df) + N_ld = len(ld_index) + + # Filtering out the variants that are not in the LD matrix, we don't need them + df_columns = ["variantId", "z"] + GWAS_df = GWAS_df.merge(ld_index, on="variantId", how="inner") + GWAS_df = GWAS_df[df_columns].reset_index() + N_after_merge = len(GWAS_df) + + merged_df = GWAS_df.merge( + ld_index, left_on="variantId", right_on="variantId", how="inner" + ) + indices = merged_df["index_y"].values + + ld_to_fm = gnomad_ld[indices][:, indices] + z_to_fm = GWAS_df["z"].values + + if run_carma: + carma_output = CARMA.time_limited_CARMA_spike_slab_noEM( + z=z_to_fm, ld=ld_to_fm, sec_threshold=carma_time_limit + ) + if carma_output["Outliers"] != [] and carma_output["Outliers"] is not None: + GWAS_df.drop(carma_output["Outliers"], inplace=True) + GWAS_df = GWAS_df.reset_index() + ld_index = ld_index.reset_index() + merged_df = GWAS_df.merge( + ld_index, left_on="variantId", right_on="variantId", how="inner" + ) + indices = merged_df["index_y"].values + + ld_to_fm = gnomad_ld[indices][:, indices] + z_to_fm = GWAS_df["z"].values + N_outliers = len(carma_output["Outliers"]) + else: + N_outliers = 0 + else: + N_outliers = 0 + + if run_sumstat_imputation: + known = indices + unknown = [ + index for index in list(range(len(gnomad_ld))) if index not in known + ] + sig_t = gnomad_ld[known, :][:, known] + sig_i_t = gnomad_ld[unknown, :][:, known] + zt = z_to_fm + + sumstat_imp_res = SummaryStatisticsImputation.raiss_model( + z_scores_known=zt, + ld_matrix_known=sig_t, + ld_matrix_known_missing=sig_i_t, + lamb=0.01, + rtol=0.01, + ) + + if ( + sum( + (sumstat_imp_res["imputation_r2"] >= imputed_r2_threshold) + * (sumstat_imp_res["ld_score"] >= ld_score_threshold) + ) + >= 1 + ): + indices = np.where( + (sumstat_imp_res["imputation_r2"] >= imputed_r2_threshold) + * (sumstat_imp_res["ld_score"] >= ld_score_threshold) + )[0] + index_to_add = [unknown[i] for i in indices] + index_to_fm = np.concatenate((known, index_to_add)) + + ld_to_fm = gnomad_ld[index_to_fm][:, index_to_fm] + + snp_info_to_add = pd.DataFrame( + { + "variantId": ld_index.iloc[index_to_add, :]["variantId"], + "z": sumstat_imp_res["mu"][indices], + } + ).reset_index() + GWAS_df = pd.concat([GWAS_df, snp_info_to_add], ignore_index=True) + z_to_fm = GWAS_df["z"].values + + N_imputed = len(indices) + else: + N_imputed = 0 + else: + N_imputed = 0 + + susie_output = SUSIE_inf.susie_inf( + z=z_to_fm, LD=ld_to_fm, L=L, est_tausq=susie_est_tausq + ) + + schema = StructType([StructField("variantId", StringType(), True)]) + variant_index = ( + session.spark.createDataFrame( + GWAS_df[["variantId"]], + schema=schema, + ) + .withColumn("chromosome", f.split(f.col("variantId"), "_")[0]) + .withColumn("position", f.split(f.col("variantId"), "_")[1]) + ) + + study_locus = SusieFineMapperStep.susie_inf_to_studylocus( + susie_output=susie_output, + session=session, + studyId=studyId, + region=region, + variant_index=variant_index, + ) + return { + "study_locus": study_locus, + "N_gwas": N_gwas, + "N_ld": N_ld, + "N_after_merge": N_after_merge, + "N_outliers": N_outliers, + "N_imputed": N_imputed, + "N_to_fm": len(ld_to_fm), + }