From 900dd649779c2727e44c30b11671b0e3c7261036 Mon Sep 17 00:00:00 2001 From: Daniel-Considine <113430683+Daniel-Considine@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:33:16 +0100 Subject: [PATCH] feat: adding init to finemapping step (#577) * feat: adding init to finemapping step * fix: removing some commented lines * chore: fixing indents * fix: schema * feat: changing output path to include studyLocusId mapped --------- Co-authored-by: Yakov --- src/gentropy/susie_finemapper.py | 51 ++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index d37298436..7b6f81b3a 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -28,6 +28,47 @@ class SusieFineMapperStep: In the future this step will be refactored and moved to the methods module. """ + def __init__( + self, + session: Session, + study_locus_to_finemap: str, + study_locus_collected_path: str, + study_index_path: str, + output_path: str, + locus_radius: int = 500_000, + locus_L: int = 10, + ) -> None: + """Run fine-mapping on a studyLocusId from a collected studyLocus table. + + Args: + session (Session): Spark session + study_locus_to_finemap (str): path to the study locus to fine-map + study_locus_collected_path (str): path to the collected study locus + study_index_path (str): path to the study index + output_path (str): path to the output + locus_radius (int): Radius of base-pair window around the locus, default is 500_000 + locus_L (int): Maximum number of causal variants in locus, default is 10 + """ + # Read studyLocus + study_locus = ( + StudyLocus.from_parquet(session, study_locus_collected_path) + .df.filter(f.col("studyLocusId") == study_locus_to_finemap) + .collect()[0] + ) + study_index = StudyIndex.from_parquet(session, study_index_path) + # Run fine-mapping + result = self.susie_finemapper_ss_gathered( + session, + study_locus, + study_index, + locus_radius * 2, + locus_L, + ) + # Write result + result.df.write.mode(session.write_mode).parquet( + output_path + "/" + study_locus_to_finemap + ) + @staticmethod def susie_finemapper_one_studylocus_row( GWAS: SummaryStatistics, @@ -317,9 +358,15 @@ def susie_finemapper_ss_gathered( + str(int(position + window / 2)) ) + schema = StudyLocus.get_schema() + gwas_df = session.spark.createDataFrame([study_locus_row], schema=schema) + exploded_df = gwas_df.select(f.explode("locus").alias("locus")) + + result_df = exploded_df.select( + "locus.variantId", "locus.beta", "locus.standardError" + ) gwas_df = ( - session.spark.createDataFrame(study_locus_row.locus) - .withColumn("z", f.col("beta") / f.col("standardError")) + result_df.withColumn("z", f.col("beta") / f.col("standardError")) .withColumn("chromosome", f.split(f.col("variantId"), "_")[0]) .withColumn("position", f.split(f.col("variantId"), "_")[1]) .filter(f.col("z").isNotNull())