diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index 001f34bbf..740b6ecb3 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -6,12 +6,13 @@ import numpy as np import pyspark.sql.functions as f -from pyspark.sql import DataFrame, Window +from pyspark.sql import DataFrame, Row, Window from gentropy.common.session import Session from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.summary_statistics import SummaryStatistics +from gentropy.datasource.gnomad.ld import GnomADLDMatrix from gentropy.method.susie_inf import SUSIE_inf @@ -26,26 +27,28 @@ class SusieFineMapperStep: def susie_finemapper_chr_pos( GWAS: SummaryStatistics, session: Session, - chromosome: str, - position: int, - _studyId: str, + study_locus_row: Row, study_index: StudyIndex, window: int = 1_000_000, + L: int = 10, ) -> StudyLocus: """Susie fine-mapper function that uses Summary Statstics, chromosome and position as inputs. Args: GWAS (SummaryStatistics): GWAS summary statistics session (Session): Spark session - chromosome (str): chromosome - position (int): position - _studyId (str): study ID + study_locus_row (Row): StudyLocus row study_index (StudyIndex): StudyIndex object window (int): window size for fine-mapping + L (int): number of causal variants Returns: StudyLocus: StudyLocus object with fine-mapped credible sets """ + chromosome = study_locus_row.chromosome + position = study_locus_row.position + _studyId = study_locus_row.studyId + study_index_df = study_index._df study_index_df = study_index_df.filter(study_index_df.studyId == _studyId) _major_population = study_index_df.select( @@ -71,59 +74,57 @@ def susie_finemapper_chr_pos( & (f.col("position") <= position + (window / 2)) ).withColumn("z", f.col("beta") / f.col("standardError")) - _z = np.array([row["z"] for row in GWAS_df.select("z").collect()]) - - # # Extract summary statistics - # _ss = ( - # SummaryStatistics.get_locus_sumstats(session, window, _locus) - # .withColumn("z", f.col("beta") / f.col("standardError")) - # .withColumn("ref", f.split(f.col("variantId"), "_").getItem(2)) - # .withColumn("alt", f.split(f.col("variantId"), "_").getItem(3)) - # .select( - # "variantId", - # f.col("chromosome").alias("chr"), - # f.col("position").alias("pos"), - # "ref", - # "alt", - # "beta", - # "pValueMantissa", - # "pValueExponent", - # "effectAlleleFrequencyFromSource", - # "standardError", - # "z", - # ) - # ) - - # Extract LD index - # _index = GnomADLDMatrix.get_locus_index( - # session, locus, window_size=window, major_population=_major_population - # ) - # _join = ( - # _ss.join( - # _index.alias("_index"), - # on=( - # (_ss["chr"] == _index["chromosome"]) - # & (_ss["pos"] == _index["position"]) - # & (_ss["ref"] == _index["referenceAllele"]) - # & (_ss["alt"] == _index["alternateAllele"]) - # ), - # ) - # .drop("ref", "alt", "chr", "pos") - # .sort("idx") - # ) - - # Extracting z-scores and LD matrix, then running SuSiE-inf - # _ld = GnomADLDMatrix.get_locus_matrix(_join, gnomad_ancestry=_major_population) - - _ld = 1 - susie_output = SUSIE_inf.susie_inf(z=_z, LD=_ld, L=10) + ld_index = ( + GnomADLDMatrix() + .get_locus_index( + study_locus_row=study_locus_row, + window_size=window, + major_population=_major_population, + ) + .withColumn( + "variantId", + f.concat( + f.lit(chromosome), + f.lit("_"), + f.col("`locus.position`"), + f.lit("_"), + f.col("alleles").getItem(0), + f.lit("_"), + f.col("alleles").getItem(1), + ).cast("string"), + ) + ) + + gnomad_ld = GnomADLDMatrix.get_numpy_matrix( + ld_index, gnomad_ancestry=_major_population + ) + + GWAS_df = GWAS_df.toPandas() + ld_index = ld_index.toPandas() + ld_index = ld_index.reset_index() + + # Filtering out the variants that are not in the LD matrix, we don't need them + df_columns = GWAS_df.columns + GWAS_df = GWAS_df.merge(ld_index, on="variantId", how="inner") + GWAS_df = GWAS_df[df_columns].reset_index() + + merged_df = GWAS_df.merge( + ld_index, left_on="variantId", right_on="variantId", how="inner" + ) + indices = merged_df["index_y"].values + + ld_to_fm = gnomad_ld[indices][:, indices] + z_to_fm = GWAS_df["z"].values + + susie_output = SUSIE_inf.susie_inf(z=z_to_fm, LD=ld_to_fm, L=L) + variant_index = session.spark.createDataFrame(GWAS_df) return SusieFineMapperStep.susie_inf_to_studylocus( susie_output=susie_output, session=session, _studyId=_studyId, _region=_region, - variant_index=GWAS, + variant_index=variant_index, ) @staticmethod @@ -151,6 +152,7 @@ def susie_inf_to_studylocus( variants = np.array( [row["variantId"] for row in variant_index.select("variantId").collect()] ).reshape(-1, 1) + PIPs = susie_output["PIP"] lbfs = susie_output["lbf_variable"] mu = susie_output["mu"] @@ -181,6 +183,7 @@ def susie_inf_to_studylocus( win = Window.rowsBetween( Window.unboundedPreceding, Window.unboundedFollowing ) + cred_set = ( session.spark.createDataFrame( cred_set.tolist(), diff --git a/tests/gentropy/method/test_susie_inf.py b/tests/gentropy/method/test_susie_inf.py index 393f786d7..e7a5d219f 100644 --- a/tests/gentropy/method/test_susie_inf.py +++ b/tests/gentropy/method/test_susie_inf.py @@ -68,6 +68,7 @@ def test_SUSIE_inf_convert_to_study_locus( est_tausq=False, ) gwas_df = sample_summary_statistics._df.limit(21) + L1 = SusieFineMapperStep.susie_inf_to_studylocus( susie_output=susie_output, session=session,