Skip to content

Commit

Permalink
feat: adding init to finemapping step (#577)
Browse files Browse the repository at this point in the history
* feat: adding init to finemapping step

* fix: removing some commented lines

* chore: fixing indents

* fix: schema

* feat: changing output path to include studyLocusId mapped

---------

Co-authored-by: Yakov <[email protected]>
  • Loading branch information
Daniel-Considine and addramir committed Apr 23, 2024
1 parent 7ed4703 commit 900dd64
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,47 @@ class SusieFineMapperStep:
In the future this step will be refactored and moved to the methods module.
"""

def __init__(
self,
session: Session,
study_locus_to_finemap: str,
study_locus_collected_path: str,
study_index_path: str,
output_path: str,
locus_radius: int = 500_000,
locus_L: int = 10,
) -> None:
"""Run fine-mapping on a studyLocusId from a collected studyLocus table.
Args:
session (Session): Spark session
study_locus_to_finemap (str): path to the study locus to fine-map
study_locus_collected_path (str): path to the collected study locus
study_index_path (str): path to the study index
output_path (str): path to the output
locus_radius (int): Radius of base-pair window around the locus, default is 500_000
locus_L (int): Maximum number of causal variants in locus, default is 10
"""
# Read studyLocus
study_locus = (
StudyLocus.from_parquet(session, study_locus_collected_path)
.df.filter(f.col("studyLocusId") == study_locus_to_finemap)
.collect()[0]
)
study_index = StudyIndex.from_parquet(session, study_index_path)
# Run fine-mapping
result = self.susie_finemapper_ss_gathered(
session,
study_locus,
study_index,
locus_radius * 2,
locus_L,
)
# Write result
result.df.write.mode(session.write_mode).parquet(
output_path + "/" + study_locus_to_finemap
)

@staticmethod
def susie_finemapper_one_studylocus_row(
GWAS: SummaryStatistics,
Expand Down Expand Up @@ -317,9 +358,15 @@ def susie_finemapper_ss_gathered(
+ str(int(position + window / 2))
)

schema = StudyLocus.get_schema()
gwas_df = session.spark.createDataFrame([study_locus_row], schema=schema)
exploded_df = gwas_df.select(f.explode("locus").alias("locus"))

result_df = exploded_df.select(
"locus.variantId", "locus.beta", "locus.standardError"
)
gwas_df = (
session.spark.createDataFrame(study_locus_row.locus)
.withColumn("z", f.col("beta") / f.col("standardError"))
result_df.withColumn("z", f.col("beta") / f.col("standardError"))
.withColumn("chromosome", f.split(f.col("variantId"), "_")[0])
.withColumn("position", f.split(f.col("variantId"), "_")[1])
.filter(f.col("z").isNotNull())
Expand Down

0 comments on commit 900dd64

Please sign in to comment.