Skip to content

Commit

Permalink
feat: add fm step with carma and sumstat imputation
Browse files Browse the repository at this point in the history
  • Loading branch information
addramir committed Apr 5, 2024
1 parent e1d20f3 commit 645dd80
Showing 1 changed file with 159 additions and 0 deletions.
159 changes: 159 additions & 0 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.datasource.gnomad.ld import GnomADLDMatrix
from gentropy.method.carma import CARMA
from gentropy.method.sumstat_imputation import SummaryStatisticsImputation
from gentropy.method.susie_inf import SUSIE_inf


Expand Down Expand Up @@ -263,3 +265,160 @@ def susie_inf_to_studylocus(
_df=cred_sets,
_schema=StudyLocus.get_schema(),
)

@staticmethod
def susie_finemapper_from_prepared_dataframes(
GWAS_df: DataFrame,
ld_index: DataFrame,
gnomad_ld: np.ndarray,
L: int,
session: Session,
studyId: str,
region: str,
susie_est_tausq: bool = False,
run_carma: bool = False,
run_sumstat_imputation: bool = False,
carma_time_limit: int = 600,
imputed_r2_threshold: float = 0.8,
ld_score_threshold: float = 4,
) -> dict[str, Any]:
"""Susie fine-mapper function that uses LD, z-scores, variant info and other options for Fine-Mapping.
Args:
GWAS_df (DataFrame): GWAS DataFrame with mandotary columns: z, variantId, chromosome, position
ld_index (DataFrame): LD index DataFrame
gnomad_ld (np.ndarray): GnomAD LD matrix
L (int): number of causal variants
session (Session): Spark session
studyId (str): study ID
region (str): region
susie_est_tausq (bool): estimate tau squared, default is False
run_carma (bool): run CARMA, default is False
run_sumstat_imputation (bool): run summary statistics imputation, default is False
carma_time_limit (int): CARMA time limit, default is 600 seconds
imputed_r2_threshold (float): imputed R2 threshold, default is 0.8
ld_score_threshold (float): LD score threshold ofr imputation, default is 4
Returns:
dict[str, Any]: dictionary with study locus, number of GWAS variants, number of LD variants, number of variants after merge, number of outliers, number of imputed variants, number of variants to fine-map
"""
GWAS_df = GWAS_df.toPandas()
ld_index = ld_index.toPandas()
ld_index = ld_index.reset_index()

N_gwas = len(GWAS_df)
N_ld = len(ld_index)

# Filtering out the variants that are not in the LD matrix, we don't need them
df_columns = ["variantId", "z"]
GWAS_df = GWAS_df.merge(ld_index, on="variantId", how="inner")
GWAS_df = GWAS_df[df_columns].reset_index()
N_after_merge = len(GWAS_df)

merged_df = GWAS_df.merge(
ld_index, left_on="variantId", right_on="variantId", how="inner"
)
indices = merged_df["index_y"].values

ld_to_fm = gnomad_ld[indices][:, indices]
z_to_fm = GWAS_df["z"].values

if run_carma:
carma_output = CARMA.time_limited_CARMA_spike_slab_noEM(
z=z_to_fm, ld=ld_to_fm, sec_threshold=carma_time_limit
)
if carma_output["Outliers"] != [] and carma_output["Outliers"] is not None:
GWAS_df.drop(carma_output["Outliers"], inplace=True)
GWAS_df = GWAS_df.reset_index()
ld_index = ld_index.reset_index()
merged_df = GWAS_df.merge(
ld_index, left_on="variantId", right_on="variantId", how="inner"
)
indices = merged_df["index_y"].values

ld_to_fm = gnomad_ld[indices][:, indices]
z_to_fm = GWAS_df["z"].values
N_outliers = len(carma_output["Outliers"])
else:
N_outliers = 0
else:
N_outliers = 0

if run_sumstat_imputation:
known = indices
unknown = [
index for index in list(range(len(gnomad_ld))) if index not in known
]
sig_t = gnomad_ld[known, :][:, known]
sig_i_t = gnomad_ld[unknown, :][:, known]
zt = z_to_fm

sumstat_imp_res = SummaryStatisticsImputation.raiss_model(
z_scores_known=zt,
ld_matrix_known=sig_t,
ld_matrix_known_missing=sig_i_t,
lamb=0.01,
rtol=0.01,
)

if (
sum(
(sumstat_imp_res["imputation_r2"] >= imputed_r2_threshold)
* (sumstat_imp_res["ld_score"] >= ld_score_threshold)
)
>= 1
):
indices = np.where(
(sumstat_imp_res["imputation_r2"] >= imputed_r2_threshold)
* (sumstat_imp_res["ld_score"] >= ld_score_threshold)
)[0]
index_to_add = [unknown[i] for i in indices]
index_to_fm = np.concatenate((known, index_to_add))

ld_to_fm = gnomad_ld[index_to_fm][:, index_to_fm]

snp_info_to_add = pd.DataFrame(
{
"variantId": ld_index.iloc[index_to_add, :]["variantId"],
"z": sumstat_imp_res["mu"][indices],
}
).reset_index()
GWAS_df = pd.concat([GWAS_df, snp_info_to_add], ignore_index=True)
z_to_fm = GWAS_df["z"].values

N_imputed = len(indices)
else:
N_imputed = 0
else:
N_imputed = 0

susie_output = SUSIE_inf.susie_inf(
z=z_to_fm, LD=ld_to_fm, L=L, est_tausq=susie_est_tausq
)

schema = StructType([StructField("variantId", StringType(), True)])
variant_index = (
session.spark.createDataFrame(
GWAS_df[["variantId"]],
schema=schema,
)
.withColumn("chromosome", f.split(f.col("variantId"), "_")[0])
.withColumn("position", f.split(f.col("variantId"), "_")[1])
)

study_locus = SusieFineMapperStep.susie_inf_to_studylocus(
susie_output=susie_output,
session=session,
studyId=studyId,
region=region,
variant_index=variant_index,
)
return {
"study_locus": study_locus,
"N_gwas": N_gwas,
"N_ld": N_ld,
"N_after_merge": N_after_merge,
"N_outliers": N_outliers,
"N_imputed": N_imputed,
"N_to_fm": len(ld_to_fm),
}

0 comments on commit 645dd80

Please sign in to comment.