diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index 03e1c2a0b..047dff0a2 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -263,3 +263,118 @@ def susie_inf_to_studylocus( _df=cred_sets, _schema=StudyLocus.get_schema(), ) + + @staticmethod + def susie_finemapper_ss_gathered( + session: Session, + study_locus_row: Row, + study_index: StudyIndex, + window: int = 1_000_000, + L: int = 10, + ) -> StudyLocus: + """Susie fine-mapper function that uses Summary Statstics, chromosome and position as inputs. + + Args: + session (Session): Spark session + study_locus_row (Row): StudyLocus row + study_index (StudyIndex): StudyIndex object + window (int): window size for fine-mapping + L (int): number of causal variants + + Returns: + StudyLocus: StudyLocus object with fine-mapped credible sets + """ + # PLEASE DO NOT REMOVE THIS LINE + pd.DataFrame.iteritems = pd.DataFrame.items + + chromosome = study_locus_row["chromosome"] + position = study_locus_row["position"] + studyId = study_locus_row["studyId"] + + study_index_df = study_index._df + study_index_df = study_index_df.filter(f.col("studyId") == studyId) + major_population = study_index_df.select( + "studyId", + f.array_max(f.col("ldPopulationStructure")) + .getItem("ldPopulation") + .alias("majorPopulation"), + ).collect()[0]["majorPopulation"] + + region = ( + chromosome + + ":" + + str(int(position - window / 2)) + + "-" + + str(int(position + window / 2)) + ) + + gwas_df = ( + session.spark.createDataFrame(study_locus_row.locus) + .withColumn("z", f.col("beta") / f.col("standardError")) + .withColumn("chromosome", f.split(f.col("variantId"), "_")[0]) + .withColumn("position", f.split(f.col("variantId"), "_")[1]) + .filter(f.col("z").isNotNull()) + ) + + ld_index = ( + GnomADLDMatrix() + .get_locus_index( + study_locus_row=study_locus_row, + window_size=window, + major_population=major_population, + ) + .withColumn( + "variantId", + f.concat( + f.lit(chromosome), + f.lit("_"), + f.col("`locus.position`"), + f.lit("_"), + f.col("alleles").getItem(0), + f.lit("_"), + f.col("alleles").getItem(1), + ).cast("string"), + ) + ) + + # Filtering out the variants that are not in the LD matrix, we don't need them + gwas_index = gwas_df.join( + ld_index.select("variantId", "alleles", "idx"), on="variantId" + ).sort("idx") + + gnomad_ld = GnomADLDMatrix.get_numpy_matrix( + gwas_index, gnomad_ancestry=major_population + ) + + pd_df = gwas_index.toPandas() + z_to_fm = np.array(pd_df["z"]) + ld_to_fm = gnomad_ld + + susie_output = SUSIE_inf.susie_inf(z=z_to_fm, LD=ld_to_fm, L=L) + + schema = StructType( + [ + StructField("variantId", StringType(), True), + StructField("chromosome", StringType(), True), + StructField("position", IntegerType(), True), + ] + ) + pd_df["position"] = pd_df["position"].astype(int) + variant_index = session.spark.createDataFrame( + pd_df[ + [ + "variantId", + "chromosome", + "position", + ] + ], + schema=schema, + ) + + return SusieFineMapperStep.susie_inf_to_studylocus( + susie_output=susie_output, + session=session, + studyId=studyId, + region=region, + variant_index=variant_index, + )